Mamba Cache Manager#
-
class MambaCacheManager#
Per-layer Mamba (recurrent + conv) state manager. Each recurrent layer owns two device tensors:
recurrent state: [maxBatchSize, recurrentStateNumHeads, recurrentStateHeadDim, recurrentStateSize]
conv state: [maxBatchSize, convDim, convKernel] When numRecurrentLayers == 0 the manager is a no-op and allocates nothing.
Public Functions
-
MambaCacheManager() noexcept = default#
Default constructor (no allocation)
-
MambaCacheManager(Config const &config, cudaStream_t stream)#
Construct and initialize per-layer Mamba state buffers.
If numRecurrentLayers == 0 the constructor returns immediately without allocating. Otherwise it allocates one recurrent-state tensor and one conv-state tensor per layer, zero-initialises them, and logs total GPU memory consumed.
- Parameters:
config – Cache configuration
stream – CUDA stream for allocation and memset
- Throws:
std::runtime_error – if config validation fails
-
~MambaCacheManager() noexcept#
Destructor.
-
MambaCacheManager(MambaCacheManager const&) = delete#
Deleted copy constructor to avoid large data copy.
-
MambaCacheManager &operator=(MambaCacheManager const&) = delete#
Deleted copy assignment to avoid large data copy.
- Returns:
Reference to this
-
MambaCacheManager(MambaCacheManager&&) noexcept#
Move constructor.
-
MambaCacheManager &operator=(MambaCacheManager&&) noexcept#
Move assignment operator.
- Returns:
Reference to this
-
rt::Tensor &getRecurrentState(int32_t recurrentLayerIdx) noexcept#
Get the recurrent state tensor for a given recurrent layer (owned tensor reference). Shape: [maxBatchSize, recurrentStateNumHeads, recurrentStateHeadDim, recurrentStateSize]
- Parameters:
recurrentLayerIdx – The recurrent layer index.
- Returns:
A reference to the owned device tensor.
-
rt::Tensor &getConvState(int32_t recurrentLayerIdx) noexcept#
Get the conv state tensor for a given recurrent layer (owned tensor reference). Shape: [maxBatchSize, convDim, convKernel]
- Parameters:
recurrentLayerIdx – The recurrent layer index.
- Returns:
A reference to the owned device tensor.
-
void clearStates(cudaStream_t stream)#
Zero all recurrent and conv state buffers (all layers, all batch slots). Called after warmup inference and before CUDA graph capture to ensure a clean starting state. No-op when numRecurrentLayers == 0.
- Parameters:
stream – CUDA stream for memset operations.
- std::vector<rt::Tensor> captureRecurrentStates(
- int32_t batchIdx,
- cudaStream_t stream
Copy one batch slot’s recurrent states into freshly-allocated tensors (one per layer). Used to snapshot states when saving a system prompt cache entry. Returns an empty vector when numRecurrentLayers == 0.
- Parameters:
batchIdx – The batch slot index to capture.
stream – CUDA stream for copy operations.
- Returns:
Vector of device tensors with shape [1, numHeads, headDim, stateSize].
- std::vector<rt::Tensor> captureConvStates(
- int32_t batchIdx,
- cudaStream_t stream
Copy one batch slot’s conv states into freshly-allocated tensors (one per layer). Used to snapshot states when saving a system prompt cache entry. Returns an empty vector when numRecurrentLayers == 0.
- Parameters:
batchIdx – The batch slot index to capture.
stream – CUDA stream for copy operations.
- Returns:
Vector of device tensors with shape [1, convDim, convKernel].
-
int32_t numLayers() const noexcept#
Get the number of recurrent layers.
- Returns:
Number of recurrent layers
-
Config const &getConfig() const noexcept#
Get cache configuration.
- Returns:
Cache configuration