KV Cache Connector#
The KV Cache Connector is a flexible interface in TensorRT-LLM that enables remote or external access to the Key-Value (KV) cache. It allows developers to implement custom logic for loading, saving, and managing KV cache blocks, extending the capabilities of the standard KV cache manager.
This document explains the KV Cache Connector architecture, common use cases, and provides a detailed walkthrough of the included example.
Use Cases#
The KV Cache Connector is designed to support a variety of advanced serving scenarios:
KV Cache Offloading: Move KV cache blocks from GPU memory to cheaper/larger storage (CPU RAM, NVMe SSD, or network storage) when they are not immediately needed, and reload them when required.
Custom Disaggregated Serving: Separate the prefill (context processing) and decode (token generation) phases onto different instances or machines. The connector can be used to transmit the KV cache generated during prefill to the decode instances.
KV Cache Sharing / P2P Transfer: Share KV cache states between different model instances or across peer-to-peer connections.
Architecture#
The connector architecture is split into two main components:
Scheduler (Leader): Responsible for orchestration. It decides what needs to be loaded or saved and builds metadata instructions. It runs only on the leader rank (rank 0).
Worker: Responsible for execution. It receives metadata from the scheduler and performs the actual data transfers (loading/saving) on the KV cache tensors. It runs on all ranks.
API Reference#
To implement a custom connector, you must subclass KvCacheConnectorScheduler and KvCacheConnectorWorker.
1. Scheduler (Leader) Interface (KvCacheConnectorScheduler)#
These methods run on the leader process and drive the connector’s behavior.
build_connector_meta(self, scheduler_output: SchedulerOutput) -> objectDescription: The core orchestration method. Called during the scheduling phase. It examines the current requests and decides which blocks need to be loaded from or saved to the external store.
Arguments:
scheduler_outputcontains information about new requests, blocks allocated, and current request states.Returns: An arbitrary metadata object (picklable) that describes the tasks for the workers. This object is broadcasted to all workers.
get_num_new_matched_tokens(self, request: LlmRequest, num_computed_tokens: int) -> tuple[int, bool]Description: Called when a new request arrives. It checks to see if any KV cache can be loaded from an external KV store.
Returns: A tuple
(num_tokens, is_async).num_tokensis the number of tokens found in the external cache.is_asyncindicates if the loading will happen asynchronously (background) or requires blocking.
request_finished(self, request: LlmRequest, cache_block_ids: list[int]) -> boolDescription: Called when a request completes generation.
Returns: A boolean indicating if an asynchronous save operation is underway. If
True, the system waits for the operation to complete before releasing the KV cache blocks.
update_state_after_alloc(self, request: LlmRequest, block_ids: list[int])Description: a callback to update internal state after KV cache blocks have been allocated for the prefill.
2. Worker Interface (KvCacheConnectorWorker)#
These methods run on all workers (GPU processes) and interact with the actual GPU data.
register_kv_caches(self, kv_cache_tensor: torch.Tensor)Description: Called at initialization. Provides the worker with the GPU KV cache tensors.
Arguments:
kv_cache_tensoris the underlying storage tensor for the KV cache.
start_load_kv(self, stream: torch.cuda.Stream)Description: Initiates the loading of KV blocks from the external source into the GPU memory.
Arguments:
streamis the CUDA stream where the forward pass is executed in.
wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream)Description: A synchronization point. Ensures that the KV cache for a specific layer is fully loaded before the model attempts to perform the forward pass on that layer.
save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream)Description: Triggers the saving of a specific layer’s KV cache.
wait_for_save(self, stream: torch.cuda.Stream)Description: A synchronization point to ensure all save operations are enqueued or completed.
get_finished(self, finished_gen_req_ids, started_loading_req_ids) -> tuple[list[int], list[int]]Description: Polled by the runtime to check the status of asynchronous operations.
Returns: Two lists of request IDs: those that have finished saving, and those that have finished loading.
Example Implementation#
The file examples/llm-api/llm_kv_cache_connector.py provides a reference implementation of a Persistent KV Cache.
Overview#
This example implements a file-system based KV cache.
1.Save: When a request finishes or needs to be swapped out, its KV blocks are saved to disk as .pt files.
2.Load: When a new request arrives with the same prompt prefix, the connector identifies the cached files and loads them back into GPU memory, skipping re-computation.
Implementation Details#
Metadata: The example defines a
PersistentKvCacheConnectorMetadatadataclass containing lists of(file_path, block_id)tuples for both loading and saving. This simple structure allows the Scheduler to tell the Worker exactly which file corresponds to which GPU block index.Hashing Strategy: The
PersistentKvCacheConnectorLeaderhashes the token sequence of a block to generate a unique filename (e.g.,hash_value.pt). This acts as the lookup key.Worker Logic:
start_load_kv: Iterates through the load list provided in the metadata, loads the.ptfile to CPU, and copies it to the specificblock_idin the GPU tensor.wait_for_save: Performs the reverse. It copies data from the GPUblock_idto CPU and saves it to disk usingtorch.save.
Limitations & Patterns#
This example illustrates the API mechanics but has several limitations that make it unsuitable for high-performance production use without modification:
Blocking I/O: The example uses
torch.loadandtorch.savesynchronously. In a real implementation, these should be offloaded to a background thread or asynchronous I/O handler to avoid stalling the GPU.Simplified Block Matching: The
get_num_new_matched_tokensimplementation in the example only matches full blocks. It does not handle partial cache hits.FileSystem Latency: Storing one file per block can create high filesystem overhead.
Usage#
To run the example:
python examples/llm-api/llm_kv_cache_connector.py <model_path>
The script demonstrates:
Generating text for a prompt (First run).
Destroying the LLM instance.
Creating a new LLM instance with the same connector config.
Generating text for the same prompt (Second run).
Asserting that the outputs match, proving the state was correctly restored from the disk cache.