Gated Delta Net Plugin#
-
class GatedDeltaNetPlugin : public nvinfer1::IPluginV3, public nvinfer1::IPluginV3OneCore, public nvinfer1::IPluginV3OneBuildV2, public nvinfer1::IPluginV3OneRuntime#
TensorRT plugin for Gated Delta Net (V3 — IPluginV3).
Registered as “gated_delta_net”. Dispatches to decode (seq_len==1), prefill (seq_len>1), or MTP verify (use_mtp=true) CuTe DSL kernels. Requires SM80+ and K=V=128.
- Dimension notation
n = batch size h = number of Q/K heads hv = number of V heads k = head dimension K (must be 128) v = head dimension V (must be 128)
- Inputs
[0] q [n, seq_len, h, k] FP16 query [1] k [n, seq_len, h, k] FP16 key [2] v [n, seq_len, hv, v] FP16 value [3] a [n, seq_len, hv] FP16 input gate [4] b [n, seq_len, hv] FP16 output gate [5] A_log [hv] FP32 log decay [6] dt_bias [hv] FP16 delta-time bias [7] h0_source [n, hv, k, v] FP32 recurrent state in (batch-dense) [8] context_lengths [n] INT32 valid token count per batch row
- Outputs
[0] o [n, seq_len, hv, v] FP16 output [1] h0_out [n, hv, k, v] FP32 recurrent state out [2] intermediate_states [n, seq_len, hv, k, v] FP32 (MTP only, optional) Per-step recurrent state cache for speculative-decoding rollback. Only populated when use_mtp=true (plugin attribute) and seq_len>1. When use_mtp=false this output is a 1-element dummy.
Public Functions
- GatedDeltaNetPlugin(
- std::string const &name,
- int32_t kDim = 128,
- int32_t vDim = 128,
- bool useMTP = false
- Parameters:
name – Plugin instance name
kDim – Head dimension K (must be 128 for CuTe DSL kernel)
vDim – Head dimension V (must be 128 for CuTe DSL kernel)
- GatedDeltaNetPlugin(
- std::string const &name,
- nvinfer1::PluginFieldCollection const *fc
-
GatedDeltaNetPlugin() = delete#
-
GatedDeltaNetPlugin(GatedDeltaNetPlugin const&) = delete#
-
~GatedDeltaNetPlugin() override#
- nvinfer1::IPluginCapability *getCapabilityInterface(
- nvinfer1::PluginCapabilityType type
-
nvinfer1::IPluginV3 *clone() noexcept override#
-
char const *getPluginName() const noexcept override#
-
char const *getPluginVersion() const noexcept override#
-
char const *getPluginNamespace() const noexcept override#
-
int32_t getNbOutputs() const noexcept override#
- int32_t getOutputDataTypes(
- nvinfer1::DataType *outputTypes,
- int32_t nbOutputs,
- nvinfer1::DataType const *inputTypes,
- int32_t nbInputs
- int32_t getOutputShapes(
- nvinfer1::DimsExprs const *inputs,
- int32_t nbInputs,
- nvinfer1::DimsExprs const *shapeInputs,
- int32_t nbShapeInputs,
- nvinfer1::DimsExprs *outputs,
- int32_t nbOutputs,
- nvinfer1::IExprBuilder &exprBuilder
- bool supportsFormatCombination(
- int32_t pos,
- nvinfer1::DynamicPluginTensorDesc const *inOut,
- int32_t nbInputs,
- int32_t nbOutputs
- int32_t configurePlugin(
- nvinfer1::DynamicPluginTensorDesc const *in,
- int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const *out,
- int32_t nbOutputs
- size_t getWorkspaceSize(
- nvinfer1::DynamicPluginTensorDesc const *inputs,
- int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const *outputs,
- int32_t nbOutputs
-
int32_t getAliasedInput(int32_t outputIndex) noexcept override#
- int32_t enqueue(
- nvinfer1::PluginTensorDesc const *inputDesc,
- nvinfer1::PluginTensorDesc const *outputDesc,
- void const *const *inputs,
- void *const *outputs,
- void *workspace,
- cudaStream_t stream
- int32_t onShapeChange(
- nvinfer1::PluginTensorDesc const *in,
- int32_t nbInputs,
- nvinfer1::PluginTensorDesc const *out,
- int32_t nbOutputs
- nvinfer1::IPluginV3 *attachToContext(
- nvinfer1::IPluginResourceContext *context
- nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
-
void setPluginNamespace(char const *pluginNamespace) noexcept#
-
class GatedDeltaNetPluginCreator : public nvinfer1::IPluginCreatorV3One#
Public Functions
-
GatedDeltaNetPluginCreator()#
-
~GatedDeltaNetPluginCreator() override = default#
-
char const *getPluginName() const noexcept override#
-
char const *getPluginVersion() const noexcept override#
- nvinfer1::PluginFieldCollection const *getFieldNames(
-
char const *getPluginNamespace() const noexcept override#
-
void setPluginNamespace(char const *pluginNamespace) noexcept#
- nvinfer1::IPluginV3 *createPlugin(
- char const *name,
- nvinfer1::PluginFieldCollection const *fc,
- nvinfer1::TensorRTPhase phase
-
GatedDeltaNetPluginCreator()#