Attention Plugin#

class AttentionPlugin : public nvinfer1::IPluginV2DynamicExt#

TensorRT plugin for attention operations (context and decode)

This plugin implements efficient attention mechanisms including context attention (prefill) and decode attention with KV cache support.

IPluginV2DynamicExt Methods

nvinfer1::IPluginV2DynamicExt *clone() const noexcept override#

Clone the plugin instance.

Returns:

Pointer to cloned plugin

int32_t getNbOutputs() const noexcept override#

Get number of outputs.

Returns:

Number of output tensors

nvinfer1::DataType getOutputDataType(
int32_t index,
nvinfer1::DataType const *inputTypes,
int32_t nbInputs
) const noexcept override#

Get output data type.

Parameters:
  • index[in] Output index

  • inputTypes[in] Array of input data types

  • nbInputs[in] Number of inputs

Returns:

Output data type

nvinfer1::DimsExprs getOutputDimensions(
int32_t outputIndex,
nvinfer1::DimsExprs const *inputs,
int32_t nbInputs,
nvinfer1::IExprBuilder &exprBuilder
) noexcept override#

Get output dimensions.

Parameters:
  • outputIndex[in] Output tensor index

  • inputs[in] Input tensor dimensions

  • nbInputs[in] Number of inputs

  • exprBuilder[in] Expression builder for dimension calculations

Returns:

Output tensor dimensions

bool supportsFormatCombination(
int32_t pos,
nvinfer1::PluginTensorDesc const *inOut,
int32_t nbInputs,
int32_t nbOutputs
) noexcept override#

Check if format combination is supported.

Parameters:
  • pos[in] Position in the input/output tensor list

  • inOut[in] Array of input and output tensor descriptors

  • nbInputs[in] Number of inputs

  • nbOutputs[in] Number of outputs

Returns:

True if format combination is supported

void configurePlugin(
nvinfer1::DynamicPluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#

Configure the plugin with input and output tensors.

Parameters:
  • in[in] Input tensor descriptors

  • nbInputs[in] Number of inputs

  • out[in] Output tensor descriptors

  • nbOutputs[in] Number of outputs

size_t getWorkspaceSize(
nvinfer1::PluginTensorDesc const *inputs,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const *outputs,
int32_t nbOutputs
) const noexcept override#

Get workspace size required by the plugin.

Parameters:
  • inputs[in] Input tensor descriptors

  • nbInputs[in] Number of inputs

  • outputs[in] Output tensor descriptors

  • nbOutputs[in] Number of outputs

Returns:

Workspace size in bytes

int32_t enqueue(
nvinfer1::PluginTensorDesc const *inputDesc,
nvinfer1::PluginTensorDesc const *outputDesc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream
) noexcept override#

Execute the plugin.

Parameters:
  • inputDesc[in] Input tensor descriptors

  • outputDesc[in] Output tensor descriptors

  • inputs[in] Input tensor data pointers

  • outputs[out] Output tensor data pointers

  • workspace[in] Workspace memory pointer

  • stream[in] CUDA stream for execution

Returns:

0 on success, non-zero on failure

size_t getSerializationSize() const noexcept override#

Get serialization size.

Returns:

Size in bytes required for serialization

void serialize(void *buffer) const noexcept override#

Serialize the plugin.

Parameters:

buffer[out] Buffer to write serialized data

char const *getPluginType() const noexcept override#

Get plugin type.

Returns:

Plugin type string

char const *getPluginNamespace() const noexcept override#

Get plugin namespace.

Returns:

Plugin namespace string

void setPluginNamespace(char const *pluginNamespace) noexcept#

Set plugin namespace.

Parameters:

pluginNamespace[in] Namespace to set

char const *getPluginVersion() const noexcept override#

Get plugin version.

Returns:

Plugin version string

int32_t initialize() noexcept override#

Initialize the plugin.

Returns:

0 on success, non-zero on failure

void terminate() noexcept override#

Terminate the plugin and release resources.

void destroy() noexcept override#

Destroy the plugin instance.

Public Functions

AttentionPlugin(
std::string const &name,
int32_t numQHeads,
int32_t numKVHeads,
int32_t headSize,
int32_t supportsSpecDecode
)#

Constructor for attention plugin with configuration parameters.

Parameters:
  • name[in] Plugin instance name

  • numQHeads[in] Number of query heads

  • numKVHeads[in] Number of key-value heads

  • headSize[in] Head dimension size

  • supportsSpecDecode[in] Whether to support speculative decoding (Tree attention)

AttentionPlugin(
std::string const &name,
void const *data,
size_t length
)#

Constructor for deserialization.

Parameters:
  • name[in] Plugin instance name

  • data[in] Serialized plugin data

  • length[in] Length of serialized data

AttentionPlugin() = delete#

Force to distinguish different instances of the plugin.

AttentionPlugin(AttentionPlugin const&) = delete#
~AttentionPlugin() override#
class AttentionPluginCreator : public nvinfer1::IPluginCreator#

Factory class for creating AttentionPlugin instances.

Public Functions

AttentionPluginCreator()#
~AttentionPluginCreator() override = default#
char const *getPluginName() const noexcept override#

Get plugin name.

Returns:

Plugin name string

nvinfer1::PluginFieldCollection const *getFieldNames(
) noexcept override#

Get plugin field collection.

Returns:

Pointer to plugin field collection containing all plugin fields

void setPluginNamespace(char const *pluginNamespace) noexcept#

Set plugin namespace.

Parameters:

pluginNamespace[in] Namespace to set

char const *getPluginNamespace() const noexcept override#

Get plugin namespace.

Returns:

Plugin namespace string

char const *getPluginVersion() const noexcept override#

Get plugin version.

Returns:

Plugin version string

nvinfer1::IPluginV2 *createPlugin(
char const *name,
nvinfer1::PluginFieldCollection const *fc
) noexcept override#

Create a new plugin instance.

Parameters:
  • name[in] Plugin instance name

  • fc[in] Plugin field collection containing configuration parameters

Returns:

Pointer to created plugin instance

nvinfer1::IPluginV2 *deserializePlugin(
char const *name,
void const *serialData,
size_t serialLength
) noexcept override#

Deserialize a plugin instance from data.

Parameters:
  • name[in] Plugin instance name

  • serialData[in] Serialized plugin data

  • serialLength[in] Length of serialized data in bytes

Returns:

Pointer to deserialized plugin instance