KV Cache Connector#

Source NVIDIA/TensorRT-LLM.

  1'''
  2This script demonstrates the KV cache connector feature in TensorRT-LLM, which enables
  3custom persistence and reuse of KV cache blocks across different LLM instances.
  4
  5**Scenario:**
  6The script implements a persistent KV cache connector that saves computed KV cache blocks
  7to disk and loads them back in subsequent runs, eliminating redundant computation for
  8recurring prompts.
  9
 10**What is a KV Cache Connector?**
 11
 12A KV cache connector is a customizable interface that allows you to:
 131.  **Save KV Cache:** Persist computed KV cache blocks to an external storage
 14    (disk, database, distributed cache, etc.)
 152.  **Load KV Cache:** Retrieve previously computed cache blocks instead of recomputing them
 163.  **Share Cache Across Instances:** Reuse cache blocks across different LLM instances
 17    or sessions, unlike regular block reuse which is limited to a single instance
 18
 19**How It Works:**
 20
 21This example implements a `PersistentKvCacheConnector` with two key components:
 22
 23* **PersistentKvCacheConnectorLeader (Scheduler):**
 24    - Hashes token sequences to create unique identifiers for each cache block
 25    - Checks if cached blocks exist on disk for incoming requests
 26    - Schedules load operations for cache hits
 27    - Schedules save operations for newly computed blocks
 28
 29* **PersistentKvCacheConnectorWorker:**
 30    - Executes the actual load/save operations between GPU and disk
 31    - Loads cached blocks from disk files into GPU memory
 32    - Saves newly computed blocks from GPU to disk files
 33
 34**Demonstration:**
 35
 36The script processes the same prompt twice using two separate LLM instances:
 37
 381.  **First Run (Instance 1):**
 39    - The LLM computes the KV cache for the input prompt
 40    - The connector saves the computed cache blocks to disk (as .pt files)
 41    - The generation completes and the LLM instance is destroyed
 42
 432.  **Second Run (Instance 2):**
 44    - A new LLM instance is created with the same connector configuration
 45    - When processing the same prompt, the connector finds matching cache blocks on disk
 46    - The cache is loaded from disk instead of being recomputed
 47    - **Expected Outcome:** Faster prefill as cache blocks are loaded rather than computed
 48    - Both outputs should be identical, demonstrating deterministic cache reuse
 49
 50**Key Benefits:**
 51
 52- **Cross-Instance Cache Sharing:** Share computed caches across multiple LLM instances
 53- **Persistent Storage:** Cache survives beyond the lifetime of a single LLM instance
 54- **Custom Storage Backends:** Implement any storage mechanism (shown here: disk files)
 55- **Reduced Computation:** Eliminate redundant KV cache computation for repeated prompts
 56
 57**How to Run:**
 58
 59```bash
 60python llm_kv_cache_connector.py <model_path>
 61```
 62
 63Example:
 64```bash
 65python llm_kv_cache_connector.py meta-llama/Llama-3.1-8B-Instruct
 66```
 67
 68**Implementation Notes:**
 69
 70- This example uses content-based hashing to identify cache blocks
 71- Cache files are stored in a temporary directory (cleaned up after the demo)
 72- The implementation is simplified and not optimized for production use
 73- Does not support chunked prefill in this example
 74- See `tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py` for the full connector interface
 75
 76**NOTE:** This example connector implementation is designed for demonstration purposes
 77and is NOT suitable for production use without additional optimizations and error handling.
 78'''
 79
 80import os
 81import sys
 82from dataclasses import dataclass, field
 83from pathlib import Path
 84from tempfile import TemporaryDirectory
 85from typing import Optional
 86
 87import click
 88import torch
 89
 90from tensorrt_llm import LLM, SamplingParams, logger
 91from tensorrt_llm._torch.pyexecutor.connectors.kv_cache_connector import (
 92    KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
 93from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
 94from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
 95
 96CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
 97
 98
 99@dataclass
100class PersistentKvCacheConnectorMetadata:
101    load: list[tuple[str, int]] = field(default_factory=list)
102    save: list[tuple[str, int]] = field(default_factory=list)
103
104
105class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
106
107    def __init__(self, llm_args: TorchLlmArgs):
108        super().__init__(llm_args)
109
110        self.kv_cache_tensor = None
111
112    def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
113        assert self.kv_cache_tensor is None, "KV cache tensor already registered"
114        self.kv_cache_tensor = kv_cache_tensor
115
116    def start_load_kv(self, stream: torch.cuda.Stream):
117        # Do all loads synchronously, and blockwise.
118        for path, block_id in self._metadata.load:
119            cpu_tensor = torch.load(path, map_location="cpu")
120
121            # Copy into the device block.
122            self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
123
124    def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
125        pass
126
127    def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
128        pass
129
130    def wait_for_save(self, stream: torch.cuda.Stream):
131
132        # Make sure the forward pass is complete before beginning our save.
133        stream.synchronize()
134
135        for path, block_id in self._metadata.save:
136            cpu_tensor = self.kv_cache_tensor[block_id].cpu()
137
138            # Don't write anything if this specific block already exists.
139            if Path(path).exists():
140                continue
141
142            # Do a blocking save to the file. This way, we only return once all saves are complete.
143            torch.save(cpu_tensor, path)
144
145    def get_finished(
146            self, finished_gen_req_ids: list[int],
147            started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
148
149        return [], []
150
151
152class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
153
154    def __init__(self, llm_args: TorchLlmArgs):
155        super().__init__(llm_args)
156
157        self.block_size = self._llm_args.kv_cache_config.tokens_per_block
158        self.pending_loads = {}
159
160        self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
161                                           "./connector_cache")
162
163        os.makedirs(self.cache_folder, exist_ok=True)
164
165    def build_connector_meta(self, scheduler_output: SchedulerOutput):
166        # NOTE: This is a simplified implementation, and does not work with chunked prefill.
167
168        metadata = PersistentKvCacheConnectorMetadata()
169
170        for req in scheduler_output.new_requests:
171            # If we don't have any pending loads for this request, we can skip it.
172            if req.request_id not in self.pending_loads:
173                continue
174
175            num_computed_blocks = req.computed_position // self.block_size
176            block_ids = req.new_block_ids
177
178            pending_load = self.pending_loads[req.request_id]
179
180            for file_path, block_pos in zip(
181                    pending_load, range(num_computed_blocks, len(block_ids))):
182                metadata.load.append((file_path, block_ids[block_pos]))
183
184            # Break up the remainder of the token sequence into chunks.
185            chunks = self._chunk_tokens(req.new_tokens)
186
187            # For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
188            for block_pos in range(num_computed_blocks + len(pending_load),
189                                   len(block_ids)):
190                if len(chunks[block_pos]) == self.block_size:
191                    hashed_tokens = self._hash_tokens(chunks[block_pos],
192                                                      req.cache_salt_id)
193
194                    file_path = self._file_path(hashed_tokens)
195
196                    metadata.save.append((file_path, block_ids[block_pos]))
197
198        self.pending_loads = {}
199
200        return metadata
201
202    def _hash_tokens(self, tokens: list[int],
203                     cache_salt_id: Optional[int]) -> int:
204        # cache_salt_id must participate in the hash so that requests carrying
205        # different salts (or no salt) cannot collide on the same cache file.
206        return abs(hash((cache_salt_id, tuple(tokens))))
207
208    def _file_path(self, hash_value: int) -> Path:
209        return Path(self.cache_folder) / f"{hash_value}.pt"
210
211    def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
212        return [
213            tokens[i:i + self.block_size]
214            for i in range(0, len(tokens), self.block_size)
215        ]
216
217    def get_num_new_matched_tokens(
218            self, request: LlmRequest,
219            num_computed_tokens: int) -> tuple[int, bool]:
220        self.pending_loads[request.request_id] = []
221
222        # Don't bother with sequences with partial matches.
223        if (num_computed_tokens % self.block_size) != 0:
224            return 0, False
225
226        computed_blocks = num_computed_tokens // self.block_size
227
228        # Get all the tokens that don't have a cache hit on device.
229        remaining_tokens = request.get_tokens(0)[computed_blocks *
230                                                 self.block_size:]
231
232        remaining_chunks = self._chunk_tokens(remaining_tokens)
233
234        # For each chunk, check if it exists in our cache.
235        for chunk in remaining_chunks:
236            # Only do full blocks.
237            if len(chunk) == self.block_size:
238                hashed_tokens = self._hash_tokens(chunk, request.cache_salt_id)
239
240                file_path = self._file_path(hashed_tokens)
241
242                # If we get a cache hit, we want to load it into device.
243                # Otherwise, we can stop looking.
244                if file_path.exists():
245                    self.pending_loads[request.request_id].append(file_path)
246                else:
247                    break
248
249        logger.info(
250            f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
251        )
252
253        return len(
254            self.pending_loads[request.request_id]) * self.block_size, False
255
256    def request_finished(self, request: LlmRequest,
257                         cache_block_ids: list[int]) -> bool:
258        # We don't do any asynchronous saving, so always return False
259        return False
260
261    def update_state_after_alloc(self, request: LlmRequest,
262                                 block_ids: list[int]):
263        pass
264
265
266@click.command()
267@click.argument("model", type=str)
268def main(model: str):
269    sys.path.append(os.path.join(
270        os.path.dirname(__file__),
271        "..",
272    ))
273
274    this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
275
276    # --- KV Cache Connector Config ---
277    kv_connector_config = KvCacheConnectorConfig(
278        connector_module=this_module,
279        connector_scheduler_class="PersistentKvCacheConnectorLeader",
280        connector_worker_class="PersistentKvCacheConnectorWorker",
281    )
282
283    connector_cache_dir = TemporaryDirectory()
284    os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
285
286    # Create LLM instance with KV Cache Connector
287    llm = LLM(model=model,
288              backend="pytorch",
289              cuda_graph_config=None,
290              kv_connector_config=kv_connector_config)
291
292    test_text = (
293        "Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
294        "Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
295        "system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
296        "and mobile and automotive applications. Tell me about the company.")
297
298    sampling_params = SamplingParams(max_tokens=32)
299
300    # Generate text with the first LLM instance and save the kv cache blocks by the connector.
301    output = llm.generate([test_text], sampling_params)
302    text0 = output[0].outputs[0].text
303
304    print("First output: ", text0)
305    print("Loading new LLM instance...")
306
307    del llm
308
309    # Create a new LLM instance with the same connector configuration
310    llm = LLM(model=model,
311              backend="pytorch",
312              cuda_graph_config=None,
313              kv_connector_config=kv_connector_config)
314
315    # Generate text with the second LLM instance and it should reuse the kv cache blocks from the connector.
316    output = llm.generate([test_text], sampling_params)
317    text1 = output[0].outputs[0].text
318
319    print("Second output (using connector cache): ", text1)
320
321    # Verify that the two outputs are identical
322    assert text0 == text1
323
324    connector_cache_dir.cleanup()
325
326
327if __name__ == "__main__":
328    main()