Attention Plugin#
-
class AttentionPlugin : public nvinfer1::IPluginV3, public nvinfer1::IPluginV3OneCore, public nvinfer1::IPluginV3OneBuildV2, public nvinfer1::IPluginV3OneRuntime#
TensorRT plugin for attention operations (V3 — IPluginV3).
This plugin implements efficient attention mechanisms including context attention (prefill) and decode attention with KV cache support.
Public Functions
- AttentionPlugin(
- std::string const &name,
- int32_t numQHeads,
- int32_t numKVHeads,
- int32_t headSize,
- int32_t supportsSpecDecode,
- int32_t enableFp8KVCache,
- int32_t slidingWindowSize = -1,
- std::vector<float> const &qkvScales = {}
Constructor for attention plugin with configuration parameters.
- Parameters:
name – [in] Plugin instance name
numQHeads – [in] Number of query heads
numKVHeads – [in] Number of key-value heads
headSize – [in] Head dimension size
supportsSpecDecode – [in] Whether to support speculative decoding (Tree attention)
enableFp8KVCache – [in] Whether to enable FP8 KV cache
slidingWindowSize – [in] Sliding window size (-1 = no sliding window)
qkvScales – [in] Optional [q, k, v] FP8 dequant scales (required when enableFp8KVCache)
- AttentionPlugin(
- std::string const &name,
- nvinfer1::PluginFieldCollection const *fc
-
AttentionPlugin() = delete#
-
AttentionPlugin(AttentionPlugin const&) = delete#
-
~AttentionPlugin() override#
- nvinfer1::IPluginCapability *getCapabilityInterface(
- nvinfer1::PluginCapabilityType type
-
nvinfer1::IPluginV3 *clone() noexcept override#
-
char const *getPluginName() const noexcept override#
-
char const *getPluginVersion() const noexcept override#
-
char const *getPluginNamespace() const noexcept override#
-
int32_t getNbOutputs() const noexcept override#
- int32_t getOutputDataTypes(
- nvinfer1::DataType *outputTypes,
- int32_t nbOutputs,
- nvinfer1::DataType const *inputTypes,
- int32_t nbInputs
- int32_t getOutputShapes(
- nvinfer1::DimsExprs const *inputs,
- int32_t nbInputs,
- nvinfer1::DimsExprs const *shapeInputs,
- int32_t nbShapeInputs,
- nvinfer1::DimsExprs *outputs,
- int32_t nbOutputs,
- nvinfer1::IExprBuilder &exprBuilder
- bool supportsFormatCombination(
- int32_t pos,
- nvinfer1::DynamicPluginTensorDesc const *inOut,
- int32_t nbInputs,
- int32_t nbOutputs
- int32_t configurePlugin(
- nvinfer1::DynamicPluginTensorDesc const *in,
- int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const *out,
- int32_t nbOutputs
- size_t getWorkspaceSize(
- nvinfer1::DynamicPluginTensorDesc const *inputs,
- int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const *outputs,
- int32_t nbOutputs
-
int32_t getAliasedInput(int32_t outputIndex) noexcept override#
- int32_t enqueue(
- nvinfer1::PluginTensorDesc const *inputDesc,
- nvinfer1::PluginTensorDesc const *outputDesc,
- void const *const *inputs,
- void *const *outputs,
- void *workspace,
- cudaStream_t stream
- int32_t onShapeChange(
- nvinfer1::PluginTensorDesc const *in,
- int32_t nbInputs,
- nvinfer1::PluginTensorDesc const *out,
- int32_t nbOutputs
- nvinfer1::IPluginV3 *attachToContext(
- nvinfer1::IPluginResourceContext *context
- nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
-
void setPluginNamespace(char const *pluginNamespace) noexcept#
-
class AttentionPluginCreator : public nvinfer1::IPluginCreatorV3One#
Factory class for creating AttentionPlugin instances.
Public Functions
-
AttentionPluginCreator()#
-
~AttentionPluginCreator() override = default#
-
char const *getPluginName() const noexcept override#
-
char const *getPluginVersion() const noexcept override#
- nvinfer1::PluginFieldCollection const *getFieldNames(
-
char const *getPluginNamespace() const noexcept override#
-
void setPluginNamespace(char const *pluginNamespace) noexcept#
- nvinfer1::IPluginV3 *createPlugin(
- char const *name,
- nvinfer1::PluginFieldCollection const *fc,
- nvinfer1::TensorRTPhase phase
-
AttentionPluginCreator()#