KV Cache Connector#

Source NVIDIA/TensorRT-LLM.

  1
  2import os
  3import sys
  4from dataclasses import dataclass, field
  5from pathlib import Path
  6from tempfile import TemporaryDirectory
  7
  8import click
  9import torch
 10
 11from tensorrt_llm import LLM, SamplingParams, logger
 12from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
 13    KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
 14from tensorrt_llm.bindings.executor import ExecutorConfig
 15from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
 16from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
 17
 18# This is a simple example of the use of the KV cache connector.
 19# It persists KV cache contents into a folder, and can load them back on subsequent runs.
 20# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface.
 21# NOTE: This example connector implementation is NOT suitable for production use.
 22
 23CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
 24
 25
 26@dataclass
 27class PersistentKvCacheConnectorMetadata:
 28    load: list[tuple[str, int]] = field(default_factory=list)
 29    save: list[tuple[str, int]] = field(default_factory=list)
 30
 31
 32class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
 33
 34    def __init__(self, executor_config: ExecutorConfig):
 35        super().__init__(executor_config)
 36
 37        self.kv_cache_tensor = None
 38
 39    def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
 40        assert self.kv_cache_tensor is None, "KV cache tensor already registered"
 41        self.kv_cache_tensor = kv_cache_tensor
 42
 43    def start_load_kv(self, stream: torch.cuda.Stream):
 44        # Do all loads synchronously, and blockwise.
 45        for path, block_id in self._metadata.load:
 46            cpu_tensor = torch.load(path, map_location="cpu")
 47
 48            # Copy into the device block.
 49            self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
 50
 51    def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
 52        pass
 53
 54    def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
 55        pass
 56
 57    def wait_for_save(self, stream: torch.cuda.Stream):
 58
 59        # Make sure the forward pass is complete before beginning our save.
 60        stream.synchronize()
 61
 62        for path, block_id in self._metadata.save:
 63            cpu_tensor = self.kv_cache_tensor[block_id].cpu()
 64
 65            # Don't write anything if this specific block already exists.
 66            if Path(path).exists():
 67                continue
 68
 69            # Do a blocking save to the file. This way, we only return once all saves are complete.
 70            torch.save(cpu_tensor, path)
 71
 72    def get_finished(
 73            self, finished_gen_req_ids: list[int],
 74            started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
 75
 76        return [], []
 77
 78
 79class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
 80
 81    def __init__(self, executor_config: ExecutorConfig):
 82        super().__init__(executor_config)
 83
 84        self.block_size = self._config.tokens_per_block
 85        self.pending_loads = {}
 86
 87        self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
 88                                           "./connector_cache")
 89
 90        os.makedirs(self.cache_folder, exist_ok=True)
 91
 92    def build_connector_meta(self, scheduler_output: SchedulerOutput):
 93        # NOTE: This is a simplified implementation, and does not work with chunked prefill.
 94
 95        metadata = PersistentKvCacheConnectorMetadata()
 96
 97        for req in scheduler_output.new_requests:
 98            # If we don't have any pending loads for this request, we can skip it.
 99            if req.request_id not in self.pending_loads:
100                continue
101
102            num_computed_blocks = req.computed_position // self.block_size
103            block_ids = req.new_block_ids
104
105            pending_load = self.pending_loads[req.request_id]
106
107            for file_path, block_pos in zip(
108                    pending_load, range(num_computed_blocks, len(block_ids))):
109                metadata.load.append((file_path, block_ids[block_pos]))
110
111            # Break up the remainder of the token sequence into chunks.
112            chunks = self._chunk_tokens(req.new_tokens)
113
114            # For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
115            for block_pos in range(num_computed_blocks + len(pending_load),
116                                   len(block_ids)):
117                if len(chunks[block_pos]) == self.block_size:
118                    hashed_tokens = self._hash_tokens(chunks[block_pos])
119
120                    file_path = self._file_path(hashed_tokens)
121
122                    metadata.save.append((file_path, block_ids[block_pos]))
123
124        self.pending_loads = {}
125
126        return metadata
127
128    def _hash_tokens(self, tokens: list[int]) -> int:
129        return abs(hash(tuple(tokens)))
130
131    def _file_path(self, hash_value: int) -> Path:
132        return Path(self.cache_folder) / f"{hash_value}.pt"
133
134    def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
135        return [
136            tokens[i:i + self.block_size]
137            for i in range(0, len(tokens), self.block_size)
138        ]
139
140    def get_num_new_matched_tokens(
141            self, request: LlmRequest,
142            num_computed_tokens: int) -> tuple[int, bool]:
143        self.pending_loads[request.request_id] = []
144
145        # Don't bother with sequences with partial matches.
146        if (num_computed_tokens % self.block_size) != 0:
147            return 0, False
148
149        computed_blocks = num_computed_tokens // self.block_size
150
151        # Get all the tokens that don't have a cache hit on device.
152        remaining_tokens = request.get_tokens(0)[computed_blocks *
153                                                 self.block_size:]
154
155        remaining_chunks = self._chunk_tokens(remaining_tokens)
156
157        # For each chunk, check if it exists in our cache.
158        for chunk in remaining_chunks:
159            # Only do full blocks.
160            if len(chunk) == self.block_size:
161                hashed_tokens = self._hash_tokens(chunk)
162
163                file_path = self._file_path(hashed_tokens)
164
165                # If we get a cache hit, we want to load it into device.
166                # Otherwise, we can stop looking.
167                if file_path.exists():
168                    self.pending_loads[request.request_id].append(file_path)
169                else:
170                    break
171
172        logger.info(
173            f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
174        )
175
176        return len(
177            self.pending_loads[request.request_id]) * self.block_size, False
178
179    def request_finished(self, request: LlmRequest,
180                         cache_block_ids: list[int]) -> bool:
181        # We don't do any asynchronous saving, so always return False
182        return False
183
184    def update_state_after_alloc(self, request: LlmRequest,
185                                 block_ids: list[int]):
186        pass
187
188
189@click.command()
190@click.argument("model", type=str)
191def main(model: str):
192    sys.path.append(os.path.join(
193        os.path.dirname(__file__),
194        "..",
195    ))
196
197    this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
198
199    kv_connector_config = KvCacheConnectorConfig(
200        connector_module=this_module,
201        connector_scheduler_class="PersistentKvCacheConnectorLeader",
202        connector_worker_class="PersistentKvCacheConnectorWorker",
203    )
204
205    connector_cache_dir = TemporaryDirectory()
206    os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
207
208    llm = LLM(model=model,
209              backend="pytorch",
210              cuda_graph_config=None,
211              kv_connector_config=kv_connector_config)
212
213    test_text = (
214        "Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
215        "Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
216        "system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
217        "and mobile and automotive applications. Tell me about the company.")
218
219    sampling_params = SamplingParams(max_tokens=32)
220
221    output = llm.generate([test_text], sampling_params)
222    text0 = output[0].outputs[0].text
223
224    print("First output: ", text0)
225    print("Loading new LLM instance...")
226
227    del llm
228
229    llm = LLM(model=model,
230              backend="pytorch",
231              cuda_graph_config=None,
232              kv_connector_config=kv_connector_config)
233
234    output = llm.generate([test_text], sampling_params)
235    text1 = output[0].outputs[0].text
236
237    print("Second output (using connector cache): ", text1)
238
239    assert text0 == text1
240
241    connector_cache_dir.cleanup()
242
243
244if __name__ == "__main__":
245    main()