Mamba Cache Manager#

class MambaCacheManager#

Per-layer Mamba (recurrent + conv) state manager. Each recurrent layer owns two device tensors:

  • recurrent state: [maxBatchSize, recurrentStateNumHeads, recurrentStateHeadDim, recurrentStateSize]

  • conv state: [maxBatchSize, convDim, convKernel] When numRecurrentLayers == 0 the manager is a no-op and allocates nothing.

Public Functions

MambaCacheManager() noexcept = default#

Default constructor (no allocation)

MambaCacheManager(Config const &config, cudaStream_t stream)#

Construct and initialize per-layer Mamba state buffers.

If numRecurrentLayers == 0 the constructor returns immediately without allocating. Otherwise it allocates one recurrent-state tensor and one conv-state tensor per layer, zero-initialises them, and logs total GPU memory consumed.

Parameters:
  • config – Cache configuration

  • stream – CUDA stream for allocation and memset

Throws:

std::runtime_error – if config validation fails

~MambaCacheManager() noexcept#

Destructor.

MambaCacheManager(MambaCacheManager const&) = delete#

Deleted copy constructor to avoid large data copy.

MambaCacheManager &operator=(MambaCacheManager const&) = delete#

Deleted copy assignment to avoid large data copy.

Returns:

Reference to this

MambaCacheManager(MambaCacheManager&&) noexcept#

Move constructor.

MambaCacheManager &operator=(MambaCacheManager&&) noexcept#

Move assignment operator.

Returns:

Reference to this

rt::Tensor &getRecurrentState(int32_t recurrentLayerIdx) noexcept#

Get the recurrent state tensor for a given recurrent layer (owned tensor reference). Shape: [maxBatchSize, recurrentStateNumHeads, recurrentStateHeadDim, recurrentStateSize]

Parameters:

recurrentLayerIdx – The recurrent layer index.

Returns:

A reference to the owned device tensor.

rt::Tensor &getConvState(int32_t recurrentLayerIdx) noexcept#

Get the conv state tensor for a given recurrent layer (owned tensor reference). Shape: [maxBatchSize, convDim, convKernel]

Parameters:

recurrentLayerIdx – The recurrent layer index.

Returns:

A reference to the owned device tensor.

void clearStates(cudaStream_t stream)#

Zero all recurrent and conv state buffers (all layers, all batch slots). Called after warmup inference and before CUDA graph capture to ensure a clean starting state. No-op when numRecurrentLayers == 0.

Parameters:

stream – CUDA stream for memset operations.

std::vector<rt::Tensor> captureRecurrentStates(
int32_t batchIdx,
cudaStream_t stream
)#

Copy one batch slot’s recurrent states into freshly-allocated tensors (one per layer). Used to snapshot states when saving a system prompt cache entry. Returns an empty vector when numRecurrentLayers == 0.

Parameters:
  • batchIdx – The batch slot index to capture.

  • stream – CUDA stream for copy operations.

Returns:

Vector of device tensors with shape [1, numHeads, headDim, stateSize].

std::vector<rt::Tensor> captureConvStates(
int32_t batchIdx,
cudaStream_t stream
)#

Copy one batch slot’s conv states into freshly-allocated tensors (one per layer). Used to snapshot states when saving a system prompt cache entry. Returns an empty vector when numRecurrentLayers == 0.

Parameters:
  • batchIdx – The batch slot index to capture.

  • stream – CUDA stream for copy operations.

Returns:

Vector of device tensors with shape [1, convDim, convKernel].

rt::Tensor &getIntermediateRecurrentState(
int32_t recurrentLayerIdx
) noexcept#

Get MTP intermediate recurrent state tensor for a given layer. Shape: [maxBatchSize, maxIntermediateSeqLen, recurrentStateNumHeads, recurrentStateHeadDim, recurrentStateSize]

Parameters:

recurrentLayerIdx – The recurrent layer index.

Returns:

A reference to the owned device tensor.

rt::Tensor &getIntermediateConvState(
int32_t recurrentLayerIdx
) noexcept#

Get MTP intermediate conv state tensor for a given layer. Shape: [maxBatchSize, maxIntermediateSeqLen, convDim, convKernel]

Parameters:

recurrentLayerIdx – The recurrent layer index.

Returns:

A reference to the owned device tensor.

void reshapeIntermediateStates(
int32_t activeBatchSize,
int32_t seqLen
)#

Reshape MTP intermediate state tensors to actual runtime dimensions.

TRT writes intermediate state outputs contiguously as [activeBatchSize, seqLen, …], but the buffers are allocated at [maxBatchSize, maxIntermediateSeqLen, …]. Call this before any code that reads the tensor shape for stride calculations. No-op when intermediate states are not allocated.

Parameters:
  • activeBatchSize – Current active batch size

  • seqLen – Actual sequence length (e.g. verifyTreeSize)

bool hasIntermediateRecurrentStates() const noexcept#

Check if intermediate recurrent state buffers are allocated (MTP enabled)

bool hasIntermediateConvStates() const noexcept#

Check if intermediate conv state buffers are allocated (MTP enabled)

int32_t numLayers() const noexcept#

Get the number of recurrent layers.

Returns:

Number of recurrent layers

Config const &getConfig() const noexcept#

Get cache configuration.

Returns:

Cache configuration

void scatterMtpStates(
rt::Tensor const &acceptLengths,
cudaStream_t stream
)#

Scatter MTP intermediate states to main state pools after verify (one batched launch per state kind). No-op when MTP is not enabled.