Selective State Update#

void mamba_ssm::invokeSelectiveStateUpdate(
trt_edgellm::rt::Tensor const &x,
trt_edgellm::rt::Tensor const &A,
trt_edgellm::rt::Tensor const &B,
trt_edgellm::rt::Tensor const &C,
trt_edgellm::rt::Tensor const &dt,
trt_edgellm::rt::OptionalInputTensor dt_bias,
trt_edgellm::rt::OptionalInputTensor D,
trt_edgellm::rt::OptionalInputTensor z,
trt_edgellm::rt::Tensor &state,
trt_edgellm::rt::Tensor &output,
bool dt_softplus,
cudaStream_t stream
)#

Launch the decode selective state update kernel (seq_len == 1).

Computes: new_state = state * exp(A * dt) + B * dt * x output = sum_i(new_state_i * C_i) + D * x if z is present: output *= silu(z)

x: [batch, nheads, dim] A: [nheads], FP32 B, C: [batch, ngroups, dstate] dt: [batch, nheads] dt_bias: [nheads] (optional) D: [nheads] (optional skip connection) z: same shape as x (optional SiLU gate) state: [batch, nheads, dim, dstate], updated in-place output: [batch, nheads, dim]

void mamba_ssm::invokeSelectiveStateUpdatePrefill(
trt_edgellm::rt::Tensor const &x,
trt_edgellm::rt::Tensor const &A,
trt_edgellm::rt::Tensor const &B,
trt_edgellm::rt::Tensor const &C,
trt_edgellm::rt::Tensor const &dt,
trt_edgellm::rt::OptionalInputTensor dt_bias,
trt_edgellm::rt::OptionalInputTensor D,
trt_edgellm::rt::OptionalInputTensor z,
trt_edgellm::rt::Tensor &state,
trt_edgellm::rt::Tensor &output,
bool dt_softplus,
cudaStream_t stream
)#

Launch the prefill selective state update kernel (seq_len > 1).

Processes all seq_len tokens in a single kernel launch. x must be 4D: [batch, seq_len, nheads, dim].