Int4 Groupwise GEMM Plugin#

class Int4GroupwiseGemmPlugin : public nvinfer1::IPluginV3, public nvinfer1::IPluginV3OneCore, public nvinfer1::IPluginV3OneBuild, public nvinfer1::IPluginV3OneRuntime#

TensorRT plugin for INT4 group-wise quantized GEMM.

Implements efficient INT4 quantized matrix multiplication with group-wise quantization. Used for quantized weight matrix multiplications in LLM inference.

Public Functions

Int4GroupwiseGemmPlugin(
std::string const &name,
int32_t N,
int32_t K,
int32_t groupSize
)#

Construct INT4 group-wise GEMM plugin.

Parameters:
  • name – Layer name

  • N – Output dimension (columns in weight matrix)

  • K – Input dimension (rows in weight matrix)

  • groupSize – Quantization group size

Int4GroupwiseGemmPlugin(
std::string const &name,
nvinfer1::PluginFieldCollection const *fc
)#

Construct from field collection.

Parameters:
  • name – Layer name

  • fc – Plugin field collection

Int4GroupwiseGemmPlugin() = delete#

Deleted default constructor.

Int4GroupwiseGemmPlugin(Int4GroupwiseGemmPlugin const&) = delete#

Deleted copy constructor.

~Int4GroupwiseGemmPlugin() override#

Destructor.

nvinfer1::IPluginCapability *getCapabilityInterface(
nvinfer1::PluginCapabilityType type
) noexcept override#

Return the plugin capability interface for given type.

nvinfer1::IPluginV3 *clone() noexcept override#

Clone the plugin for use in another network.

Returns:

Cloned plugin instance

char const *getPluginName() const noexcept override#

Get plugin name.

Returns:

Plugin name string

char const *getPluginVersion() const noexcept override#

Get plugin version.

Returns:

Version string

char const *getPluginNamespace() const noexcept override#

Get plugin namespace.

Returns:

Namespace string

int32_t getNbOutputs() const noexcept override#

Get number of output tensors.

Returns:

Number of outputs (1)

int32_t getOutputDataTypes(
nvinfer1::DataType *outputTypes,
int32_t nbOutputs,
nvinfer1::DataType const *inputTypes,
int32_t nbInputs
) const noexcept override#

Get output tensor data types.

Parameters:
  • outputTypes – Output array for data types

  • nbOutputs – Number of outputs

  • inputTypes – Input data types

  • nbInputs – Number of inputs

Returns:

0 on success, non-zero on error

int32_t getOutputShapes(
nvinfer1::DimsExprs const *inputs,
int32_t nbInputs,
nvinfer1::DimsExprs const *shapeInputs,
int32_t nbShapeInputs,
nvinfer1::DimsExprs *outputs,
int32_t nbOutputs,
nvinfer1::IExprBuilder &exprBuilder
) noexcept override#

Get output tensor shapes.

Parameters:
  • inputs – Input dimensions

  • nbInputs – Number of inputs

  • shapeInputs – Shape tensor inputs

  • nbShapeInputs – Number of shape inputs

  • outputs – Output dimensions

  • nbOutputs – Number of outputs

  • exprBuilder – Expression builder for dynamic shapes

Returns:

0 on success, non-zero on error

bool supportsFormatCombination(
int32_t pos,
nvinfer1::DynamicPluginTensorDesc const *inOut,
int32_t nbInputs,
int32_t nbOutputs
) noexcept override#

Check if format combination is supported.

Parameters:
  • pos – Position in input/output array

  • inOut – Input and output tensor descriptors

  • nbInputs – Number of inputs

  • nbOutputs – Number of outputs

Returns:

True if supported

int32_t configurePlugin(
nvinfer1::DynamicPluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#

Configure plugin with tensor descriptions.

Parameters:
  • in – Input tensor descriptors

  • nbInputs – Number of inputs

  • out – Output tensor descriptors

  • nbOutputs – Number of outputs

Returns:

0 on success, non-zero on error

size_t getWorkspaceSize(
nvinfer1::DynamicPluginTensorDesc const *inputs,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *outputs,
int32_t nbOutputs
) const noexcept override#

Get workspace size required for execution.

Parameters:
  • inputs – Input tensor descriptors

  • nbInputs – Number of inputs

  • outputs – Output tensor descriptors

  • nbOutputs – Number of outputs

Returns:

Workspace size in bytes

int32_t enqueue(
nvinfer1::PluginTensorDesc const *inputDesc,
nvinfer1::PluginTensorDesc const *outputDesc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream
) noexcept override#

Execute the plugin.

Parameters:
  • inputDesc – Input tensor descriptors

  • outputDesc – Output tensor descriptors

  • inputs – Input tensor pointers

  • outputs – Output tensor pointers

  • workspace – Workspace pointer

  • stream – CUDA stream

Returns:

0 on success, non-zero on error

int32_t onShapeChange(
nvinfer1::PluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#

Called when input/output shapes change during runtime.

Parameters:
  • in – Input tensor descriptors

  • nbInputs – Number of inputs

  • out – Output tensor descriptors

  • nbOutputs – Number of outputs

Returns:

0 on success, non-zero on error

nvinfer1::IPluginV3 *attachToContext(
nvinfer1::IPluginResourceContext *context
) noexcept override#

Attach plugin to an execution context.

Parameters:

context – Plugin resource context

Returns:

Cloned plugin attached to context

nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
) noexcept override#

Get plugin fields for serialization.

Returns:

Field collection for serialization

void setPluginNamespace(char const *pluginNamespace) noexcept#

Set plugin namespace.

Parameters:

pluginNamespace – Namespace string

class Int4GroupwiseGemmPluginCreator : public nvinfer1::IPluginCreatorV3One#

Factory for creating Int4GroupwiseGemmPlugin instances.

Handles plugin registration and creation in TensorRT.

Public Functions

Int4GroupwiseGemmPluginCreator()#

Constructor.

~Int4GroupwiseGemmPluginCreator() override = default#

Destructor.

char const *getPluginName() const noexcept override#

Get plugin name.

Returns:

Plugin name string

char const *getPluginVersion() const noexcept override#

Get plugin version.

Returns:

Version string

nvinfer1::PluginFieldCollection const *getFieldNames(
) noexcept override#

Get plugin field names.

Returns:

Field collection

char const *getPluginNamespace() const noexcept override#

Get plugin namespace.

Returns:

Namespace string

void setPluginNamespace(char const *pluginNamespace) noexcept#

Set plugin namespace.

Parameters:

pluginNamespace – Namespace string

nvinfer1::IPluginV3 *createPlugin(
char const *name,
nvinfer1::PluginFieldCollection const *fc,
nvinfer1::TensorRTPhase phase
) noexcept override#

Create plugin from field collection.

Parameters:
  • name – Plugin name

  • fc – Field collection with parameters

  • phase – TensorRT phase (build or runtime)

Returns:

Created plugin instance