Int4 Groupwise GEMM Plugin#

class Int4GroupwiseGemmPlugin : public nvinfer1::IPluginV2DynamicExt#

TensorRT plugin for INT4 group-wise quantized GEMM.

Implements efficient INT4 quantized matrix multiplication with group-wise quantization. Used for quantized weight matrix multiplications in LLM inference.

Public Functions

Int4GroupwiseGemmPlugin(
std::string const &name,
int32_t N,
int32_t K,
int32_t groupSize
)#

Construct INT4 group-wise GEMM plugin.

Parameters:
  • name – Layer name

  • N – Output dimension (columns in weight matrix)

  • K – Input dimension (rows in weight matrix)

  • groupSize – Quantization group size

Int4GroupwiseGemmPlugin(
std::string const &name,
void const *data,
size_t length
)#

Construct from serialized data.

Parameters:
  • name – Layer name

  • data – Serialized plugin data

  • length – Size of serialized data

Int4GroupwiseGemmPlugin() = delete#

Deleted default constructor.

Int4GroupwiseGemmPlugin(Int4GroupwiseGemmPlugin const&) = delete#

Deleted copy constructor.

~Int4GroupwiseGemmPlugin() override#

Destructor.

nvinfer1::IPluginV2DynamicExt *clone() const noexcept override#

Clone the plugin for use in another network.

Returns:

Cloned plugin instance

int32_t getNbOutputs() const noexcept override#

Get number of output tensors.

Returns:

Number of outputs (1)

nvinfer1::DataType getOutputDataType(
int32_t index,
nvinfer1::DataType const *inputTypes,
int32_t nbInputs
) const noexcept override#

Get output tensor data type.

Parameters:
  • index – Output index

  • inputTypes – Input data types

  • nbInputs – Number of inputs

Returns:

Output data type

nvinfer1::DimsExprs getOutputDimensions(
int32_t outputIndex,
nvinfer1::DimsExprs const *inputs,
int32_t nbInputs,
nvinfer1::IExprBuilder &exprBuilder
) noexcept override#

Get output tensor dimensions.

Parameters:
  • outputIndex – Output index

  • inputs – Input dimensions

  • nbInputs – Number of inputs

  • exprBuilder – Expression builder for dynamic shapes

Returns:

Output dimensions

bool supportsFormatCombination(
int32_t pos,
nvinfer1::PluginTensorDesc const *inOut,
int32_t nbInputs,
int32_t nbOutputs
) noexcept override#

Check if format combination is supported.

Parameters:
  • pos – Position in input/output array

  • inOut – Input and output tensor descriptors

  • nbInputs – Number of inputs

  • nbOutputs – Number of outputs

Returns:

True if supported

void configurePlugin(
nvinfer1::DynamicPluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#

Configure plugin with tensor descriptions.

Parameters:
  • in – Input tensor descriptors

  • nbInputs – Number of inputs

  • out – Output tensor descriptors

  • nbOutputs – Number of outputs

size_t getWorkspaceSize(
nvinfer1::PluginTensorDesc const *inputs,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const *outputs,
int32_t nbOutputs
) const noexcept override#

Get workspace size required for execution.

Parameters:
  • inputs – Input tensor descriptors

  • nbInputs – Number of inputs

  • outputs – Output tensor descriptors

  • nbOutputs – Number of outputs

Returns:

Workspace size in bytes

int32_t enqueue(
nvinfer1::PluginTensorDesc const *inputDesc,
nvinfer1::PluginTensorDesc const *outputDesc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream
) noexcept override#

Execute the plugin.

Parameters:
  • inputDesc – Input tensor descriptors

  • outputDesc – Output tensor descriptors

  • inputs – Input tensor pointers

  • outputs – Output tensor pointers

  • workspace – Workspace pointer

  • stream – CUDA stream

Returns:

0 on success, non-zero on error

size_t getSerializationSize() const noexcept override#

Get serialization size.

Returns:

Size in bytes

void serialize(void *buffer) const noexcept override#

Serialize plugin state.

Parameters:

buffer – Output buffer

char const *getPluginType() const noexcept override#

Get plugin type name.

Returns:

Plugin type string

char const *getPluginNamespace() const noexcept override#

Get plugin namespace.

Returns:

Namespace string

void setPluginNamespace(char const *pluginNamespace) noexcept#

Set plugin namespace.

Parameters:

pluginNamespace – Namespace string

char const *getPluginVersion() const noexcept override#

Get plugin version.

Returns:

Version string

int32_t initialize() noexcept override#

Initialize plugin resources.

Returns:

0 on success

void terminate() noexcept override#

Release plugin resources.

void destroy() noexcept override#

Destroy plugin instance.

class Int4GroupwiseGemmPluginCreator : public nvinfer1::IPluginCreator#

Factory for creating Int4GroupwiseGemmPlugin instances.

Handles plugin registration and creation in TensorRT.

Public Functions

Int4GroupwiseGemmPluginCreator()#

Constructor.

~Int4GroupwiseGemmPluginCreator() override = default#

Destructor.

char const *getPluginName() const noexcept override#

Get plugin name.

Returns:

Plugin name string

nvinfer1::PluginFieldCollection const *getFieldNames(
) noexcept override#

Get plugin field names.

Returns:

Field collection

void setPluginNamespace(char const *pluginNamespace) noexcept#

Set plugin namespace.

Parameters:

pluginNamespace – Namespace string

char const *getPluginNamespace() const noexcept override#

Get plugin namespace.

Returns:

Namespace string

char const *getPluginVersion() const noexcept override#

Get plugin version.

Returns:

Version string

nvinfer1::IPluginV2 *createPlugin(
char const *name,
nvinfer1::PluginFieldCollection const *fc
) noexcept override#

Create plugin from field collection.

Parameters:
  • name – Plugin name

  • fc – Field collection with parameters

Returns:

Created plugin instance

nvinfer1::IPluginV2 *deserializePlugin(
char const *name,
void const *serialData,
size_t serialLength
) noexcept override#

Deserialize plugin from data.

Parameters:
  • name – Plugin name

  • serialData – Serialized data

  • serialLength – Data size

Returns:

Deserialized plugin instance