Gated Delta Net Plugin#

class GatedDeltaNetPlugin : public nvinfer1::IPluginV2DynamicExt#

TensorRT plugin for Gated Delta Net (V2 — IPluginV2DynamicExt).

Registered as “gated_delta_net”. Dispatches to decode (seq_len==1) or prefill (seq_len>1) CuTe DSL kernels. Requires SM80+ and K=V=128.

Dimension notation

n = batch size h = number of Q/K heads hv = number of V heads k = head dimension K (must be 128) v = head dimension V (must be 128)

Inputs

[0] q [n, seq_len, h, k] FP16 query [1] k [n, seq_len, h, k] FP16 key [2] v [n, seq_len, hv, v] FP16 value [3] a [n, seq_len, hv] FP16 input gate [4] b [n, seq_len, hv] FP16 output gate [5] A_log [hv] FP32 log decay [6] dt_bias [hv] FP16 delta-time bias [7] h0_source [n, hv, k, v] FP32 recurrent state in (batch-dense) [8] context_lengths [n] INT32 valid token count per batch row

Outputs

[0] o [n, seq_len, hv, v] FP16 output [1] h0_out [n, hv, k, v] FP32 recurrent state out

Public Functions

GatedDeltaNetPlugin(
std::string const &name,
int32_t kDim = 128,
int32_t vDim = 128
)#
Parameters:
  • name – Plugin instance name

  • kDim – Head dimension K (must be 128 for CuTe DSL kernel)

  • vDim – Head dimension V (must be 128 for CuTe DSL kernel)

GatedDeltaNetPlugin(
std::string const &name,
void const *data,
size_t length
)#

Deserialization constructor.

GatedDeltaNetPlugin() = delete#
GatedDeltaNetPlugin(GatedDeltaNetPlugin const&) = delete#
~GatedDeltaNetPlugin() override#
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override#
int32_t getNbOutputs() const noexcept override#
nvinfer1::DataType getOutputDataType(
int32_t index,
nvinfer1::DataType const *inputTypes,
int32_t nbInputs
) const noexcept override#
nvinfer1::DimsExprs getOutputDimensions(
int32_t outputIndex,
nvinfer1::DimsExprs const *inputs,
int32_t nbInputs,
nvinfer1::IExprBuilder &exprBuilder
) noexcept override#
bool supportsFormatCombination(
int32_t pos,
nvinfer1::PluginTensorDesc const *inOut,
int32_t nbInputs,
int32_t nbOutputs
) noexcept override#
void configurePlugin(
nvinfer1::DynamicPluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
size_t getWorkspaceSize(
nvinfer1::PluginTensorDesc const *inputs,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const *outputs,
int32_t nbOutputs
) const noexcept override#
int32_t enqueue(
nvinfer1::PluginTensorDesc const *inputDesc,
nvinfer1::PluginTensorDesc const *outputDesc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream
) noexcept override#
size_t getSerializationSize() const noexcept override#
void serialize(void *buffer) const noexcept override#
char const *getPluginType() const noexcept override#
char const *getPluginNamespace() const noexcept override#
void setPluginNamespace(
char const *pluginNamespace
) noexcept override#
char const *getPluginVersion() const noexcept override#
int32_t initialize() noexcept override#
void terminate() noexcept override#
void destroy() noexcept override#
class GatedDeltaNetPluginCreator : public nvinfer1::IPluginCreator#

Public Functions

GatedDeltaNetPluginCreator()#
~GatedDeltaNetPluginCreator() override = default#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
nvinfer1::PluginFieldCollection const *getFieldNames(
) noexcept override#
char const *getPluginNamespace() const noexcept override#
void setPluginNamespace(
char const *pluginNamespace
) noexcept override#
nvinfer1::IPluginV2 *createPlugin(
char const *name,
nvinfer1::PluginFieldCollection const *fc
) noexcept override#
nvinfer1::IPluginV2 *deserializePlugin(
char const *name,
void const *serialData,
size_t serialLength
) noexcept override#