Sparse Attention Development Guide#
This guide is for contributors adding a new sparse attention algorithm to TensorRT LLM. It walks through the framework hooks each algorithm plugs into and the registration steps needed for the runtime to pick up the new backend.
For the user-facing configuration surface, see Sparse Attention. For the design rationale and high-level architecture diagrams, see the Sparse Attention tech blog.
Two integration levels#
TensorRT LLM’s sparse attention algorithms fall into two categories.
Framework-level: the algorithm runs a prediction step that emits sparse indices, which are then consumed by a shared
AttentionOpto produce sparse KV cache updates and/or sparse attention computation. Examples: RocketKV (page-level, MQA/GQA), DSA (token-level, MLA).Kernel-level: sparsity is implemented entirely inside the attention kernel — there is no external prediction or gather step. The kernel decides what to skip from runtime values such as Softmax scores. Example: Skip Softmax Attention (BLASST). The only framework dependency is
sparse_attention_configplumbing for selecting the backend; everything else lives in the kernel.
This guide focuses primarily on the framework-level path. Kernel-level algorithms reuse the same configuration surface but skip the prediction and memory-management sections below.
Lowered sparse parameters#
Sparse attention has two configuration layers.
User-facing sparse configs live in
tensorrt_llm/llmapi/llm_args.pyfor LLM andtensorrt_llm/visual_gen/sparse_attention.pyfor VisualGen. They are the Python/YAML surface and may also merge data from checkpointconfig.json.Lowered sparse params live under
tensorrt_llm/_torch/attention_backend/sparse/. They are backend-owned runtime objects consumed by attention implementations and metadata builders.
The lowering boundary is intentional: AttentionBackend instances
should not keep or interpret user-facing config objects. Before an
attention backend is constructed, the model layer calls
to_sparse_params(...) on the user config. That method resolves
per-model, per-layer, checkpoint, and default values into an
algorithm-specific SparseParams dataclass, or returns None when the
algorithm should not apply to that layer. The resolved object is then
passed to create_attention(..., sparse_params=...) and stored on the
backend instance.
Algorithms that need sparse metadata, auxiliary buffers, or per-batch
runtime state also implement to_sparse_metadata_params(...). This
returns an algorithm-specific SparseMetadataParams object for
AttentionMetadata, analogous to how to_sparse_params(...) returns
SparseParams for AttentionBackend. Keep them separate: metadata
params describe allocation and runtime metadata state, while sparse
params describe per-attention-layer kernel or prediction behavior.
When adding a new algorithm, define concrete parameter dataclasses next to the backend implementation, implement the two lowering methods on the public config class, and make backend code consume only the lowered params.
Framework-level sparse attention#
Framework-level sparse attention primarily targets approaches that leverage token/sequence sparsity — for many queries only a small fraction of historical tokens meaningfully contribute to the output, and the framework exploits that in a GPU-friendly, structured way. The attention operator provides unified APIs for both sparse computation and sparse KV cache, so algorithm authors only need to identify the important query/key pairs; everything else (KV cache layout, kernel dispatch, page alignment) is handled by the framework.
It is built around three layers:
Prediction module — generates
sparse_kv_indices(which KV tokens to keep in cache) andsparse_attn_indices(which KV pages or tokens to attend to during compute).AttentionOp— consumes those indices via pre/post kernels and drives the core attention kernels. The op already understands page-level sparsity for MQA/GQA in the generation phase, token-level sparsity for MLA in both phases, and token-level KV compression in the context phase for MQA/GQA.Auxiliary memory subsystem — manages any extra pools (KT cache, indexer K cache, …) alongside the main KV cache.
Figure 1: Framework support for sparse attention in TensorRT LLM.
Architecturally, each sparse attention algorithm subclasses the shared
AttentionBackend and supplies its own sparse_kv_predict /
sparse_attn_predict implementations. Different attention layers
within a single model can use different backends, so a model can mix
sparse attention strategies layer by layer. The shared AttentionOp
performs the actual computation and is not modified by individual
algorithms.
The current capability matrix is:
Attention type |
Context phase |
Generation phase |
|---|---|---|
MQA / MHA / GQA |
sparse KV cache |
sparse computation (page-level) |
MLA |
sparse computation (token-level) |
sparse computation (token-level) |
Context-phase sparse computation for MQA/GQA and dynamic generation-phase KV eviction are tracked as future work.
Prediction hooks#
AttentionBackend exposes two prediction methods that algorithm-specific
subclasses override:
sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(q, k, metadata, **kwargs)
sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(q, k, metadata, **kwargs)
Different KV heads are allowed to emit different sparse index sets; Q heads that map to the same KV head share the KV head’s sparse pattern.
Algorithm implementations live under
tensorrt_llm/_torch/attention_backend/sparse/:
rocket.py,dsa.py— concrete algorithms.kernel.py— custom Triton kernels (importance scoring, Top-K).utils.py— dispatch helpers.
AttentionOp behavior#
Figure 2: Sparse attention operator workflow in TensorRT LLM.
For MQA/GQA, the op runs gatherKvPageOffsetsKernel before the
generation-phase attention kernel. It takes the (potentially unordered
or finer-grained) sparse indices and maps them to ordered, page-aligned
KV cache offsets, also producing an updated per-head effective KV
length. The downstream attention kernel reads only those pages. Today
MQA/GQA sparse computation is supported at block (page) granularity
in the generation phase only.
After context attention, updateSparseKvCacheAfterFmha post-processes
the KV cache: it selects the important KV tokens and rewrites the
corresponding K/V vectors in place to shrink the cache. The indices
must be sorted so the in-place gather is safe; this preserves
compatibility with features such as chunked prefill at the cost of an
extra write.
For sparse MLA, the kernel consumes token-level indices directly, so
gatherKvPageOffsetsKernel is bypassed — both context and generation
phases are supported at token granularity. The sparse MLA path
currently expects global KV cache pool addresses with token-level
offsets, not request-local logical positions. Sparse KV cache for MLA
is not yet supported.
Auxiliary memory pools#
Two paths exist for managing auxiliary tensors today; new algorithms
should prefer KVCacheManagerV2 when starting fresh.
KVCacheManagerV2(recommended for new work): Python-side, hierarchical, supports heterogeneous pools per layer with automatic coalescing within a lifecycle group. Adding an auxiliary pool only requires defining a per-layerAttentionLayerConfigandBufferConfig.KVCacheManager(legacy path used by RocketKV/DSA today): either inherit from it at the Python level (RocketKV’sRocketKVCacheManager), or integrate directly into the C++KVCacheManager(DSA’s indexer K cache). The Python path is faster to iterate on; the C++ path is required for KV cache reuse and disaggregated serving.
Note: algorithms that evict KV blocks generally cannot coexist with the standard KV cache block reuse, because eviction changes block contents per request. Low-rank-only approaches like DSA’s indexer K cache can still reuse blocks.
Adding a new framework-level algorithm#
The four steps below cover what the runtime needs in order to dispatch a new algorithm end-to-end. The order matches the natural development flow — config first, then prediction, then memory, then registration.
1. Configuration class#
Define a configuration class in tensorrt_llm/llmapi/llm_args.py
inheriting from BaseSparseAttentionConfig. Hold all user-tunable
parameters here and pick a unique algorithm discriminator literal.
class MySparseAttentionConfig(BaseSparseAttentionConfig):
algorithm: Literal["my_algo"] = "my_algo"
topk: int = 64
# ... other parameters
Add the new class to the discriminated SparseAttentionConfig union at
the bottom of the file.
2. Prediction module#
Create a new backend class inheriting from TrtllmAttention (or
VanillaAttention if appropriate) in
tensorrt_llm/_torch/attention_backend/sparse/. Override one or both
prediction methods.
sparse_kv_predict(self, q, k, metadata, **kwargs)
Behavior: return the indices of tokens to retain in the KV cache.
Outputs:
sparse_kv_indices: shape(nHeads, nTokens)— token indices on the sequence dimension, wherenHeadsis the number of KV heads andnTokensis the total selected tokens across the batch.sparse_kv_offsets: shape(nBatch + 1)— sample boundaries; the indices for headhand samplenaresparse_kv_indices[h, sparse_kv_offsets[n]:sparse_kv_offsets[n+1]].
Constraint: indices must be sorted so the post-attention in-place gather (
updateSparseKvCacheAfterFmha) is safe. The sort cost buys compatibility with chunked prefill and similar features.
sparse_attn_predict(self, q, k, metadata, **kwargs)
Behavior: return the sparse indices used by the generation-phase attention computation.
Outputs:
sparse_attn_indices: shape(nHeads, nBlocks)— block indices on the KV sequence dimension. Block size is set by the algorithm viasparse_attn_indices_block_size(arbitrary value supported).sparse_attn_offsets: shape(nBatch + 1)— same semantics as above.
Constraint: today only page-level granularity is supported for MQA/GQA sparse computation, and the generation-phase path uses TRTLLM-GEN kernels (NVIDIA Blackwell SM 100+).
Prediction is on the critical path and can dominate latency in low-latency scenarios. Plan for custom kernels (Triton or CUDA) rather than relying on generic PyTorch ops.
3. Auxiliary memory#
If the algorithm needs extra tensors beyond the main KV cache:
KVCacheManagerV2(preferred for new algorithms): define a per-layerAttentionLayerConfigand aBufferConfigfor the auxiliary buffer; the V2 manager groups layers by lifecycle and coalesces buffers automatically. No C++ changes required.Python-level custom manager (legacy
KVCacheManager): subclassKVCacheManager, reuseBlockManagerfor the auxiliary pool, and overrideget_cache_size_per_token/get_cache_bytes_per_tokenso the runtime allocates enough GPU memory, plusadd_dummy_requests/prepare_resourcesso the pool gets the right resources at request time. Easier to iterate; no KV cache reuse or disagg-serving.C++ integrated manager: extend the C++
KVCacheManageritself. Required for advanced features (KV cache reuse, disaggregated serving). Significantly higher implementation cost.
4. Registration and dispatch#
Register the new config + backend in
tensorrt_llm/_torch/attention_backend/sparse/utils.pyandtensorrt_llm/_torch/pyexecutor/_util.pyso the runtime routes requests to your backend when the config is present.If your algorithm exposes new C++ parameters, plumb them through
cpp/tensorrt_llm/thop/attentionOp.cppandcpp/tensorrt_llm/kernels/sparseAttentionKernels.h.
Kernel-level sparse attention#
Kernel-level algorithms reuse the same sparse_attention_config
selection but bypass the prediction and memory-management hooks
entirely. Implementation lives inside the attention kernel; the only
framework wiring is:
A new config subclass with its own
algorithmdiscriminator.A lowered
SparseParamsobject that carries the resolved kernel settings.A switch inside the attention backend (e.g.,
_torch/attention_backend/trtllm_gen.py) that reads the lowered params and enables the kernel-side fast path.
Skip Softmax Attention follows this pattern — see the BLASST tech blog for the kernel-side specifics.
Roadmap#
Sparse computation in context phase for MQA/MHA/GQA — extend framework coverage to context-phase sparse compute.
Dynamic eviction in generation phase — exploring block-level eviction as a compromise that keeps KV cache flexibility manageable.
Unified auxiliary memory management — let custom auxiliary pools inherit KV-cache features (reuse, offloading) by default.
Code refactoring — as more algorithms land, unify the framework-level scaffolding for maintainability.