Attention Plugin#
-
class AttentionPlugin : public nvinfer1::IPluginV2DynamicExt#
TensorRT plugin for attention operations (context and decode)
This plugin implements efficient attention mechanisms including context attention (prefill) and decode attention with KV cache support.
IPluginV2DynamicExt Methods
-
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override#
Clone the plugin instance.
- Returns:
Pointer to cloned plugin
-
int32_t getNbOutputs() const noexcept override#
Get number of outputs.
- Returns:
Number of output tensors
- nvinfer1::DataType getOutputDataType(
- int32_t index,
- nvinfer1::DataType const *inputTypes,
- int32_t nbInputs
Get output data type.
- Parameters:
index – [in] Output index
inputTypes – [in] Array of input data types
nbInputs – [in] Number of inputs
- Returns:
Output data type
- nvinfer1::DimsExprs getOutputDimensions(
- int32_t outputIndex,
- nvinfer1::DimsExprs const *inputs,
- int32_t nbInputs,
- nvinfer1::IExprBuilder &exprBuilder
Get output dimensions.
- Parameters:
outputIndex – [in] Output tensor index
inputs – [in] Input tensor dimensions
nbInputs – [in] Number of inputs
exprBuilder – [in] Expression builder for dimension calculations
- Returns:
Output tensor dimensions
- bool supportsFormatCombination(
- int32_t pos,
- nvinfer1::PluginTensorDesc const *inOut,
- int32_t nbInputs,
- int32_t nbOutputs
Check if format combination is supported.
- Parameters:
pos – [in] Position in the input/output tensor list
inOut – [in] Array of input and output tensor descriptors
nbInputs – [in] Number of inputs
nbOutputs – [in] Number of outputs
- Returns:
True if format combination is supported
- void configurePlugin(
- nvinfer1::DynamicPluginTensorDesc const *in,
- int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const *out,
- int32_t nbOutputs
Configure the plugin with input and output tensors.
- Parameters:
in – [in] Input tensor descriptors
nbInputs – [in] Number of inputs
out – [in] Output tensor descriptors
nbOutputs – [in] Number of outputs
- size_t getWorkspaceSize(
- nvinfer1::PluginTensorDesc const *inputs,
- int32_t nbInputs,
- nvinfer1::PluginTensorDesc const *outputs,
- int32_t nbOutputs
Get workspace size required by the plugin.
- Parameters:
inputs – [in] Input tensor descriptors
nbInputs – [in] Number of inputs
outputs – [in] Output tensor descriptors
nbOutputs – [in] Number of outputs
- Returns:
Workspace size in bytes
- int32_t enqueue(
- nvinfer1::PluginTensorDesc const *inputDesc,
- nvinfer1::PluginTensorDesc const *outputDesc,
- void const *const *inputs,
- void *const *outputs,
- void *workspace,
- cudaStream_t stream
Execute the plugin.
- Parameters:
inputDesc – [in] Input tensor descriptors
outputDesc – [in] Output tensor descriptors
inputs – [in] Input tensor data pointers
outputs – [out] Output tensor data pointers
workspace – [in] Workspace memory pointer
stream – [in] CUDA stream for execution
- Returns:
0 on success, non-zero on failure
-
size_t getSerializationSize() const noexcept override#
Get serialization size.
- Returns:
Size in bytes required for serialization
-
void serialize(void *buffer) const noexcept override#
Serialize the plugin.
- Parameters:
buffer – [out] Buffer to write serialized data
-
char const *getPluginType() const noexcept override#
Get plugin type.
- Returns:
Plugin type string
-
char const *getPluginNamespace() const noexcept override#
Get plugin namespace.
- Returns:
Plugin namespace string
-
void setPluginNamespace(char const *pluginNamespace) noexcept#
Set plugin namespace.
- Parameters:
pluginNamespace – [in] Namespace to set
-
char const *getPluginVersion() const noexcept override#
Get plugin version.
- Returns:
Plugin version string
-
int32_t initialize() noexcept override#
Initialize the plugin.
- Returns:
0 on success, non-zero on failure
-
void terminate() noexcept override#
Terminate the plugin and release resources.
-
void destroy() noexcept override#
Destroy the plugin instance.
Public Functions
- AttentionPlugin(
- std::string const &name,
- int32_t numQHeads,
- int32_t numKVHeads,
- int32_t headSize,
- int32_t supportsSpecDecode
Constructor for attention plugin with configuration parameters.
- Parameters:
name – [in] Plugin instance name
numQHeads – [in] Number of query heads
numKVHeads – [in] Number of key-value heads
headSize – [in] Head dimension size
supportsSpecDecode – [in] Whether to support speculative decoding (Tree attention)
- AttentionPlugin(
- std::string const &name,
- void const *data,
- size_t length
Constructor for deserialization.
- Parameters:
name – [in] Plugin instance name
data – [in] Serialized plugin data
length – [in] Length of serialized data
-
AttentionPlugin() = delete#
Force to distinguish different instances of the plugin.
-
AttentionPlugin(AttentionPlugin const&) = delete#
-
~AttentionPlugin() override#
-
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override#
-
class AttentionPluginCreator : public nvinfer1::IPluginCreator#
Factory class for creating AttentionPlugin instances.
Public Functions
-
AttentionPluginCreator()#
-
~AttentionPluginCreator() override = default#
-
char const *getPluginName() const noexcept override#
Get plugin name.
- Returns:
Plugin name string
- nvinfer1::PluginFieldCollection const *getFieldNames(
Get plugin field collection.
- Returns:
Pointer to plugin field collection containing all plugin fields
-
void setPluginNamespace(char const *pluginNamespace) noexcept#
Set plugin namespace.
- Parameters:
pluginNamespace – [in] Namespace to set
-
char const *getPluginNamespace() const noexcept override#
Get plugin namespace.
- Returns:
Plugin namespace string
-
char const *getPluginVersion() const noexcept override#
Get plugin version.
- Returns:
Plugin version string
- nvinfer1::IPluginV2 *createPlugin(
- char const *name,
- nvinfer1::PluginFieldCollection const *fc
Create a new plugin instance.
- Parameters:
name – [in] Plugin instance name
fc – [in] Plugin field collection containing configuration parameters
- Returns:
Pointer to created plugin instance
- nvinfer1::IPluginV2 *deserializePlugin(
- char const *name,
- void const *serialData,
- size_t serialLength
Deserialize a plugin instance from data.
- Parameters:
name – [in] Plugin instance name
serialData – [in] Serialized plugin data
serialLength – [in] Length of serialized data in bytes
- Returns:
Pointer to deserialized plugin instance
-
AttentionPluginCreator()#