Mamba Plugin#

class MambaPlugin : public nvinfer1::IPluginV3, public nvinfer1::IPluginV3OneCore, public nvinfer1::IPluginV3OneBuild, public nvinfer1::IPluginV3OneRuntime#

TensorRT plugin for Mamba Selective State Update (SSM)

Registered as “update_ssm_state” under the “trt_edgellm” ONNX domain.

Implements the selective state space model update: new_state = state * exp(A * dt) + B * dt * x output = sum_i(new_state_i * C_i) + D * x

SiLU gating (z) is handled externally by the ONNX graph (gated_rms_norm).

Inputs may include an optional seq_len dimension (e.g. x as [batch, seq_len, nheads, dim] instead of [batch, nheads, dim]). When seq_len > 1, the plugin loops over the single-step kernel internally, updating the SSM state in-place after each step.

Performance note: the loop launches one kernel per time step. For decode (seq_len=1) this is optimal. For prefill (seq_len >> 1) this is O(seq_len) serial launches, which is correct but slower than a parallel chunked scan. A future optimization would dispatch to a mamba_chunk_scan_combined kernel when seq_len exceeds a threshold.

Input ordering (see constants defined in mambaPlugin.cpp): [0] x [batch, (seq_len,) nheads, dim] FP16 or FP32 [1] A [nheads] FP32 (always) [2] B [batch, (seq_len,) ngroups, dstate] FP16 or FP32 [3] C [batch, (seq_len,) ngroups, dstate] FP16 or FP32 [4] D [nheads] FP16 or FP32 [5] dt [batch, (seq_len,) nheads] FP16 or FP32 [6] dt_bias [nheads] FP16 or FP32 [7] state [batch, nheads, dim, dstate] FP16 or FP32

All data tensors (everything except A) must use the same type. TRT selects FP32 when the ONNX graph declares FP32, and may optimize to FP16 during the builder phase when the FP16 flag is set.

Outputs: [0] output [batch, (seq_len,) nheads, dim] same as input type [1] state_out [batch, nheads, dim, dstate] same as input type

Public Functions

MambaPlugin(
std::string const &name,
int32_t dim,
int32_t dstate,
int32_t nheads,
int32_t ngroups,
int32_t dtSoftplus
)#
MambaPlugin() = delete#
MambaPlugin(MambaPlugin const&) = delete#
~MambaPlugin() override#
nvinfer1::IPluginCapability *getCapabilityInterface(
nvinfer1::PluginCapabilityType type
) noexcept override#
nvinfer1::IPluginV3 *clone() noexcept override#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
char const *getPluginNamespace() const noexcept override#
int32_t getNbOutputs() const noexcept override#
int32_t getOutputDataTypes(
nvinfer1::DataType *outputTypes,
int32_t nbOutputs,
nvinfer1::DataType const *inputTypes,
int32_t nbInputs
) const noexcept override#
int32_t getOutputShapes(
nvinfer1::DimsExprs const *inputs,
int32_t nbInputs,
nvinfer1::DimsExprs const *shapeInputs,
int32_t nbShapeInputs,
nvinfer1::DimsExprs *outputs,
int32_t nbOutputs,
nvinfer1::IExprBuilder &exprBuilder
) noexcept override#
bool supportsFormatCombination(
int32_t pos,
nvinfer1::DynamicPluginTensorDesc const *inOut,
int32_t nbInputs,
int32_t nbOutputs
) noexcept override#
int32_t configurePlugin(
nvinfer1::DynamicPluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
size_t getWorkspaceSize(
nvinfer1::DynamicPluginTensorDesc const *inputs,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const *outputs,
int32_t nbOutputs
) const noexcept override#
int32_t enqueue(
nvinfer1::PluginTensorDesc const *inputDesc,
nvinfer1::PluginTensorDesc const *outputDesc,
void const *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream
) noexcept override#
int32_t onShapeChange(
nvinfer1::PluginTensorDesc const *in,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const *out,
int32_t nbOutputs
) noexcept override#
nvinfer1::IPluginV3 *attachToContext(
nvinfer1::IPluginResourceContext *context
) noexcept override#
nvinfer1::PluginFieldCollection const *getFieldsToSerialize(
) noexcept override#
void setPluginNamespace(char const *pluginNamespace) noexcept#
class MambaPluginCreator : public nvinfer1::IPluginCreatorV3One#

Public Functions

MambaPluginCreator()#
~MambaPluginCreator() override = default#
char const *getPluginName() const noexcept override#
char const *getPluginVersion() const noexcept override#
nvinfer1::PluginFieldCollection const *getFieldNames(
) noexcept override#
void setPluginNamespace(char const *pluginNamespace) noexcept#
char const *getPluginNamespace() const noexcept override#
nvinfer1::IPluginV3 *createPlugin(
char const *name,
nvinfer1::PluginFieldCollection const *fc,
nvinfer1::TensorRTPhase phase
) noexcept override#