Int4 Groupwise GEMM Plugin#
-
class Int4GroupwiseGemmPlugin : public nvinfer1::IPluginV2DynamicExt#
TensorRT plugin for INT4 group-wise quantized GEMM.
Implements efficient INT4 quantized matrix multiplication with group-wise quantization. Used for quantized weight matrix multiplications in LLM inference.
Public Functions
- Int4GroupwiseGemmPlugin(
- std::string const &name,
- int32_t N,
- int32_t K,
- int32_t groupSize
Construct INT4 group-wise GEMM plugin.
- Parameters:
name – Layer name
N – Output dimension (columns in weight matrix)
K – Input dimension (rows in weight matrix)
groupSize – Quantization group size
- Int4GroupwiseGemmPlugin(
- std::string const &name,
- void const *data,
- size_t length
Construct from serialized data.
- Parameters:
name – Layer name
data – Serialized plugin data
length – Size of serialized data
-
Int4GroupwiseGemmPlugin() = delete#
Deleted default constructor.
-
Int4GroupwiseGemmPlugin(Int4GroupwiseGemmPlugin const&) = delete#
Deleted copy constructor.
-
~Int4GroupwiseGemmPlugin() override#
Destructor.
-
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override#
Clone the plugin for use in another network.
- Returns:
Cloned plugin instance
-
int32_t getNbOutputs() const noexcept override#
Get number of output tensors.
- Returns:
Number of outputs (1)
- nvinfer1::DataType getOutputDataType(
- int32_t index,
- nvinfer1::DataType const *inputTypes,
- int32_t nbInputs
Get output tensor data type.
- Parameters:
index – Output index
inputTypes – Input data types
nbInputs – Number of inputs
- Returns:
Output data type
- nvinfer1::DimsExprs getOutputDimensions(
- int32_t outputIndex,
- nvinfer1::DimsExprs const *inputs,
- int32_t nbInputs,
- nvinfer1::IExprBuilder &exprBuilder
Get output tensor dimensions.
- Parameters:
outputIndex – Output index
inputs – Input dimensions
nbInputs – Number of inputs
exprBuilder – Expression builder for dynamic shapes
- Returns:
Output dimensions
- bool supportsFormatCombination(
- int32_t pos,
- nvinfer1::PluginTensorDesc const *inOut,
- int32_t nbInputs,
- int32_t nbOutputs
Check if format combination is supported.
- Parameters:
pos – Position in input/output array
inOut – Input and output tensor descriptors
nbInputs – Number of inputs
nbOutputs – Number of outputs
- Returns:
True if supported
- void configurePlugin(
- nvinfer1::DynamicPluginTensorDesc const *in,
- int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const *out,
- int32_t nbOutputs
Configure plugin with tensor descriptions.
- Parameters:
in – Input tensor descriptors
nbInputs – Number of inputs
out – Output tensor descriptors
nbOutputs – Number of outputs
- size_t getWorkspaceSize(
- nvinfer1::PluginTensorDesc const *inputs,
- int32_t nbInputs,
- nvinfer1::PluginTensorDesc const *outputs,
- int32_t nbOutputs
Get workspace size required for execution.
- Parameters:
inputs – Input tensor descriptors
nbInputs – Number of inputs
outputs – Output tensor descriptors
nbOutputs – Number of outputs
- Returns:
Workspace size in bytes
- int32_t enqueue(
- nvinfer1::PluginTensorDesc const *inputDesc,
- nvinfer1::PluginTensorDesc const *outputDesc,
- void const *const *inputs,
- void *const *outputs,
- void *workspace,
- cudaStream_t stream
Execute the plugin.
- Parameters:
inputDesc – Input tensor descriptors
outputDesc – Output tensor descriptors
inputs – Input tensor pointers
outputs – Output tensor pointers
workspace – Workspace pointer
stream – CUDA stream
- Returns:
0 on success, non-zero on error
-
size_t getSerializationSize() const noexcept override#
Get serialization size.
- Returns:
Size in bytes
-
void serialize(void *buffer) const noexcept override#
Serialize plugin state.
- Parameters:
buffer – Output buffer
-
char const *getPluginType() const noexcept override#
Get plugin type name.
- Returns:
Plugin type string
-
char const *getPluginNamespace() const noexcept override#
Get plugin namespace.
- Returns:
Namespace string
-
void setPluginNamespace(char const *pluginNamespace) noexcept#
Set plugin namespace.
- Parameters:
pluginNamespace – Namespace string
-
char const *getPluginVersion() const noexcept override#
Get plugin version.
- Returns:
Version string
-
int32_t initialize() noexcept override#
Initialize plugin resources.
- Returns:
0 on success
-
void terminate() noexcept override#
Release plugin resources.
-
void destroy() noexcept override#
Destroy plugin instance.
-
class Int4GroupwiseGemmPluginCreator : public nvinfer1::IPluginCreator#
Factory for creating Int4GroupwiseGemmPlugin instances.
Handles plugin registration and creation in TensorRT.
Public Functions
-
Int4GroupwiseGemmPluginCreator()#
Constructor.
-
~Int4GroupwiseGemmPluginCreator() override = default#
Destructor.
-
char const *getPluginName() const noexcept override#
Get plugin name.
- Returns:
Plugin name string
- nvinfer1::PluginFieldCollection const *getFieldNames(
Get plugin field names.
- Returns:
Field collection
-
void setPluginNamespace(char const *pluginNamespace) noexcept#
Set plugin namespace.
- Parameters:
pluginNamespace – Namespace string
-
char const *getPluginNamespace() const noexcept override#
Get plugin namespace.
- Returns:
Namespace string
-
char const *getPluginVersion() const noexcept override#
Get plugin version.
- Returns:
Version string
- nvinfer1::IPluginV2 *createPlugin(
- char const *name,
- nvinfer1::PluginFieldCollection const *fc
Create plugin from field collection.
- Parameters:
name – Plugin name
fc – Field collection with parameters
- Returns:
Created plugin instance
- nvinfer1::IPluginV2 *deserializePlugin(
- char const *name,
- void const *serialData,
- size_t serialLength
Deserialize plugin from data.
- Parameters:
name – Plugin name
serialData – Serialized data
serialLength – Data size
- Returns:
Deserialized plugin instance
-
Int4GroupwiseGemmPluginCreator()#