Selective State Update#
- void mamba_ssm::invokeSelectiveStateUpdate(
- trt_edgellm::rt::Tensor const &x,
- trt_edgellm::rt::Tensor const &A,
- trt_edgellm::rt::Tensor const &B,
- trt_edgellm::rt::Tensor const &C,
- trt_edgellm::rt::Tensor const &dt,
- trt_edgellm::rt::OptionalInputTensor dt_bias,
- trt_edgellm::rt::OptionalInputTensor D,
- trt_edgellm::rt::OptionalInputTensor z,
- trt_edgellm::rt::Tensor &state,
- trt_edgellm::rt::Tensor &output,
- bool dt_softplus,
- cudaStream_t stream
Launch the decode selective state update kernel (seq_len == 1).
Computes: new_state = state * exp(A * dt) + B * dt * x output = sum_i(new_state_i * C_i) + D * x if z is present: output *= silu(z)
x: [batch, nheads, dim] A: [nheads], FP32 B, C: [batch, ngroups, dstate] dt: [batch, nheads] dt_bias: [nheads] (optional) D: [nheads] (optional skip connection) z: same shape as x (optional SiLU gate) state: [batch, nheads, dim, dstate], updated in-place output: [batch, nheads, dim]
- void mamba_ssm::invokeSelectiveStateUpdatePrefill(
- trt_edgellm::rt::Tensor const &x,
- trt_edgellm::rt::Tensor const &A,
- trt_edgellm::rt::Tensor const &B,
- trt_edgellm::rt::Tensor const &C,
- trt_edgellm::rt::Tensor const &dt,
- trt_edgellm::rt::OptionalInputTensor dt_bias,
- trt_edgellm::rt::OptionalInputTensor D,
- trt_edgellm::rt::OptionalInputTensor z,
- trt_edgellm::rt::Tensor &state,
- trt_edgellm::rt::Tensor &output,
- bool dt_softplus,
- cudaStream_t stream
Launch the prefill selective state update kernel (seq_len > 1).
Processes all seq_len tokens in a single kernel launch. x must be 4D: [batch, seq_len, nheads, dim].