Causal Conv1d#

void mamba_ssm::invokeCausalConv1d(
trt_edgellm::rt::Tensor const &x,
trt_edgellm::rt::Tensor const &weight,
trt_edgellm::rt::OptionalInputTensor bias,
trt_edgellm::rt::Tensor &out,
int32_t stride,
int32_t padding,
int32_t dilation,
trt_edgellm::rt::OptionalInputTensor contextLengths,
cudaStream_t stream
)#

Prefill causal depthwise conv1d.

x: [batch, seq_len, dim] weight: [dim, 1, width] bias: [dim] (optional) contextLengths: [batch] INT32, per-batch actual token count (optional, prefill only) out: [batch, out_seq_len, dim]

void mamba_ssm::invokeCaptureConvState(
trt_edgellm::rt::Tensor const &x,
trt_edgellm::rt::Tensor &convState,
trt_edgellm::rt::OptionalInputTensor contextLengths,
cudaStream_t stream
)#

Capture conv state from prefill input.

x: [batch, seqLen, dim] contextLengths: [batch] INT32, per-batch actual token count (optional) convState: [batch, dim, width] (output, zero-initialized before call)

void mamba_ssm::invokeCausalConv1dDecode(
trt_edgellm::rt::Tensor &convState,
trt_edgellm::rt::Tensor const &newCol,
trt_edgellm::rt::Tensor const &weight,
trt_edgellm::rt::OptionalInputTensor bias,
trt_edgellm::rt::Tensor &out,
cudaStream_t stream
)#

Decode-mode conv1d: shift conv_state, insert new column, and compute dot product.

convState: [batch, dim, width] (in-place update) newCol: [batch, 1, dim] (new single-token input) weight: [dim, 1, width] bias: [dim] (optional) out: [batch, 1, dim]

void mamba_ssm::invokeCausalConv1dDecodeMTP(
trt_edgellm::rt::Tensor &convState,
trt_edgellm::rt::Tensor const &newCols,
trt_edgellm::rt::Tensor const &weight,
trt_edgellm::rt::OptionalInputTensor bias,
trt_edgellm::rt::Tensor &out,
trt_edgellm::rt::Tensor &intermediateConvStates,
int32_t T,
cudaStream_t stream
)#

MTP (multi-token) decode: process T draft tokens with per-step state checkpointing.

For each draft token t in [0, T):

  1. Shift conv_state left by 1, insert newCols[:, t, :]

  2. Compute output = dot(conv_state, weight) + bias

  3. Save intermediate conv_state to intermediateConvStates[:, t, :, :]

convState: [batch, dim, width] FP16 (in-place updated to final state) newCols: [batch, T, dim] FP16 (T draft token inputs) weight: [dim, 1, width] FP16 bias: [dim] FP16 (optional) out: [batch, T, dim] FP16 (T outputs) intermediateConvStates: [batch, T, dim, width] FP16 (per-step state cache for rollback) T: number of draft tokens