KV Cache Connector#
Source NVIDIA/TensorRT-LLM.
1'''
2This script demonstrates the KV cache connector feature in TensorRT-LLM, which enables
3custom persistence and reuse of KV cache blocks across different LLM instances.
4
5**Scenario:**
6The script implements a persistent KV cache connector that saves computed KV cache blocks
7to disk and loads them back in subsequent runs, eliminating redundant computation for
8recurring prompts.
9
10**What is a KV Cache Connector?**
11
12A KV cache connector is a customizable interface that allows you to:
131. **Save KV Cache:** Persist computed KV cache blocks to an external storage
14 (disk, database, distributed cache, etc.)
152. **Load KV Cache:** Retrieve previously computed cache blocks instead of recomputing them
163. **Share Cache Across Instances:** Reuse cache blocks across different LLM instances
17 or sessions, unlike regular block reuse which is limited to a single instance
18
19**How It Works:**
20
21This example implements a `PersistentKvCacheConnector` with two key components:
22
23* **PersistentKvCacheConnectorLeader (Scheduler):**
24 - Hashes token sequences to create unique identifiers for each cache block
25 - Checks if cached blocks exist on disk for incoming requests
26 - Schedules load operations for cache hits
27 - Schedules save operations for newly computed blocks
28
29* **PersistentKvCacheConnectorWorker:**
30 - Executes the actual load/save operations between GPU and disk
31 - Loads cached blocks from disk files into GPU memory
32 - Saves newly computed blocks from GPU to disk files
33
34**Demonstration:**
35
36The script processes the same prompt twice using two separate LLM instances:
37
381. **First Run (Instance 1):**
39 - The LLM computes the KV cache for the input prompt
40 - The connector saves the computed cache blocks to disk (as .pt files)
41 - The generation completes and the LLM instance is destroyed
42
432. **Second Run (Instance 2):**
44 - A new LLM instance is created with the same connector configuration
45 - When processing the same prompt, the connector finds matching cache blocks on disk
46 - The cache is loaded from disk instead of being recomputed
47 - **Expected Outcome:** Faster prefill as cache blocks are loaded rather than computed
48 - Both outputs should be identical, demonstrating deterministic cache reuse
49
50**Key Benefits:**
51
52- **Cross-Instance Cache Sharing:** Share computed caches across multiple LLM instances
53- **Persistent Storage:** Cache survives beyond the lifetime of a single LLM instance
54- **Custom Storage Backends:** Implement any storage mechanism (shown here: disk files)
55- **Reduced Computation:** Eliminate redundant KV cache computation for repeated prompts
56
57**How to Run:**
58
59```bash
60python llm_kv_cache_connector.py <model_path>
61```
62
63Example:
64```bash
65python llm_kv_cache_connector.py meta-llama/Llama-3.1-8B-Instruct
66```
67
68**Implementation Notes:**
69
70- This example uses content-based hashing to identify cache blocks
71- Cache files are stored in a temporary directory (cleaned up after the demo)
72- The implementation is simplified and not optimized for production use
73- Does not support chunked prefill in this example
74- See `tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py` for the full connector interface
75
76**NOTE:** This example connector implementation is designed for demonstration purposes
77and is NOT suitable for production use without additional optimizations and error handling.
78'''
79
80import os
81import sys
82from dataclasses import dataclass, field
83from pathlib import Path
84from tempfile import TemporaryDirectory
85from typing import Optional
86
87import click
88import torch
89
90from tensorrt_llm import LLM, SamplingParams, logger
91from tensorrt_llm._torch.pyexecutor.connectors.kv_cache_connector import (
92 KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
93from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
94from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
95
96CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
97
98
99@dataclass
100class PersistentKvCacheConnectorMetadata:
101 load: list[tuple[str, int]] = field(default_factory=list)
102 save: list[tuple[str, int]] = field(default_factory=list)
103
104
105class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
106
107 def __init__(self, llm_args: TorchLlmArgs):
108 super().__init__(llm_args)
109
110 self.kv_cache_tensor = None
111
112 def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
113 assert self.kv_cache_tensor is None, "KV cache tensor already registered"
114 self.kv_cache_tensor = kv_cache_tensor
115
116 def start_load_kv(self, stream: torch.cuda.Stream):
117 # Do all loads synchronously, and blockwise.
118 for path, block_id in self._metadata.load:
119 cpu_tensor = torch.load(path, map_location="cpu")
120
121 # Copy into the device block.
122 self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
123
124 def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
125 pass
126
127 def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
128 pass
129
130 def wait_for_save(self, stream: torch.cuda.Stream):
131
132 # Make sure the forward pass is complete before beginning our save.
133 stream.synchronize()
134
135 for path, block_id in self._metadata.save:
136 cpu_tensor = self.kv_cache_tensor[block_id].cpu()
137
138 # Don't write anything if this specific block already exists.
139 if Path(path).exists():
140 continue
141
142 # Do a blocking save to the file. This way, we only return once all saves are complete.
143 torch.save(cpu_tensor, path)
144
145 def get_finished(
146 self, finished_gen_req_ids: list[int],
147 started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
148
149 return [], []
150
151
152class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
153
154 def __init__(self, llm_args: TorchLlmArgs):
155 super().__init__(llm_args)
156
157 self.block_size = self._llm_args.kv_cache_config.tokens_per_block
158 self.pending_loads = {}
159
160 self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
161 "./connector_cache")
162
163 os.makedirs(self.cache_folder, exist_ok=True)
164
165 def build_connector_meta(self, scheduler_output: SchedulerOutput):
166 # NOTE: This is a simplified implementation, and does not work with chunked prefill.
167
168 metadata = PersistentKvCacheConnectorMetadata()
169
170 for req in scheduler_output.new_requests:
171 # If we don't have any pending loads for this request, we can skip it.
172 if req.request_id not in self.pending_loads:
173 continue
174
175 num_computed_blocks = req.computed_position // self.block_size
176 block_ids = req.new_block_ids
177
178 pending_load = self.pending_loads[req.request_id]
179
180 for file_path, block_pos in zip(
181 pending_load, range(num_computed_blocks, len(block_ids))):
182 metadata.load.append((file_path, block_ids[block_pos]))
183
184 # Break up the remainder of the token sequence into chunks.
185 chunks = self._chunk_tokens(req.new_tokens)
186
187 # For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
188 for block_pos in range(num_computed_blocks + len(pending_load),
189 len(block_ids)):
190 if len(chunks[block_pos]) == self.block_size:
191 hashed_tokens = self._hash_tokens(chunks[block_pos],
192 req.cache_salt_id)
193
194 file_path = self._file_path(hashed_tokens)
195
196 metadata.save.append((file_path, block_ids[block_pos]))
197
198 self.pending_loads = {}
199
200 return metadata
201
202 def _hash_tokens(self, tokens: list[int],
203 cache_salt_id: Optional[int]) -> int:
204 # cache_salt_id must participate in the hash so that requests carrying
205 # different salts (or no salt) cannot collide on the same cache file.
206 return abs(hash((cache_salt_id, tuple(tokens))))
207
208 def _file_path(self, hash_value: int) -> Path:
209 return Path(self.cache_folder) / f"{hash_value}.pt"
210
211 def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
212 return [
213 tokens[i:i + self.block_size]
214 for i in range(0, len(tokens), self.block_size)
215 ]
216
217 def get_num_new_matched_tokens(
218 self, request: LlmRequest,
219 num_computed_tokens: int) -> tuple[int, bool]:
220 self.pending_loads[request.request_id] = []
221
222 # Don't bother with sequences with partial matches.
223 if (num_computed_tokens % self.block_size) != 0:
224 return 0, False
225
226 computed_blocks = num_computed_tokens // self.block_size
227
228 # Get all the tokens that don't have a cache hit on device.
229 remaining_tokens = request.get_tokens(0)[computed_blocks *
230 self.block_size:]
231
232 remaining_chunks = self._chunk_tokens(remaining_tokens)
233
234 # For each chunk, check if it exists in our cache.
235 for chunk in remaining_chunks:
236 # Only do full blocks.
237 if len(chunk) == self.block_size:
238 hashed_tokens = self._hash_tokens(chunk, request.cache_salt_id)
239
240 file_path = self._file_path(hashed_tokens)
241
242 # If we get a cache hit, we want to load it into device.
243 # Otherwise, we can stop looking.
244 if file_path.exists():
245 self.pending_loads[request.request_id].append(file_path)
246 else:
247 break
248
249 logger.info(
250 f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
251 )
252
253 return len(
254 self.pending_loads[request.request_id]) * self.block_size, False
255
256 def request_finished(self, request: LlmRequest,
257 cache_block_ids: list[int]) -> bool:
258 # We don't do any asynchronous saving, so always return False
259 return False
260
261 def update_state_after_alloc(self, request: LlmRequest,
262 block_ids: list[int]):
263 pass
264
265
266@click.command()
267@click.argument("model", type=str)
268def main(model: str):
269 sys.path.append(os.path.join(
270 os.path.dirname(__file__),
271 "..",
272 ))
273
274 this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
275
276 # --- KV Cache Connector Config ---
277 kv_connector_config = KvCacheConnectorConfig(
278 connector_module=this_module,
279 connector_scheduler_class="PersistentKvCacheConnectorLeader",
280 connector_worker_class="PersistentKvCacheConnectorWorker",
281 )
282
283 connector_cache_dir = TemporaryDirectory()
284 os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
285
286 # Create LLM instance with KV Cache Connector
287 llm = LLM(model=model,
288 backend="pytorch",
289 cuda_graph_config=None,
290 kv_connector_config=kv_connector_config)
291
292 test_text = (
293 "Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
294 "Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
295 "system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
296 "and mobile and automotive applications. Tell me about the company.")
297
298 sampling_params = SamplingParams(max_tokens=32)
299
300 # Generate text with the first LLM instance and save the kv cache blocks by the connector.
301 output = llm.generate([test_text], sampling_params)
302 text0 = output[0].outputs[0].text
303
304 print("First output: ", text0)
305 print("Loading new LLM instance...")
306
307 del llm
308
309 # Create a new LLM instance with the same connector configuration
310 llm = LLM(model=model,
311 backend="pytorch",
312 cuda_graph_config=None,
313 kv_connector_config=kv_connector_config)
314
315 # Generate text with the second LLM instance and it should reuse the kv cache blocks from the connector.
316 output = llm.generate([test_text], sampling_params)
317 text1 = output[0].outputs[0].text
318
319 print("Second output (using connector cache): ", text1)
320
321 # Verify that the two outputs are identical
322 assert text0 == text1
323
324 connector_cache_dir.cleanup()
325
326
327if __name__ == "__main__":
328 main()