Attention Plugin#

class AttentionPlugin : public nvinfer1::IPluginV3, public nvinfer1::IPluginV3OneCore, public nvinfer1::IPluginV3OneBuildV2, public nvinfer1::IPluginV3OneRuntime#

TensorRT plugin for attention operations (V3 — IPluginV3).

This plugin implements efficient attention mechanisms including context attention (prefill) and decode attention with KV cache support.

Public Functions

AttentionPlugin(
std::string const &name,
int32_t numQHeads,
int32_t numKVHeads,
int32_t headSize,
int32_t supportsSpecDecode,
int32_t enableFp8KVCache,
int32_t slidingWindowSize = -1,
std::vector<float> const &qkvScales = {}
)#

Constructor for attention plugin with configuration parameters.

Parameters:
  • name[in] Plugin instance name

  • numQHeads[in] Number of query heads

  • numKVHeads[in] Number of key-value heads

  • headSize[in] Head dimension size

  • supportsSpecDecode[in] Whether to support speculative decoding (Tree attention)

  • enableFp8KVCache[in] Whether to enable FP8 KV cache

  • slidingWindowSize[in] Sliding window size (-1 = no sliding window)

  • qkvScales[in] Optional [q, k, v] FP8 dequant scales (required when enableFp8KVCache)

AttentionPlugin(
std::string const &name,
nvinfer1::PluginFieldCollection const *fc
)#
AttentionPlugin() = delete#
AttentionPlugin(AttentionPlugin const&) = delete#
~AttentionPlugin() override#
nvinfer1::IPluginCapability *getCapabilityInterface(
nvinfer1::PluginCapabilityType type
) noexcept override#
nvinfer1::IPluginV3 *clone() noexcept override#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
char const *getPluginNamespace() const noexcept override#
int32_t getNbOutputs() const noexcept override#
int32_t getOutputDataTypes(
nvinfer1::DataType *outputTypes,
int32_t nbOutputs,
nvinfer1::DataType const *inputTypes,
int32_t nbInputs
) const noexcept override#
int32_t getOutputShapes(
nvinfer1::DimsExprs const *inputs,
int32_t nbInputs,
nvinfer1::DimsExprs const *shapeInputs,
int32_t nbShapeInputs,
nvinfer1::DimsExprs *outputs,
int32_t nbOutputs,
nvinfer1::IExprBuilder &exprBuilder
) noexcept override#
bool supportsFormatCombination(
int32_t pos,
nvinfer1::DynamicPluginTensorDesc const *inOut,
int32_t nbInputs,
int32_t nbOutputs
) noexcept override#
int32_t configurePlugin(
nvinfer1::DynamicPluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
size_t getWorkspaceSize(
nvinfer1::DynamicPluginTensorDesc const *inputs,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *outputs,
int32_t nbOutputs
) const noexcept override#
int32_t getAliasedInput(int32_t outputIndex) noexcept override#
int32_t enqueue(
nvinfer1::PluginTensorDesc const *inputDesc,
nvinfer1::PluginTensorDesc const *outputDesc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream
) noexcept override#
int32_t onShapeChange(
nvinfer1::PluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
nvinfer1::IPluginV3 *attachToContext(
nvinfer1::IPluginResourceContext *context
) noexcept override#
nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
) noexcept override#
void setPluginNamespace(char const *pluginNamespace) noexcept#
class AttentionPluginCreator : public nvinfer1::IPluginCreatorV3One#

Factory class for creating AttentionPlugin instances.

Public Functions

AttentionPluginCreator()#
~AttentionPluginCreator() override = default#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
nvinfer1::PluginFieldCollection const *getFieldNames(
) noexcept override#
char const *getPluginNamespace() const noexcept override#
void setPluginNamespace(char const *pluginNamespace) noexcept#
nvinfer1::IPluginV3 *createPlugin(
char const *name,
nvinfer1::PluginFieldCollection const *fc,
nvinfer1::TensorRTPhase phase
) noexcept override#