Mamba Cache Manager#

class MambaCacheManager#

Per-layer Mamba (recurrent + conv) state manager. Each recurrent layer owns two device tensors:

  • recurrent state: [maxBatchSize, recurrentStateNumHeads, recurrentStateHeadDim, recurrentStateSize]

  • conv state: [maxBatchSize, convDim, convKernel] When numRecurrentLayers == 0 the manager is a no-op and allocates nothing.

Public Functions

MambaCacheManager() noexcept = default#

Default constructor (no allocation)

MambaCacheManager(Config const &config, cudaStream_t stream)#

Construct and initialize per-layer Mamba state buffers.

If numRecurrentLayers == 0 the constructor returns immediately without allocating. Otherwise it allocates one recurrent-state tensor and one conv-state tensor per layer, zero-initialises them, and logs total GPU memory consumed.

Parameters:
  • config – Cache configuration

  • stream – CUDA stream for allocation and memset

Throws:

std::runtime_error – if config validation fails

~MambaCacheManager() noexcept#

Destructor.

MambaCacheManager(MambaCacheManager const&) = delete#

Deleted copy constructor to avoid large data copy.

MambaCacheManager &operator=(MambaCacheManager const&) = delete#

Deleted copy assignment to avoid large data copy.

Returns:

Reference to this

MambaCacheManager(MambaCacheManager&&) noexcept#

Move constructor.

MambaCacheManager &operator=(MambaCacheManager&&) noexcept#

Move assignment operator.

Returns:

Reference to this

rt::Tensor &getRecurrentState(int32_t recurrentLayerIdx) noexcept#

Get the recurrent state tensor for a given recurrent layer (owned tensor reference). Shape: [maxBatchSize, recurrentStateNumHeads, recurrentStateHeadDim, recurrentStateSize]

Parameters:

recurrentLayerIdx – The recurrent layer index.

Returns:

A reference to the owned device tensor.

rt::Tensor &getConvState(int32_t recurrentLayerIdx) noexcept#

Get the conv state tensor for a given recurrent layer (owned tensor reference). Shape: [maxBatchSize, convDim, convKernel]

Parameters:

recurrentLayerIdx – The recurrent layer index.

Returns:

A reference to the owned device tensor.

void clearStates(cudaStream_t stream)#

Zero all recurrent and conv state buffers (all layers, all batch slots). Called after warmup inference and before CUDA graph capture to ensure a clean starting state. No-op when numRecurrentLayers == 0.

Parameters:

stream – CUDA stream for memset operations.

std::vector<rt::Tensor> captureRecurrentStates(
int32_t batchIdx,
cudaStream_t stream
)#

Copy one batch slot’s recurrent states into freshly-allocated tensors (one per layer). Used to snapshot states when saving a system prompt cache entry. Returns an empty vector when numRecurrentLayers == 0.

Parameters:
  • batchIdx – The batch slot index to capture.

  • stream – CUDA stream for copy operations.

Returns:

Vector of device tensors with shape [1, numHeads, headDim, stateSize].

std::vector<rt::Tensor> captureConvStates(
int32_t batchIdx,
cudaStream_t stream
)#

Copy one batch slot’s conv states into freshly-allocated tensors (one per layer). Used to snapshot states when saving a system prompt cache entry. Returns an empty vector when numRecurrentLayers == 0.

Parameters:
  • batchIdx – The batch slot index to capture.

  • stream – CUDA stream for copy operations.

Returns:

Vector of device tensors with shape [1, convDim, convKernel].

int32_t numLayers() const noexcept#

Get the number of recurrent layers.

Returns:

Number of recurrent layers

Config const &getConfig() const noexcept#

Get cache configuration.

Returns:

Cache configuration