KV Cache Connector#

Source NVIDIA/TensorRT-LLM.

  1
  2import os
  3import sys
  4from dataclasses import dataclass, field
  5from pathlib import Path
  6from tempfile import TemporaryDirectory
  7
  8import click
  9import torch
 10
 11from tensorrt_llm import LLM, SamplingParams, logger
 12from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
 13    KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
 14from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
 15from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
 16
 17# This is a simple example of the use of the KV cache connector.
 18# It persists KV cache contents into a folder, and can load them back on subsequent runs.
 19# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface.
 20# NOTE: This example connector implementation is NOT suitable for production use.
 21
 22CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
 23
 24
 25@dataclass
 26class PersistentKvCacheConnectorMetadata:
 27    load: list[tuple[str, int]] = field(default_factory=list)
 28    save: list[tuple[str, int]] = field(default_factory=list)
 29
 30
 31class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
 32
 33    def __init__(self):
 34        super().__init__()
 35
 36        self.kv_cache_tensor = None
 37
 38    def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
 39        assert self.kv_cache_tensor is None, "KV cache tensor already registered"
 40        self.kv_cache_tensor = kv_cache_tensor
 41
 42    def start_load_kv(self, stream: torch.cuda.Stream):
 43        # Do all loads synchronously, and blockwise.
 44        for path, block_id in self._metadata.load:
 45            cpu_tensor = torch.load(path, map_location="cpu")
 46
 47            # Copy into the device block.
 48            self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
 49
 50    def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
 51        pass
 52
 53    def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
 54        pass
 55
 56    def wait_for_save(self, stream: torch.cuda.Stream):
 57
 58        # Make sure the forward pass is complete before beginning our save.
 59        stream.synchronize()
 60
 61        for path, block_id in self._metadata.save:
 62            cpu_tensor = self.kv_cache_tensor[block_id].cpu()
 63
 64            # Don't write anything if this specific block already exists.
 65            if Path(path).exists():
 66                continue
 67
 68            # Do a blocking save to the file. This way, we only return once all saves are complete.
 69            torch.save(cpu_tensor, path)
 70
 71    def get_finished(
 72            self, finished_gen_req_ids: list[int],
 73            started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
 74
 75        return [], []
 76
 77
 78class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
 79
 80    def __init__(self, tokens_per_block):
 81        super().__init__()
 82
 83        self.block_size = tokens_per_block
 84        self.pending_loads = {}
 85
 86        self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
 87                                           "./connector_cache")
 88
 89        os.makedirs(self.cache_folder, exist_ok=True)
 90
 91    def build_connector_meta(self, scheduler_output: SchedulerOutput):
 92        # NOTE: This is a simplified implementation, and does not work with chunked prefill.
 93
 94        metadata = PersistentKvCacheConnectorMetadata()
 95
 96        for req in scheduler_output.new_requests:
 97            # If we don't have any pending loads for this request, we can skip it.
 98            if req.request_id not in self.pending_loads:
 99                continue
100
101            num_computed_blocks = req.computed_position // self.block_size
102            block_ids = req.new_block_ids
103
104            pending_load = self.pending_loads[req.request_id]
105
106            for file_path, block_pos in zip(
107                    pending_load, range(num_computed_blocks, len(block_ids))):
108                metadata.load.append((file_path, block_ids[block_pos]))
109
110            # Break up the remainder of the token sequence into chunks.
111            chunks = self._chunk_tokens(req.new_tokens)
112
113            # For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
114            for block_pos in range(num_computed_blocks + len(pending_load),
115                                   len(block_ids)):
116                if len(chunks[block_pos]) == self.block_size:
117                    hashed_tokens = self._hash_tokens(chunks[block_pos])
118
119                    file_path = self._file_path(hashed_tokens)
120
121                    metadata.save.append((file_path, block_ids[block_pos]))
122
123        self.pending_loads = {}
124
125        return metadata
126
127    def _hash_tokens(self, tokens: list[int]) -> int:
128        return abs(hash(tuple(tokens)))
129
130    def _file_path(self, hash_value: int) -> Path:
131        return Path(self.cache_folder) / f"{hash_value}.pt"
132
133    def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
134        return [
135            tokens[i:i + self.block_size]
136            for i in range(0, len(tokens), self.block_size)
137        ]
138
139    def get_num_new_matched_tokens(
140            self, request: LlmRequest,
141            num_computed_tokens: int) -> tuple[int, bool]:
142        self.pending_loads[request.request_id] = []
143
144        # Don't bother with sequences with partial matches.
145        if (num_computed_tokens % self.block_size) != 0:
146            return 0, False
147
148        computed_blocks = num_computed_tokens // self.block_size
149
150        # Get all the tokens that don't have a cache hit on device.
151        remaining_tokens = request.get_tokens(0)[computed_blocks *
152                                                 self.block_size:]
153
154        remaining_chunks = self._chunk_tokens(remaining_tokens)
155
156        # For each chunk, check if it exists in our cache.
157        for chunk in remaining_chunks:
158            # Only do full blocks.
159            if len(chunk) == self.block_size:
160                hashed_tokens = self._hash_tokens(chunk)
161
162                file_path = self._file_path(hashed_tokens)
163
164                # If we get a cache hit, we want to load it into device.
165                # Otherwise, we can stop looking.
166                if file_path.exists():
167                    self.pending_loads[request.request_id].append(file_path)
168                else:
169                    break
170
171        logger.info(
172            f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
173        )
174
175        return len(
176            self.pending_loads[request.request_id]) * self.block_size, False
177
178    def request_finished(self, request: LlmRequest,
179                         cache_block_ids: list[int]) -> bool:
180        # We don't do any asynchronous saving, so always return False
181        return False
182
183    def update_state_after_alloc(self, request: LlmRequest,
184                                 block_ids: list[int]):
185        pass
186
187
188@click.command()
189@click.argument("model", type=str)
190def main(model: str):
191    sys.path.append(os.path.join(
192        os.path.dirname(__file__),
193        "..",
194    ))
195
196    this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
197
198    kv_connector_config = KvCacheConnectorConfig(
199        connector_module=this_module,
200        connector_scheduler_class="PersistentKvCacheConnectorLeader",
201        connector_worker_class="PersistentKvCacheConnectorWorker",
202    )
203
204    connector_cache_dir = TemporaryDirectory()
205    os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
206
207    llm = LLM(model=model,
208              backend="pytorch",
209              cuda_graph_config=None,
210              kv_connector_config=kv_connector_config)
211
212    test_text = (
213        "Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
214        "Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
215        "system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
216        "and mobile and automotive applications. Tell me about the company.")
217
218    sampling_params = SamplingParams(max_tokens=32)
219
220    output = llm.generate([test_text], sampling_params)
221    text0 = output[0].outputs[0].text
222
223    print("First output: ", text0)
224    print("Loading new LLM instance...")
225
226    del llm
227
228    llm = LLM(model=model,
229              backend="pytorch",
230              cuda_graph_config=None,
231              kv_connector_config=kv_connector_config)
232
233    output = llm.generate([test_text], sampling_params)
234    text1 = output[0].outputs[0].text
235
236    print("Second output (using connector cache): ", text1)
237
238    assert text0 == text1
239
240    connector_cache_dir.cleanup()
241
242
243if __name__ == "__main__":
244    main()