Gated Delta Net Plugin#
-
class GatedDeltaNetPlugin : public nvinfer1::IPluginV2DynamicExt#
TensorRT plugin for Gated Delta Net (V2 — IPluginV2DynamicExt).
Registered as “gated_delta_net”. Dispatches to decode (seq_len==1) or prefill (seq_len>1) CuTe DSL kernels. Requires SM80+ and K=V=128.
- Dimension notation
n = batch size h = number of Q/K heads hv = number of V heads k = head dimension K (must be 128) v = head dimension V (must be 128)
- Inputs
[0] q [n, seq_len, h, k] FP16 query [1] k [n, seq_len, h, k] FP16 key [2] v [n, seq_len, hv, v] FP16 value [3] a [n, seq_len, hv] FP16 input gate [4] b [n, seq_len, hv] FP16 output gate [5] A_log [hv] FP32 log decay [6] dt_bias [hv] FP16 delta-time bias [7] h0_source [n, hv, k, v] FP32 recurrent state in (batch-dense) [8] context_lengths [n] INT32 valid token count per batch row
- Outputs
[0] o [n, seq_len, hv, v] FP16 output [1] h0_out [n, hv, k, v] FP32 recurrent state out
Public Functions
- GatedDeltaNetPlugin(
- std::string const &name,
- int32_t kDim = 128,
- int32_t vDim = 128
- Parameters:
name – Plugin instance name
kDim – Head dimension K (must be 128 for CuTe DSL kernel)
vDim – Head dimension V (must be 128 for CuTe DSL kernel)
- GatedDeltaNetPlugin(
- std::string const &name,
- void const *data,
- size_t length
Deserialization constructor.
-
GatedDeltaNetPlugin() = delete#
-
GatedDeltaNetPlugin(GatedDeltaNetPlugin const&) = delete#
-
~GatedDeltaNetPlugin() override#
-
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override#
-
int32_t getNbOutputs() const noexcept override#
- nvinfer1::DataType getOutputDataType(
- int32_t index,
- nvinfer1::DataType const *inputTypes,
- int32_t nbInputs
- nvinfer1::DimsExprs getOutputDimensions(
- int32_t outputIndex,
- nvinfer1::DimsExprs const *inputs,
- int32_t nbInputs,
- nvinfer1::IExprBuilder &exprBuilder
- bool supportsFormatCombination(
- int32_t pos,
- nvinfer1::PluginTensorDesc const *inOut,
- int32_t nbInputs,
- int32_t nbOutputs
- void configurePlugin(
- nvinfer1::DynamicPluginTensorDesc const *in,
- int32_t nbInputs,
- nvinfer1::DynamicPluginTensorDesc const *out,
- int32_t nbOutputs
- size_t getWorkspaceSize(
- nvinfer1::PluginTensorDesc const *inputs,
- int32_t nbInputs,
- nvinfer1::PluginTensorDesc const *outputs,
- int32_t nbOutputs
- int32_t enqueue(
- nvinfer1::PluginTensorDesc const *inputDesc,
- nvinfer1::PluginTensorDesc const *outputDesc,
- void const *const *inputs,
- void *const *outputs,
- void *workspace,
- cudaStream_t stream
-
size_t getSerializationSize() const noexcept override#
-
void serialize(void *buffer) const noexcept override#
-
char const *getPluginType() const noexcept override#
-
char const *getPluginNamespace() const noexcept override#
- void setPluginNamespace(
- char const *pluginNamespace
-
char const *getPluginVersion() const noexcept override#
-
int32_t initialize() noexcept override#
-
void terminate() noexcept override#
-
void destroy() noexcept override#
-
class GatedDeltaNetPluginCreator : public nvinfer1::IPluginCreator#
Public Functions
-
GatedDeltaNetPluginCreator()#
-
~GatedDeltaNetPluginCreator() override = default#
-
char const *getPluginName() const noexcept override#
-
char const *getPluginVersion() const noexcept override#
- nvinfer1::PluginFieldCollection const *getFieldNames(
-
char const *getPluginNamespace() const noexcept override#
- void setPluginNamespace(
- char const *pluginNamespace
- nvinfer1::IPluginV2 *createPlugin(
- char const *name,
- nvinfer1::PluginFieldCollection const *fc
- nvinfer1::IPluginV2 *deserializePlugin(
- char const *name,
- void const *serialData,
- size_t serialLength
-
GatedDeltaNetPluginCreator()#