KV Cache Connector#
Source NVIDIA/TensorRT-LLM.
1
2import os
3import sys
4from dataclasses import dataclass, field
5from pathlib import Path
6from tempfile import TemporaryDirectory
7
8import click
9import torch
10
11from tensorrt_llm import LLM, SamplingParams, logger
12from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
13 KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
14from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
15from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
16
17# This is a simple example of the use of the KV cache connector.
18# It persists KV cache contents into a folder, and can load them back on subsequent runs.
19# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface.
20# NOTE: This example connector implementation is NOT suitable for production use.
21
22CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
23
24
25@dataclass
26class PersistentKvCacheConnectorMetadata:
27 load: list[tuple[str, int]] = field(default_factory=list)
28 save: list[tuple[str, int]] = field(default_factory=list)
29
30
31class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
32
33 def __init__(self):
34 super().__init__()
35
36 self.kv_cache_tensor = None
37
38 def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
39 assert self.kv_cache_tensor is None, "KV cache tensor already registered"
40 self.kv_cache_tensor = kv_cache_tensor
41
42 def start_load_kv(self, stream: torch.cuda.Stream):
43 # Do all loads synchronously, and blockwise.
44 for path, block_id in self._metadata.load:
45 cpu_tensor = torch.load(path, map_location="cpu")
46
47 # Copy into the device block.
48 self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
49
50 def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
51 pass
52
53 def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
54 pass
55
56 def wait_for_save(self, stream: torch.cuda.Stream):
57
58 # Make sure the forward pass is complete before beginning our save.
59 stream.synchronize()
60
61 for path, block_id in self._metadata.save:
62 cpu_tensor = self.kv_cache_tensor[block_id].cpu()
63
64 # Don't write anything if this specific block already exists.
65 if Path(path).exists():
66 continue
67
68 # Do a blocking save to the file. This way, we only return once all saves are complete.
69 torch.save(cpu_tensor, path)
70
71 def get_finished(
72 self, finished_gen_req_ids: list[int],
73 started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
74
75 return [], []
76
77
78class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
79
80 def __init__(self, tokens_per_block):
81 super().__init__()
82
83 self.block_size = tokens_per_block
84 self.pending_loads = {}
85
86 self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
87 "./connector_cache")
88
89 os.makedirs(self.cache_folder, exist_ok=True)
90
91 def build_connector_meta(self, scheduler_output: SchedulerOutput):
92 # NOTE: This is a simplified implementation, and does not work with chunked prefill.
93
94 metadata = PersistentKvCacheConnectorMetadata()
95
96 for req in scheduler_output.new_requests:
97 # If we don't have any pending loads for this request, we can skip it.
98 if req.request_id not in self.pending_loads:
99 continue
100
101 num_computed_blocks = req.computed_position // self.block_size
102 block_ids = req.new_block_ids
103
104 pending_load = self.pending_loads[req.request_id]
105
106 for file_path, block_pos in zip(
107 pending_load, range(num_computed_blocks, len(block_ids))):
108 metadata.load.append((file_path, block_ids[block_pos]))
109
110 # Break up the remainder of the token sequence into chunks.
111 chunks = self._chunk_tokens(req.new_tokens)
112
113 # For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
114 for block_pos in range(num_computed_blocks + len(pending_load),
115 len(block_ids)):
116 if len(chunks[block_pos]) == self.block_size:
117 hashed_tokens = self._hash_tokens(chunks[block_pos])
118
119 file_path = self._file_path(hashed_tokens)
120
121 metadata.save.append((file_path, block_ids[block_pos]))
122
123 self.pending_loads = {}
124
125 return metadata
126
127 def _hash_tokens(self, tokens: list[int]) -> int:
128 return abs(hash(tuple(tokens)))
129
130 def _file_path(self, hash_value: int) -> Path:
131 return Path(self.cache_folder) / f"{hash_value}.pt"
132
133 def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
134 return [
135 tokens[i:i + self.block_size]
136 for i in range(0, len(tokens), self.block_size)
137 ]
138
139 def get_num_new_matched_tokens(
140 self, request: LlmRequest,
141 num_computed_tokens: int) -> tuple[int, bool]:
142 self.pending_loads[request.request_id] = []
143
144 # Don't bother with sequences with partial matches.
145 if (num_computed_tokens % self.block_size) != 0:
146 return 0, False
147
148 computed_blocks = num_computed_tokens // self.block_size
149
150 # Get all the tokens that don't have a cache hit on device.
151 remaining_tokens = request.get_tokens(0)[computed_blocks *
152 self.block_size:]
153
154 remaining_chunks = self._chunk_tokens(remaining_tokens)
155
156 # For each chunk, check if it exists in our cache.
157 for chunk in remaining_chunks:
158 # Only do full blocks.
159 if len(chunk) == self.block_size:
160 hashed_tokens = self._hash_tokens(chunk)
161
162 file_path = self._file_path(hashed_tokens)
163
164 # If we get a cache hit, we want to load it into device.
165 # Otherwise, we can stop looking.
166 if file_path.exists():
167 self.pending_loads[request.request_id].append(file_path)
168 else:
169 break
170
171 logger.info(
172 f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
173 )
174
175 return len(
176 self.pending_loads[request.request_id]) * self.block_size, False
177
178 def request_finished(self, request: LlmRequest,
179 cache_block_ids: list[int]) -> bool:
180 # We don't do any asynchronous saving, so always return False
181 return False
182
183 def update_state_after_alloc(self, request: LlmRequest,
184 block_ids: list[int]):
185 pass
186
187
188@click.command()
189@click.argument("model", type=str)
190def main(model: str):
191 sys.path.append(os.path.join(
192 os.path.dirname(__file__),
193 "..",
194 ))
195
196 this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
197
198 kv_connector_config = KvCacheConnectorConfig(
199 connector_module=this_module,
200 connector_scheduler_class="PersistentKvCacheConnectorLeader",
201 connector_worker_class="PersistentKvCacheConnectorWorker",
202 )
203
204 connector_cache_dir = TemporaryDirectory()
205 os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
206
207 llm = LLM(model=model,
208 backend="pytorch",
209 cuda_graph_config=None,
210 kv_connector_config=kv_connector_config)
211
212 test_text = (
213 "Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
214 "Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
215 "system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
216 "and mobile and automotive applications. Tell me about the company.")
217
218 sampling_params = SamplingParams(max_tokens=32)
219
220 output = llm.generate([test_text], sampling_params)
221 text0 = output[0].outputs[0].text
222
223 print("First output: ", text0)
224 print("Loading new LLM instance...")
225
226 del llm
227
228 llm = LLM(model=model,
229 backend="pytorch",
230 cuda_graph_config=None,
231 kv_connector_config=kv_connector_config)
232
233 output = llm.generate([test_text], sampling_params)
234 text1 = output[0].outputs[0].text
235
236 print("Second output (using connector cache): ", text1)
237
238 assert text0 == text1
239
240 connector_cache_dir.cleanup()
241
242
243if __name__ == "__main__":
244 main()