KV Cache Connector#
Source NVIDIA/TensorRT-LLM.
1'''
2This script demonstrates the KV cache connector feature in TensorRT-LLM, which enables
3custom persistence and reuse of KV cache blocks across different LLM instances.
4
5**Scenario:**
6The script implements a persistent KV cache connector that saves computed KV cache blocks
7to disk and loads them back in subsequent runs, eliminating redundant computation for
8recurring prompts.
9
10**What is a KV Cache Connector?**
11
12A KV cache connector is a customizable interface that allows you to:
131. **Save KV Cache:** Persist computed KV cache blocks to an external storage
14 (disk, database, distributed cache, etc.)
152. **Load KV Cache:** Retrieve previously computed cache blocks instead of recomputing them
163. **Share Cache Across Instances:** Reuse cache blocks across different LLM instances
17 or sessions, unlike regular block reuse which is limited to a single instance
18
19**How It Works:**
20
21This example implements a `PersistentKvCacheConnector` with two key components:
22
23* **PersistentKvCacheConnectorLeader (Scheduler):**
24 - Hashes token sequences to create unique identifiers for each cache block
25 - Checks if cached blocks exist on disk for incoming requests
26 - Schedules load operations for cache hits
27 - Schedules save operations for newly computed blocks
28
29* **PersistentKvCacheConnectorWorker:**
30 - Executes the actual load/save operations between GPU and disk
31 - Loads cached blocks from disk files into GPU memory
32 - Saves newly computed blocks from GPU to disk files
33
34**Demonstration:**
35
36The script processes the same prompt twice using two separate LLM instances:
37
381. **First Run (Instance 1):**
39 - The LLM computes the KV cache for the input prompt
40 - The connector saves the computed cache blocks to disk (as .pt files)
41 - The generation completes and the LLM instance is destroyed
42
432. **Second Run (Instance 2):**
44 - A new LLM instance is created with the same connector configuration
45 - When processing the same prompt, the connector finds matching cache blocks on disk
46 - The cache is loaded from disk instead of being recomputed
47 - **Expected Outcome:** Faster prefill as cache blocks are loaded rather than computed
48 - Both outputs should be identical, demonstrating deterministic cache reuse
49
50**Key Benefits:**
51
52- **Cross-Instance Cache Sharing:** Share computed caches across multiple LLM instances
53- **Persistent Storage:** Cache survives beyond the lifetime of a single LLM instance
54- **Custom Storage Backends:** Implement any storage mechanism (shown here: disk files)
55- **Reduced Computation:** Eliminate redundant KV cache computation for repeated prompts
56
57**How to Run:**
58
59```bash
60python llm_kv_cache_connector.py <model_path>
61```
62
63Example:
64```bash
65python llm_kv_cache_connector.py meta-llama/Llama-3.1-8B-Instruct
66```
67
68**Implementation Notes:**
69
70- This example uses content-based hashing to identify cache blocks
71- Cache files are stored in a temporary directory (cleaned up after the demo)
72- The implementation is simplified and not optimized for production use
73- Does not support chunked prefill in this example
74- See `tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py` for the full connector interface
75
76**NOTE:** This example connector implementation is designed for demonstration purposes
77and is NOT suitable for production use without additional optimizations and error handling.
78'''
79
80import os
81import sys
82from dataclasses import dataclass, field
83from pathlib import Path
84from tempfile import TemporaryDirectory
85from typing import Optional
86
87import click
88import torch
89
90from tensorrt_llm import LLM, SamplingParams, logger
91from tensorrt_llm._torch.pyexecutor.connectors.kv_cache_connector import (
92 KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
93from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
94from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
95
96CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
97
98
99@dataclass
100class PersistentKvCacheConnectorMetadata:
101 load: list[tuple[str, int]] = field(default_factory=list)
102 save: list[tuple[str, int]] = field(default_factory=list)
103
104
105class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
106
107 def __init__(self, llm_args: TorchLlmArgs):
108 super().__init__(llm_args)
109
110 self.kv_cache_tensor = None
111
112 def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
113 assert self.kv_cache_tensor is None, "KV cache tensor already registered"
114 self.kv_cache_tensor = kv_cache_tensor
115
116 def start_load_kv(self, stream: torch.cuda.Stream):
117 # Do all loads synchronously, and blockwise.
118 for path, block_id in self._metadata.load:
119 cpu_tensor = torch.load(path, map_location="cpu")
120
121 # Copy into the device block.
122 self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
123
124 def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
125 pass
126
127 def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
128 pass
129
130 def wait_for_save(self, stream: torch.cuda.Stream):
131
132 # Make sure the forward pass is complete before beginning our save.
133 stream.synchronize()
134
135 for path, block_id in self._metadata.save:
136 cpu_tensor = self.kv_cache_tensor[block_id].cpu()
137
138 # Don't write anything if this specific block already exists.
139 if Path(path).exists():
140 continue
141
142 # Do a blocking save to the file. This way, we only return once all saves are complete.
143 torch.save(cpu_tensor, path)
144
145 def get_finished(
146 self, finished_gen_req_ids: list[int],
147 started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
148
149 return [], []
150
151
152class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
153
154 def __init__(self, llm_args: TorchLlmArgs):
155 super().__init__(llm_args)
156
157 self.block_size = self._llm_args.kv_cache_config.tokens_per_block
158 self.pending_loads = {}
159
160 self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
161 "./connector_cache")
162
163 os.makedirs(self.cache_folder, exist_ok=True)
164
165 def build_connector_meta(self, scheduler_output: SchedulerOutput):
166 # NOTE: This is a simplified implementation, and does not work with chunked prefill.
167
168 metadata = PersistentKvCacheConnectorMetadata()
169
170 for req in scheduler_output.new_requests:
171 # If we don't have any pending loads for this request, we can skip it.
172 if req.request_id not in self.pending_loads:
173 continue
174
175 num_computed_blocks = req.computed_position // self.block_size
176 block_ids = req.new_block_ids
177
178 pending_load = self.pending_loads[req.request_id]
179
180 for file_path, block_pos in zip(
181 pending_load, range(num_computed_blocks, len(block_ids))):
182 metadata.load.append((file_path, block_ids[block_pos]))
183
184 # Break up the remainder of the token sequence into chunks.
185 chunks = self._chunk_tokens(req.new_tokens)
186
187 # For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
188 for block_pos in range(num_computed_blocks + len(pending_load),
189 len(block_ids)):
190 if len(chunks[block_pos]) == self.block_size:
191 hashed_tokens = self._hash_tokens(chunks[block_pos],
192 req.cache_salt)
193
194 file_path = self._file_path(hashed_tokens)
195
196 metadata.save.append((file_path, block_ids[block_pos]))
197
198 self.pending_loads = {}
199
200 return metadata
201
202 def _hash_tokens(self, tokens: list[int], cache_salt: Optional[str]) -> int:
203 # cache_salt must participate in the hash so that requests carrying
204 # different salts (or no salt) cannot collide on the same cache file.
205 return abs(hash((cache_salt, tuple(tokens))))
206
207 def _file_path(self, hash_value: int) -> Path:
208 return Path(self.cache_folder) / f"{hash_value}.pt"
209
210 def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
211 return [
212 tokens[i:i + self.block_size]
213 for i in range(0, len(tokens), self.block_size)
214 ]
215
216 def get_num_new_matched_tokens(
217 self, request: LlmRequest,
218 num_computed_tokens: int) -> tuple[int, bool]:
219 self.pending_loads[request.request_id] = []
220
221 # Don't bother with sequences with partial matches.
222 if (num_computed_tokens % self.block_size) != 0:
223 return 0, False
224
225 computed_blocks = num_computed_tokens // self.block_size
226
227 # Get all the tokens that don't have a cache hit on device.
228 remaining_tokens = request.get_tokens(0)[computed_blocks *
229 self.block_size:]
230
231 remaining_chunks = self._chunk_tokens(remaining_tokens)
232
233 # For each chunk, check if it exists in our cache.
234 for chunk in remaining_chunks:
235 # Only do full blocks.
236 if len(chunk) == self.block_size:
237 hashed_tokens = self._hash_tokens(chunk, request.cache_salt)
238
239 file_path = self._file_path(hashed_tokens)
240
241 # If we get a cache hit, we want to load it into device.
242 # Otherwise, we can stop looking.
243 if file_path.exists():
244 self.pending_loads[request.request_id].append(file_path)
245 else:
246 break
247
248 logger.info(
249 f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
250 )
251
252 return len(
253 self.pending_loads[request.request_id]) * self.block_size, False
254
255 def request_finished(self, request: LlmRequest,
256 cache_block_ids: list[int]) -> bool:
257 # We don't do any asynchronous saving, so always return False
258 return False
259
260 def update_state_after_alloc(self, request: LlmRequest,
261 block_ids: list[int]):
262 pass
263
264
265@click.command()
266@click.argument("model", type=str)
267def main(model: str):
268 sys.path.append(os.path.join(
269 os.path.dirname(__file__),
270 "..",
271 ))
272
273 this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
274
275 # --- KV Cache Connector Config ---
276 kv_connector_config = KvCacheConnectorConfig(
277 connector_module=this_module,
278 connector_scheduler_class="PersistentKvCacheConnectorLeader",
279 connector_worker_class="PersistentKvCacheConnectorWorker",
280 )
281
282 connector_cache_dir = TemporaryDirectory()
283 os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
284
285 # Create LLM instance with KV Cache Connector
286 llm = LLM(model=model,
287 backend="pytorch",
288 cuda_graph_config=None,
289 kv_connector_config=kv_connector_config)
290
291 test_text = (
292 "Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
293 "Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
294 "system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
295 "and mobile and automotive applications. Tell me about the company.")
296
297 sampling_params = SamplingParams(max_tokens=32)
298
299 # Generate text with the first LLM instance and save the kv cache blocks by the connector.
300 output = llm.generate([test_text], sampling_params)
301 text0 = output[0].outputs[0].text
302
303 print("First output: ", text0)
304 print("Loading new LLM instance...")
305
306 del llm
307
308 # Create a new LLM instance with the same connector configuration
309 llm = LLM(model=model,
310 backend="pytorch",
311 cuda_graph_config=None,
312 kv_connector_config=kv_connector_config)
313
314 # Generate text with the second LLM instance and it should reuse the kv cache blocks from the connector.
315 output = llm.generate([test_text], sampling_params)
316 text1 = output[0].outputs[0].text
317
318 print("Second output (using connector cache): ", text1)
319
320 # Verify that the two outputs are identical
321 assert text0 == text1
322
323 connector_cache_dir.cleanup()
324
325
326if __name__ == "__main__":
327 main()