Gated Delta Net Plugin#

class GatedDeltaNetPlugin : public nvinfer1::IPluginV3, public nvinfer1::IPluginV3OneCore, public nvinfer1::IPluginV3OneBuildV2, public nvinfer1::IPluginV3OneRuntime#

TensorRT plugin for Gated Delta Net (V3 — IPluginV3).

Registered as “gated_delta_net”. Dispatches to decode (seq_len==1), prefill (seq_len>1), or MTP verify (use_mtp=true) CuTe DSL kernels. Requires SM80+ and K=V=128.

Dimension notation

n = batch size h = number of Q/K heads hv = number of V heads k = head dimension K (must be 128) v = head dimension V (must be 128)

Inputs

[0] q [n, seq_len, h, k] FP16 query [1] k [n, seq_len, h, k] FP16 key [2] v [n, seq_len, hv, v] FP16 value [3] a [n, seq_len, hv] FP16 input gate [4] b [n, seq_len, hv] FP16 output gate [5] A_log [hv] FP32 log decay [6] dt_bias [hv] FP16 delta-time bias [7] h0_source [n, hv, k, v] FP32 recurrent state in (batch-dense) [8] context_lengths [n] INT32 valid token count per batch row

Outputs

[0] o [n, seq_len, hv, v] FP16 output [1] h0_out [n, hv, k, v] FP32 recurrent state out [2] intermediate_states [n, seq_len, hv, k, v] FP32 (MTP only, optional) Per-step recurrent state cache for speculative-decoding rollback. Only populated when use_mtp=true (plugin attribute) and seq_len>1. When use_mtp=false this output is a 1-element dummy.

Public Functions

GatedDeltaNetPlugin(
std::string const &name,
int32_t kDim = 128,
int32_t vDim = 128,
bool useMTP = false
)#
Parameters:
  • name – Plugin instance name

  • kDim – Head dimension K (must be 128 for CuTe DSL kernel)

  • vDim – Head dimension V (must be 128 for CuTe DSL kernel)

GatedDeltaNetPlugin(
std::string const &name,
nvinfer1::PluginFieldCollection const *fc
)#
GatedDeltaNetPlugin() = delete#
GatedDeltaNetPlugin(GatedDeltaNetPlugin const&) = delete#
~GatedDeltaNetPlugin() override#
nvinfer1::IPluginCapability *getCapabilityInterface(
nvinfer1::PluginCapabilityType type
) noexcept override#
nvinfer1::IPluginV3 *clone() noexcept override#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
char const *getPluginNamespace() const noexcept override#
int32_t getNbOutputs() const noexcept override#
int32_t getOutputDataTypes(
nvinfer1::DataType *outputTypes,
int32_t nbOutputs,
nvinfer1::DataType const *inputTypes,
int32_t nbInputs
) const noexcept override#
int32_t getOutputShapes(
nvinfer1::DimsExprs const *inputs,
int32_t nbInputs,
nvinfer1::DimsExprs const *shapeInputs,
int32_t nbShapeInputs,
nvinfer1::DimsExprs *outputs,
int32_t nbOutputs,
nvinfer1::IExprBuilder &exprBuilder
) noexcept override#
bool supportsFormatCombination(
int32_t pos,
nvinfer1::DynamicPluginTensorDesc const *inOut,
int32_t nbInputs,
int32_t nbOutputs
) noexcept override#
int32_t configurePlugin(
nvinfer1::DynamicPluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
size_t getWorkspaceSize(
nvinfer1::DynamicPluginTensorDesc const *inputs,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *outputs,
int32_t nbOutputs
) const noexcept override#
int32_t getAliasedInput(int32_t outputIndex) noexcept override#
int32_t enqueue(
nvinfer1::PluginTensorDesc const *inputDesc,
nvinfer1::PluginTensorDesc const *outputDesc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream
) noexcept override#
int32_t onShapeChange(
nvinfer1::PluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
nvinfer1::IPluginV3 *attachToContext(
nvinfer1::IPluginResourceContext *context
) noexcept override#
nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
) noexcept override#
void setPluginNamespace(char const *pluginNamespace) noexcept#
class GatedDeltaNetPluginCreator : public nvinfer1::IPluginCreatorV3One#

Public Functions

GatedDeltaNetPluginCreator()#
~GatedDeltaNetPluginCreator() override = default#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
nvinfer1::PluginFieldCollection const *getFieldNames(
) noexcept override#
char const *getPluginNamespace() const noexcept override#
void setPluginNamespace(char const *pluginNamespace) noexcept#
nvinfer1::IPluginV3 *createPlugin(
char const *name,
nvinfer1::PluginFieldCollection const *fc,
nvinfer1::TensorRTPhase phase
) noexcept override#