KV Cache Connector#
Source NVIDIA/TensorRT-LLM.
1'''
2This script demonstrates the KV cache connector feature in TensorRT-LLM, which enables
3custom persistence and reuse of KV cache blocks across different LLM instances.
4
5**Scenario:**
6The script implements a persistent KV cache connector that saves computed KV cache blocks
7to disk and loads them back in subsequent runs, eliminating redundant computation for
8recurring prompts.
9
10**What is a KV Cache Connector?**
11
12A KV cache connector is a customizable interface that allows you to:
131. **Save KV Cache:** Persist computed KV cache blocks to an external storage
14 (disk, database, distributed cache, etc.)
152. **Load KV Cache:** Retrieve previously computed cache blocks instead of recomputing them
163. **Share Cache Across Instances:** Reuse cache blocks across different LLM instances
17 or sessions, unlike regular block reuse which is limited to a single instance
18
19**How It Works:**
20
21This example implements a `PersistentKvCacheConnector` with two key components:
22
23* **PersistentKvCacheConnectorLeader (Scheduler):**
24 - Hashes token sequences to create unique identifiers for each cache block
25 - Checks if cached blocks exist on disk for incoming requests
26 - Schedules load operations for cache hits
27 - Schedules save operations for newly computed blocks
28
29* **PersistentKvCacheConnectorWorker:**
30 - Executes the actual load/save operations between GPU and disk
31 - Loads cached blocks from disk files into GPU memory
32 - Saves newly computed blocks from GPU to disk files
33
34**Demonstration:**
35
36The script processes the same prompt twice using two separate LLM instances:
37
381. **First Run (Instance 1):**
39 - The LLM computes the KV cache for the input prompt
40 - The connector saves the computed cache blocks to disk (as .pt files)
41 - The generation completes and the LLM instance is destroyed
42
432. **Second Run (Instance 2):**
44 - A new LLM instance is created with the same connector configuration
45 - When processing the same prompt, the connector finds matching cache blocks on disk
46 - The cache is loaded from disk instead of being recomputed
47 - **Expected Outcome:** Faster prefill as cache blocks are loaded rather than computed
48 - Both outputs should be identical, demonstrating deterministic cache reuse
49
50**Key Benefits:**
51
52- **Cross-Instance Cache Sharing:** Share computed caches across multiple LLM instances
53- **Persistent Storage:** Cache survives beyond the lifetime of a single LLM instance
54- **Custom Storage Backends:** Implement any storage mechanism (shown here: disk files)
55- **Reduced Computation:** Eliminate redundant KV cache computation for repeated prompts
56
57**How to Run:**
58
59```bash
60python llm_kv_cache_connector.py <model_path>
61```
62
63Example:
64```bash
65python llm_kv_cache_connector.py meta-llama/Llama-3.1-8B-Instruct
66```
67
68**Implementation Notes:**
69
70- This example uses content-based hashing to identify cache blocks
71- Cache files are stored in a temporary directory (cleaned up after the demo)
72- The implementation is simplified and not optimized for production use
73- Does not support chunked prefill in this example
74- See `tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py` for the full connector interface
75
76**NOTE:** This example connector implementation is designed for demonstration purposes
77and is NOT suitable for production use without additional optimizations and error handling.
78'''
79
80import os
81import sys
82from dataclasses import dataclass, field
83from pathlib import Path
84from tempfile import TemporaryDirectory
85
86import click
87import torch
88
89from tensorrt_llm import LLM, SamplingParams, logger
90from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
91 KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
92from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
93from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
94
95CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
96
97
98@dataclass
99class PersistentKvCacheConnectorMetadata:
100 load: list[tuple[str, int]] = field(default_factory=list)
101 save: list[tuple[str, int]] = field(default_factory=list)
102
103
104class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
105
106 def __init__(self, llm_args: TorchLlmArgs):
107 super().__init__(llm_args)
108
109 self.kv_cache_tensor = None
110
111 def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
112 assert self.kv_cache_tensor is None, "KV cache tensor already registered"
113 self.kv_cache_tensor = kv_cache_tensor
114
115 def start_load_kv(self, stream: torch.cuda.Stream):
116 # Do all loads synchronously, and blockwise.
117 for path, block_id in self._metadata.load:
118 cpu_tensor = torch.load(path, map_location="cpu")
119
120 # Copy into the device block.
121 self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
122
123 def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
124 pass
125
126 def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
127 pass
128
129 def wait_for_save(self, stream: torch.cuda.Stream):
130
131 # Make sure the forward pass is complete before beginning our save.
132 stream.synchronize()
133
134 for path, block_id in self._metadata.save:
135 cpu_tensor = self.kv_cache_tensor[block_id].cpu()
136
137 # Don't write anything if this specific block already exists.
138 if Path(path).exists():
139 continue
140
141 # Do a blocking save to the file. This way, we only return once all saves are complete.
142 torch.save(cpu_tensor, path)
143
144 def get_finished(
145 self, finished_gen_req_ids: list[int],
146 started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
147
148 return [], []
149
150
151class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
152
153 def __init__(self, llm_args: TorchLlmArgs):
154 super().__init__(llm_args)
155
156 self.block_size = self._llm_args.kv_cache_config.tokens_per_block
157 self.pending_loads = {}
158
159 self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
160 "./connector_cache")
161
162 os.makedirs(self.cache_folder, exist_ok=True)
163
164 def build_connector_meta(self, scheduler_output: SchedulerOutput):
165 # NOTE: This is a simplified implementation, and does not work with chunked prefill.
166
167 metadata = PersistentKvCacheConnectorMetadata()
168
169 for req in scheduler_output.new_requests:
170 # If we don't have any pending loads for this request, we can skip it.
171 if req.request_id not in self.pending_loads:
172 continue
173
174 num_computed_blocks = req.computed_position // self.block_size
175 block_ids = req.new_block_ids
176
177 pending_load = self.pending_loads[req.request_id]
178
179 for file_path, block_pos in zip(
180 pending_load, range(num_computed_blocks, len(block_ids))):
181 metadata.load.append((file_path, block_ids[block_pos]))
182
183 # Break up the remainder of the token sequence into chunks.
184 chunks = self._chunk_tokens(req.new_tokens)
185
186 # For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
187 for block_pos in range(num_computed_blocks + len(pending_load),
188 len(block_ids)):
189 if len(chunks[block_pos]) == self.block_size:
190 hashed_tokens = self._hash_tokens(chunks[block_pos])
191
192 file_path = self._file_path(hashed_tokens)
193
194 metadata.save.append((file_path, block_ids[block_pos]))
195
196 self.pending_loads = {}
197
198 return metadata
199
200 def _hash_tokens(self, tokens: list[int]) -> int:
201 return abs(hash(tuple(tokens)))
202
203 def _file_path(self, hash_value: int) -> Path:
204 return Path(self.cache_folder) / f"{hash_value}.pt"
205
206 def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
207 return [
208 tokens[i:i + self.block_size]
209 for i in range(0, len(tokens), self.block_size)
210 ]
211
212 def get_num_new_matched_tokens(
213 self, request: LlmRequest,
214 num_computed_tokens: int) -> tuple[int, bool]:
215 self.pending_loads[request.request_id] = []
216
217 # Don't bother with sequences with partial matches.
218 if (num_computed_tokens % self.block_size) != 0:
219 return 0, False
220
221 computed_blocks = num_computed_tokens // self.block_size
222
223 # Get all the tokens that don't have a cache hit on device.
224 remaining_tokens = request.get_tokens(0)[computed_blocks *
225 self.block_size:]
226
227 remaining_chunks = self._chunk_tokens(remaining_tokens)
228
229 # For each chunk, check if it exists in our cache.
230 for chunk in remaining_chunks:
231 # Only do full blocks.
232 if len(chunk) == self.block_size:
233 hashed_tokens = self._hash_tokens(chunk)
234
235 file_path = self._file_path(hashed_tokens)
236
237 # If we get a cache hit, we want to load it into device.
238 # Otherwise, we can stop looking.
239 if file_path.exists():
240 self.pending_loads[request.request_id].append(file_path)
241 else:
242 break
243
244 logger.info(
245 f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
246 )
247
248 return len(
249 self.pending_loads[request.request_id]) * self.block_size, False
250
251 def request_finished(self, request: LlmRequest,
252 cache_block_ids: list[int]) -> bool:
253 # We don't do any asynchronous saving, so always return False
254 return False
255
256 def update_state_after_alloc(self, request: LlmRequest,
257 block_ids: list[int]):
258 pass
259
260
261@click.command()
262@click.argument("model", type=str)
263def main(model: str):
264 sys.path.append(os.path.join(
265 os.path.dirname(__file__),
266 "..",
267 ))
268
269 this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
270
271 # --- KV Cache Connector Config ---
272 kv_connector_config = KvCacheConnectorConfig(
273 connector_module=this_module,
274 connector_scheduler_class="PersistentKvCacheConnectorLeader",
275 connector_worker_class="PersistentKvCacheConnectorWorker",
276 )
277
278 connector_cache_dir = TemporaryDirectory()
279 os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
280
281 # Create LLM instance with KV Cache Connector
282 llm = LLM(model=model,
283 backend="pytorch",
284 cuda_graph_config=None,
285 kv_connector_config=kv_connector_config)
286
287 test_text = (
288 "Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
289 "Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
290 "system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
291 "and mobile and automotive applications. Tell me about the company.")
292
293 sampling_params = SamplingParams(max_tokens=32)
294
295 # Generate text with the first LLM instance and save the kv cache blocks by the connector.
296 output = llm.generate([test_text], sampling_params)
297 text0 = output[0].outputs[0].text
298
299 print("First output: ", text0)
300 print("Loading new LLM instance...")
301
302 del llm
303
304 # Create a new LLM instance with the same connector configuration
305 llm = LLM(model=model,
306 backend="pytorch",
307 cuda_graph_config=None,
308 kv_connector_config=kv_connector_config)
309
310 # Generate text with the second LLM instance and it should reuse the kv cache blocks from the connector.
311 output = llm.generate([test_text], sampling_params)
312 text1 = output[0].outputs[0].text
313
314 print("Second output (using connector cache): ", text1)
315
316 # Verify that the two outputs are identical
317 assert text0 == text1
318
319 connector_cache_dir.cleanup()
320
321
322if __name__ == "__main__":
323 main()