KV Cache Connector#
Source NVIDIA/TensorRT-LLM.
1
2import os
3import sys
4from dataclasses import dataclass, field
5from pathlib import Path
6from tempfile import TemporaryDirectory
7
8import click
9import torch
10
11from tensorrt_llm import LLM, SamplingParams, logger
12from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
13 KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
14from tensorrt_llm.bindings.executor import ExecutorConfig
15from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
16from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
17
18# This is a simple example of the use of the KV cache connector.
19# It persists KV cache contents into a folder, and can load them back on subsequent runs.
20# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface.
21# NOTE: This example connector implementation is NOT suitable for production use.
22
23CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
24
25
26@dataclass
27class PersistentKvCacheConnectorMetadata:
28 load: list[tuple[str, int]] = field(default_factory=list)
29 save: list[tuple[str, int]] = field(default_factory=list)
30
31
32class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
33
34 def __init__(self, executor_config: ExecutorConfig):
35 super().__init__(executor_config)
36
37 self.kv_cache_tensor = None
38
39 def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
40 assert self.kv_cache_tensor is None, "KV cache tensor already registered"
41 self.kv_cache_tensor = kv_cache_tensor
42
43 def start_load_kv(self, stream: torch.cuda.Stream):
44 # Do all loads synchronously, and blockwise.
45 for path, block_id in self._metadata.load:
46 cpu_tensor = torch.load(path, map_location="cpu")
47
48 # Copy into the device block.
49 self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
50
51 def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
52 pass
53
54 def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
55 pass
56
57 def wait_for_save(self, stream: torch.cuda.Stream):
58
59 # Make sure the forward pass is complete before beginning our save.
60 stream.synchronize()
61
62 for path, block_id in self._metadata.save:
63 cpu_tensor = self.kv_cache_tensor[block_id].cpu()
64
65 # Don't write anything if this specific block already exists.
66 if Path(path).exists():
67 continue
68
69 # Do a blocking save to the file. This way, we only return once all saves are complete.
70 torch.save(cpu_tensor, path)
71
72 def get_finished(
73 self, finished_gen_req_ids: list[int],
74 started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
75
76 return [], []
77
78
79class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
80
81 def __init__(self, executor_config: ExecutorConfig):
82 super().__init__(executor_config)
83
84 self.block_size = self._config.tokens_per_block
85 self.pending_loads = {}
86
87 self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
88 "./connector_cache")
89
90 os.makedirs(self.cache_folder, exist_ok=True)
91
92 def build_connector_meta(self, scheduler_output: SchedulerOutput):
93 # NOTE: This is a simplified implementation, and does not work with chunked prefill.
94
95 metadata = PersistentKvCacheConnectorMetadata()
96
97 for req in scheduler_output.new_requests:
98 # If we don't have any pending loads for this request, we can skip it.
99 if req.request_id not in self.pending_loads:
100 continue
101
102 num_computed_blocks = req.computed_position // self.block_size
103 block_ids = req.new_block_ids
104
105 pending_load = self.pending_loads[req.request_id]
106
107 for file_path, block_pos in zip(
108 pending_load, range(num_computed_blocks, len(block_ids))):
109 metadata.load.append((file_path, block_ids[block_pos]))
110
111 # Break up the remainder of the token sequence into chunks.
112 chunks = self._chunk_tokens(req.new_tokens)
113
114 # For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
115 for block_pos in range(num_computed_blocks + len(pending_load),
116 len(block_ids)):
117 if len(chunks[block_pos]) == self.block_size:
118 hashed_tokens = self._hash_tokens(chunks[block_pos])
119
120 file_path = self._file_path(hashed_tokens)
121
122 metadata.save.append((file_path, block_ids[block_pos]))
123
124 self.pending_loads = {}
125
126 return metadata
127
128 def _hash_tokens(self, tokens: list[int]) -> int:
129 return abs(hash(tuple(tokens)))
130
131 def _file_path(self, hash_value: int) -> Path:
132 return Path(self.cache_folder) / f"{hash_value}.pt"
133
134 def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
135 return [
136 tokens[i:i + self.block_size]
137 for i in range(0, len(tokens), self.block_size)
138 ]
139
140 def get_num_new_matched_tokens(
141 self, request: LlmRequest,
142 num_computed_tokens: int) -> tuple[int, bool]:
143 self.pending_loads[request.request_id] = []
144
145 # Don't bother with sequences with partial matches.
146 if (num_computed_tokens % self.block_size) != 0:
147 return 0, False
148
149 computed_blocks = num_computed_tokens // self.block_size
150
151 # Get all the tokens that don't have a cache hit on device.
152 remaining_tokens = request.get_tokens(0)[computed_blocks *
153 self.block_size:]
154
155 remaining_chunks = self._chunk_tokens(remaining_tokens)
156
157 # For each chunk, check if it exists in our cache.
158 for chunk in remaining_chunks:
159 # Only do full blocks.
160 if len(chunk) == self.block_size:
161 hashed_tokens = self._hash_tokens(chunk)
162
163 file_path = self._file_path(hashed_tokens)
164
165 # If we get a cache hit, we want to load it into device.
166 # Otherwise, we can stop looking.
167 if file_path.exists():
168 self.pending_loads[request.request_id].append(file_path)
169 else:
170 break
171
172 logger.info(
173 f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
174 )
175
176 return len(
177 self.pending_loads[request.request_id]) * self.block_size, False
178
179 def request_finished(self, request: LlmRequest,
180 cache_block_ids: list[int]) -> bool:
181 # We don't do any asynchronous saving, so always return False
182 return False
183
184 def update_state_after_alloc(self, request: LlmRequest,
185 block_ids: list[int]):
186 pass
187
188
189@click.command()
190@click.argument("model", type=str)
191def main(model: str):
192 sys.path.append(os.path.join(
193 os.path.dirname(__file__),
194 "..",
195 ))
196
197 this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
198
199 kv_connector_config = KvCacheConnectorConfig(
200 connector_module=this_module,
201 connector_scheduler_class="PersistentKvCacheConnectorLeader",
202 connector_worker_class="PersistentKvCacheConnectorWorker",
203 )
204
205 connector_cache_dir = TemporaryDirectory()
206 os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
207
208 llm = LLM(model=model,
209 backend="pytorch",
210 cuda_graph_config=None,
211 kv_connector_config=kv_connector_config)
212
213 test_text = (
214 "Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
215 "Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
216 "system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
217 "and mobile and automotive applications. Tell me about the company.")
218
219 sampling_params = SamplingParams(max_tokens=32)
220
221 output = llm.generate([test_text], sampling_params)
222 text0 = output[0].outputs[0].text
223
224 print("First output: ", text0)
225 print("Loading new LLM instance...")
226
227 del llm
228
229 llm = LLM(model=model,
230 backend="pytorch",
231 cuda_graph_config=None,
232 kv_connector_config=kv_connector_config)
233
234 output = llm.generate([test_text], sampling_params)
235 text1 = output[0].outputs[0].text
236
237 print("Second output (using connector cache): ", text1)
238
239 assert text0 == text1
240
241 connector_cache_dir.cleanup()
242
243
244if __name__ == "__main__":
245 main()