# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Block KV cache for causal attention with a fixed-size local window."""
from dataclasses import dataclass, field
import torch
from torch import Tensor
from typing_extensions import Self
[docs]
@dataclass
class BlockKVCache:
"""
KV cache for causal attention with a fixed-size local window, CUDA-graph compatible.
Keys and values can have arbitrary shape ``[..., total_size, ...]``; the sequence
(rolling) dimension is given by ``seq_dim`` (dimension index, can be negative).
Layout along that dimension: [sink tokens | local window tokens]. Sink tokens are
never evicted; the local window rolls left as new chunks are added if full. Chunks are
non-overlapping: each update adds one chunk of ``chunk_size`` tokens at the
next logical position in the full sequence.
Note: Currently only supports ``total_size`` (``sink_size + window_size``) divisible by ``chunk_size``.
Phases:
- Filling: cache not yet full; tokens are written contiguously;
``cached_k()`` / ``cached_v()`` return only the valid prefix.
- Steady-state: cache full; each new chunk triggers a left-roll of the
local window and overwrites the rightmost positions;
``cached_k()`` / ``cached_v()`` return the full buffer.
The argument ``chunk_idx`` (0, 1, 2, ...) is the index of the new chunk in the full
sequence (not an index into the cache). If ``chunk_idx`` is greater than
the previous one, the chunk is appended (or, in steady-state, written after
the roll). If ``chunk_idx`` equals the previous one, the same cache positions
are overwritten.
Per-step usage:
1. before_update(chunk_idx) — prepare (roll local window if steady-state).
2. update(k, v) — write the new chunk's keys/values into the cache.
3. cached_k() / cached_v() — get cached keys/values for attention.
4. after_update(chunk_idx) — update internal bookkeeping.
"""
k_shape: tuple[int, ...]
"""Shape of the keys. Must be the same as the values shape except for the last dimension."""
v_shape: tuple[int, ...]
"""Shape of the values. Must be the same as the keys shape except for the last dimension."""
seq_dim: int
"""Sequence dimension that will be rolled. Can be negative."""
chunk_size: int
"""Number of tokens processed each time."""
window_size: int
"""Size of the local attention window (excluding sink tokens)."""
sink_size: int = 0
"""Number of sink tokens at the start of the cache that are never evicted. Defaults to 0."""
device: torch.device | str = torch.device("cuda")
"""Device to store the cache on."""
dtype: torch.dtype = torch.float16
"""Data type to store the cache in."""
_prev_chunk_idx: int = -1
"""Chunk index of the last written chunk; -1 when empty."""
_curr_chunk_idx: int | None = None
"""The index of the current chunk that is being processed. None when empty."""
_n_cached: int = 0
"""Number of valid tokens currently in the cache."""
_k: Tensor = field(init=False)
"""Cached keys. shape ``[..., total_size, ..., Dk]``, where the ``total_size`` is the length of the cache buffer at ``seq_dim`` dimension."""
_v: Tensor = field(init=False)
"""Cached values. shape ``[..., total_size, ..., Dv]``, where the ``total_size`` is the length of the cache buffer at ``seq_dim`` dimension."""
@property
def size(self) -> int:
"""Number of valid cached tokens visible to attention."""
if self._curr_chunk_idx is None:
return self._n_cached
return self._visible_end()
@property
def write_end(self) -> int:
"""Right edge of the current chunk in the physical cache layout."""
assert self._curr_chunk_idx is not None, (
"Must call before_update() before write_end"
)
return self.size
[docs]
@classmethod
def from_tensor(cls, k: Tensor, v: Tensor, seq_dim: int) -> Self:
"""Build a single-chunk cache pre-filled with the given key and value tensors."""
cache = cls(
k_shape=k.shape,
v_shape=v.shape,
seq_dim=seq_dim,
chunk_size=k.shape[seq_dim],
window_size=k.shape[seq_dim],
device=k.device,
dtype=k.dtype,
)
cache.before_update(0)
cache.update(k, v)
cache.after_update(0)
cache._curr_chunk_idx = 0
return cache
def __post_init__(self) -> None:
assert self.k_shape[:-1] == self.v_shape[:-1], (
"k and v must have the same shape except for the last dimension"
)
tensor_dim = len(self.k_shape)
assert -tensor_dim <= self.seq_dim < tensor_dim, (
f"seq_dim must be in [-{tensor_dim}, {tensor_dim}), got {self.seq_dim}"
)
# Normalize seq_dim to a non-negative index so downstream
# indexing math doesn't have to special-case negatives.
self.seq_dim = self.seq_dim if self.seq_dim >= 0 else self.seq_dim + tensor_dim
assert self.sink_size >= 0, "sink_size must be non-negative"
expected_length = self.sink_size + self.window_size
assert self.k_shape[self.seq_dim] == expected_length, (
f"k_shape[seq_dim] ({self.k_shape[self.seq_dim]}) must equal sink_size + window_size ({expected_length})"
)
assert (self.window_size + self.sink_size) % self.chunk_size == 0, (
f"window_size + sink_size ({self.window_size + self.sink_size}) must be divisible by chunk_size ({self.chunk_size})"
)
self._k = torch.empty(self.k_shape, device=self.device, dtype=self.dtype)
self._v = torch.empty(self.v_shape, device=self.device, dtype=self.dtype)
def _seq_slice(self, start: int | None, end: int | None) -> tuple[slice | int, ...]:
"""Return an index tuple selecting ``[start:end]`` on ``seq_dim`` and all elements elsewhere."""
idx: list[slice | int] = [slice(None)] * len(self.k_shape)
idx[self.seq_dim] = slice(start, end)
return tuple(idx)
def _roll_local_window_left(self) -> None:
"""Shift the local window left by chunk_size tokens (steady-state only)."""
total_size = self._k.shape[self.seq_dim]
assert total_size == self._n_cached, (
f"Expected full cache: {total_size=} != {self._n_cached=}"
)
tokens_to_keep = self.window_size - self.chunk_size
if tokens_to_keep > 0:
src_start = self.sink_size + self.chunk_size
src_end = total_size
dst_start = self.sink_size
dst_end = self.sink_size + tokens_to_keep
dst_slice = self._seq_slice(dst_start, dst_end)
src_slice = self._seq_slice(src_start, src_end)
self._k[dst_slice] = self._k[src_slice].clone()
self._v[dst_slice] = self._v[src_slice].clone()
def _current_chunk_overlaps_sink(self) -> bool:
assert self._curr_chunk_idx is not None, (
"Must call before_update() before checking sink overlap"
)
return (
self.sink_size > 0
and self._curr_chunk_idx * self.chunk_size < self.sink_size
)
def _current_write_bounds(self) -> tuple[int, int]:
"""Return the physical cache range written by the current update."""
assert self._curr_chunk_idx is not None, (
"Must call before_update() before computing write bounds"
)
total_size = self._k.shape[self.seq_dim]
assert self.chunk_size <= total_size, (
f"chunk_size ({self.chunk_size}) must be <= cache size ({total_size})"
)
if self._curr_chunk_idx == self._prev_chunk_idx + 1:
write_start = torch.sym_min(self._n_cached, total_size - self.chunk_size)
write_end = write_start + self.chunk_size
elif self._curr_chunk_idx == self._prev_chunk_idx:
write_end = torch.sym_min(self._n_cached, total_size)
write_start = torch.sym_max(write_end - self.chunk_size, 0)
else:
raise ValueError(
f"{self._curr_chunk_idx=} should be either {self._prev_chunk_idx + 1} or {self._prev_chunk_idx}."
)
return write_start, write_end
def _write_current_chunk(self, k: Tensor, v: Tensor) -> None:
"""Write the current chunk through a filling/steady compatible path."""
write_start, write_end = self._current_write_bounds()
read_start = 0
read_end = write_end - write_start
if (
self.sink_size > 0
and not self._current_chunk_overlaps_sink()
and write_start < self.sink_size
):
write_start = self.sink_size
keep_size = write_end - write_start
read_end = self.chunk_size
read_start = read_end - keep_size
sl_read = self._seq_slice(read_start, read_end)
sl_write = self._seq_slice(write_start, write_end)
self._k[sl_write] = k[sl_read]
self._v[sl_write] = v[sl_read]
def _visible_end(self) -> int:
"""Right edge of cached tokens visible to attention during this update."""
assert self._curr_chunk_idx is not None, (
"Must call before_update() before computing visible cache size"
)
total_size = self._k.shape[self.seq_dim]
if self._curr_chunk_idx == self._prev_chunk_idx + 1:
return torch.sym_min(self._n_cached + self.chunk_size, total_size)
if self._curr_chunk_idx == self._prev_chunk_idx:
return torch.sym_min(self._n_cached, total_size)
raise ValueError(
f"{self._curr_chunk_idx=} should be either {self._prev_chunk_idx + 1} or {self._prev_chunk_idx}."
)
[docs]
def is_steady_state(self) -> bool:
"""Return True if the cache is full (steady-state phase)."""
assert self._curr_chunk_idx is not None, (
"Must call before_update() before is_steady_state()"
)
total_size = self._k.shape[self.seq_dim]
is_full = total_size == self._n_cached
is_overlapping_with_sink = (
self.sink_size > 0
and self._curr_chunk_idx * self.chunk_size
< self.sink_size # start < sink_size
)
return is_full and not is_overlapping_with_sink
[docs]
def before_update(self, chunk_idx: int) -> None:
"""
Prepare the cache before writing new tokens.
If ``chunk_idx`` equals the previous chunk index, this is a no-op. Otherwise,
we expect the ``chunk_idx`` to be +1 from the previous chunk index. In this case,
we will roll the local window left if the cache is in steady-state, or no op
if the cache is in filling phase.
Args:
chunk_idx: Chunk index of the new chunk in the full sequence.
"""
assert self._curr_chunk_idx is None, (
"Must call after_update() before before_update()"
)
self._curr_chunk_idx = chunk_idx
if chunk_idx == self._prev_chunk_idx:
return
assert chunk_idx == self._prev_chunk_idx + 1, (
"Expected the new chunk_idx to be +1 from the previous chunk_idx, "
f"got {chunk_idx} != {self._prev_chunk_idx} + 1"
)
if self.is_steady_state():
self._roll_local_window_left()
[docs]
def update(self, k: Tensor, v: Tensor) -> None:
"""
Write the new chunk's keys and values into the cache.
Must be called after ``before_update()`` and before ``after_update()``.
Args:
k: Keys; shape must match cached keys except at seq_dim, where length must be chunk_size.
v: Values; shape must match cached values except at seq_dim, where length must be chunk_size.
"""
assert self._curr_chunk_idx is not None, (
"Must call before_update() before update()"
)
chunk_size_k = k.shape[self.seq_dim]
chunk_size_v = v.shape[self.seq_dim]
assert chunk_size_k == self.chunk_size, (
f"Expected input k to have chunk_size ({chunk_size_k}) at seq_dim ({self.seq_dim}), "
f"got {chunk_size_k} != {self.chunk_size}"
)
assert chunk_size_v == self.chunk_size, (
f"Expected input v to have chunk_size ({chunk_size_v}) at seq_dim ({self.seq_dim}), "
f"got {chunk_size_v} != {self.chunk_size}"
)
self._write_current_chunk(k, v)
[docs]
def after_update(self, chunk_idx: int) -> None:
"""
Finalize bookkeeping after writing new tokens.
Updates ``_prev_chunk_idx`` and, in filling phase, ``_n_cached``.
Args:
chunk_idx: The index of the new chunk in the full sequence.
"""
assert chunk_idx == self._curr_chunk_idx, (
f"Expected chunk_idx to be {self._curr_chunk_idx}, got {chunk_idx}"
)
if self._curr_chunk_idx == self._prev_chunk_idx + 1:
if self.is_steady_state():
pass
else:
self._n_cached += self.chunk_size
self._prev_chunk_idx += 1
elif self._curr_chunk_idx == self._prev_chunk_idx:
pass
else:
raise ValueError(
f"{self._curr_chunk_idx=} should be either {self._prev_chunk_idx + 1} or {self._prev_chunk_idx}."
)
self._curr_chunk_idx = None
[docs]
def cached_k(self) -> Tensor:
"""
Return cached keys for attention (valid prefix in filling phase, full buffer in steady-state).
"""
return self._k[self._seq_slice(0, self._visible_end())]
[docs]
def cached_v(self) -> Tensor:
"""
Return cached values for attention (valid prefix in filling phase, full buffer in steady-state).
"""
return self._v[self._seq_slice(0, self._visible_end())]
[docs]
def reset(self) -> None:
"""Reset the cache to its initial empty state."""
self._prev_chunk_idx = -1
self._n_cached = 0