KV Cache Connector#

Source NVIDIA/TensorRT-LLM.

  1'''
  2This script demonstrates the KV cache connector feature in TensorRT-LLM, which enables
  3custom persistence and reuse of KV cache blocks across different LLM instances.
  4
  5**Scenario:**
  6The script implements a persistent KV cache connector that saves computed KV cache blocks
  7to disk and loads them back in subsequent runs, eliminating redundant computation for
  8recurring prompts.
  9
 10**What is a KV Cache Connector?**
 11
 12A KV cache connector is a customizable interface that allows you to:
 131.  **Save KV Cache:** Persist computed KV cache blocks to an external storage
 14    (disk, database, distributed cache, etc.)
 152.  **Load KV Cache:** Retrieve previously computed cache blocks instead of recomputing them
 163.  **Share Cache Across Instances:** Reuse cache blocks across different LLM instances
 17    or sessions, unlike regular block reuse which is limited to a single instance
 18
 19**How It Works:**
 20
 21This example implements a `PersistentKvCacheConnector` with two key components:
 22
 23* **PersistentKvCacheConnectorLeader (Scheduler):**
 24    - Hashes token sequences to create unique identifiers for each cache block
 25    - Checks if cached blocks exist on disk for incoming requests
 26    - Schedules load operations for cache hits
 27    - Schedules save operations for newly computed blocks
 28
 29* **PersistentKvCacheConnectorWorker:**
 30    - Executes the actual load/save operations between GPU and disk
 31    - Loads cached blocks from disk files into GPU memory
 32    - Saves newly computed blocks from GPU to disk files
 33
 34**Demonstration:**
 35
 36The script processes the same prompt twice using two separate LLM instances:
 37
 381.  **First Run (Instance 1):**
 39    - The LLM computes the KV cache for the input prompt
 40    - The connector saves the computed cache blocks to disk (as .pt files)
 41    - The generation completes and the LLM instance is destroyed
 42
 432.  **Second Run (Instance 2):**
 44    - A new LLM instance is created with the same connector configuration
 45    - When processing the same prompt, the connector finds matching cache blocks on disk
 46    - The cache is loaded from disk instead of being recomputed
 47    - **Expected Outcome:** Faster prefill as cache blocks are loaded rather than computed
 48    - Both outputs should be identical, demonstrating deterministic cache reuse
 49
 50**Key Benefits:**
 51
 52- **Cross-Instance Cache Sharing:** Share computed caches across multiple LLM instances
 53- **Persistent Storage:** Cache survives beyond the lifetime of a single LLM instance
 54- **Custom Storage Backends:** Implement any storage mechanism (shown here: disk files)
 55- **Reduced Computation:** Eliminate redundant KV cache computation for repeated prompts
 56
 57**How to Run:**
 58
 59```bash
 60python llm_kv_cache_connector.py <model_path>
 61```
 62
 63Example:
 64```bash
 65python llm_kv_cache_connector.py meta-llama/Llama-3.1-8B-Instruct
 66```
 67
 68**Implementation Notes:**
 69
 70- This example uses content-based hashing to identify cache blocks
 71- Cache files are stored in a temporary directory (cleaned up after the demo)
 72- The implementation is simplified and not optimized for production use
 73- Does not support chunked prefill in this example
 74- See `tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py` for the full connector interface
 75
 76**NOTE:** This example connector implementation is designed for demonstration purposes
 77and is NOT suitable for production use without additional optimizations and error handling.
 78'''
 79
 80import os
 81import sys
 82from dataclasses import dataclass, field
 83from pathlib import Path
 84from tempfile import TemporaryDirectory
 85from typing import Optional
 86
 87import click
 88import torch
 89
 90from tensorrt_llm import LLM, SamplingParams, logger
 91from tensorrt_llm._torch.pyexecutor.connectors.kv_cache_connector import (
 92    KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
 93from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
 94from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
 95
 96CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
 97
 98
 99@dataclass
100class PersistentKvCacheConnectorMetadata:
101    load: list[tuple[str, int]] = field(default_factory=list)
102    save: list[tuple[str, int]] = field(default_factory=list)
103
104
105class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
106
107    def __init__(self, llm_args: TorchLlmArgs):
108        super().__init__(llm_args)
109
110        self.kv_cache_tensor = None
111
112    def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
113        assert self.kv_cache_tensor is None, "KV cache tensor already registered"
114        self.kv_cache_tensor = kv_cache_tensor
115
116    def start_load_kv(self, stream: torch.cuda.Stream):
117        # Do all loads synchronously, and blockwise.
118        for path, block_id in self._metadata.load:
119            cpu_tensor = torch.load(path, map_location="cpu")
120
121            # Copy into the device block.
122            self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
123
124    def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
125        pass
126
127    def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
128        pass
129
130    def wait_for_save(self, stream: torch.cuda.Stream):
131
132        # Make sure the forward pass is complete before beginning our save.
133        stream.synchronize()
134
135        for path, block_id in self._metadata.save:
136            cpu_tensor = self.kv_cache_tensor[block_id].cpu()
137
138            # Don't write anything if this specific block already exists.
139            if Path(path).exists():
140                continue
141
142            # Do a blocking save to the file. This way, we only return once all saves are complete.
143            torch.save(cpu_tensor, path)
144
145    def get_finished(
146            self, finished_gen_req_ids: list[int],
147            started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
148
149        return [], []
150
151
152class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
153
154    def __init__(self, llm_args: TorchLlmArgs):
155        super().__init__(llm_args)
156
157        self.block_size = self._llm_args.kv_cache_config.tokens_per_block
158        self.pending_loads = {}
159
160        self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
161                                           "./connector_cache")
162
163        os.makedirs(self.cache_folder, exist_ok=True)
164
165    def build_connector_meta(self, scheduler_output: SchedulerOutput):
166        # NOTE: This is a simplified implementation, and does not work with chunked prefill.
167
168        metadata = PersistentKvCacheConnectorMetadata()
169
170        for req in scheduler_output.new_requests:
171            # If we don't have any pending loads for this request, we can skip it.
172            if req.request_id not in self.pending_loads:
173                continue
174
175            num_computed_blocks = req.computed_position // self.block_size
176            block_ids = req.new_block_ids
177
178            pending_load = self.pending_loads[req.request_id]
179
180            for file_path, block_pos in zip(
181                    pending_load, range(num_computed_blocks, len(block_ids))):
182                metadata.load.append((file_path, block_ids[block_pos]))
183
184            # Break up the remainder of the token sequence into chunks.
185            chunks = self._chunk_tokens(req.new_tokens)
186
187            # For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
188            for block_pos in range(num_computed_blocks + len(pending_load),
189                                   len(block_ids)):
190                if len(chunks[block_pos]) == self.block_size:
191                    hashed_tokens = self._hash_tokens(chunks[block_pos],
192                                                      req.cache_salt)
193
194                    file_path = self._file_path(hashed_tokens)
195
196                    metadata.save.append((file_path, block_ids[block_pos]))
197
198        self.pending_loads = {}
199
200        return metadata
201
202    def _hash_tokens(self, tokens: list[int], cache_salt: Optional[str]) -> int:
203        # cache_salt must participate in the hash so that requests carrying
204        # different salts (or no salt) cannot collide on the same cache file.
205        return abs(hash((cache_salt, tuple(tokens))))
206
207    def _file_path(self, hash_value: int) -> Path:
208        return Path(self.cache_folder) / f"{hash_value}.pt"
209
210    def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
211        return [
212            tokens[i:i + self.block_size]
213            for i in range(0, len(tokens), self.block_size)
214        ]
215
216    def get_num_new_matched_tokens(
217            self, request: LlmRequest,
218            num_computed_tokens: int) -> tuple[int, bool]:
219        self.pending_loads[request.request_id] = []
220
221        # Don't bother with sequences with partial matches.
222        if (num_computed_tokens % self.block_size) != 0:
223            return 0, False
224
225        computed_blocks = num_computed_tokens // self.block_size
226
227        # Get all the tokens that don't have a cache hit on device.
228        remaining_tokens = request.get_tokens(0)[computed_blocks *
229                                                 self.block_size:]
230
231        remaining_chunks = self._chunk_tokens(remaining_tokens)
232
233        # For each chunk, check if it exists in our cache.
234        for chunk in remaining_chunks:
235            # Only do full blocks.
236            if len(chunk) == self.block_size:
237                hashed_tokens = self._hash_tokens(chunk, request.cache_salt)
238
239                file_path = self._file_path(hashed_tokens)
240
241                # If we get a cache hit, we want to load it into device.
242                # Otherwise, we can stop looking.
243                if file_path.exists():
244                    self.pending_loads[request.request_id].append(file_path)
245                else:
246                    break
247
248        logger.info(
249            f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
250        )
251
252        return len(
253            self.pending_loads[request.request_id]) * self.block_size, False
254
255    def request_finished(self, request: LlmRequest,
256                         cache_block_ids: list[int]) -> bool:
257        # We don't do any asynchronous saving, so always return False
258        return False
259
260    def update_state_after_alloc(self, request: LlmRequest,
261                                 block_ids: list[int]):
262        pass
263
264
265@click.command()
266@click.argument("model", type=str)
267def main(model: str):
268    sys.path.append(os.path.join(
269        os.path.dirname(__file__),
270        "..",
271    ))
272
273    this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
274
275    # --- KV Cache Connector Config ---
276    kv_connector_config = KvCacheConnectorConfig(
277        connector_module=this_module,
278        connector_scheduler_class="PersistentKvCacheConnectorLeader",
279        connector_worker_class="PersistentKvCacheConnectorWorker",
280    )
281
282    connector_cache_dir = TemporaryDirectory()
283    os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
284
285    # Create LLM instance with KV Cache Connector
286    llm = LLM(model=model,
287              backend="pytorch",
288              cuda_graph_config=None,
289              kv_connector_config=kv_connector_config)
290
291    test_text = (
292        "Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
293        "Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
294        "system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
295        "and mobile and automotive applications. Tell me about the company.")
296
297    sampling_params = SamplingParams(max_tokens=32)
298
299    # Generate text with the first LLM instance and save the kv cache blocks by the connector.
300    output = llm.generate([test_text], sampling_params)
301    text0 = output[0].outputs[0].text
302
303    print("First output: ", text0)
304    print("Loading new LLM instance...")
305
306    del llm
307
308    # Create a new LLM instance with the same connector configuration
309    llm = LLM(model=model,
310              backend="pytorch",
311              cuda_graph_config=None,
312              kv_connector_config=kv_connector_config)
313
314    # Generate text with the second LLM instance and it should reuse the kv cache blocks from the connector.
315    output = llm.generate([test_text], sampling_params)
316    text1 = output[0].outputs[0].text
317
318    print("Second output (using connector cache): ", text1)
319
320    # Verify that the two outputs are identical
321    assert text0 == text1
322
323    connector_cache_dir.cleanup()
324
325
326if __name__ == "__main__":
327    main()