1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import hashlib
17import os
18import stat
19import tempfile
20import threading
21from collections import OrderedDict
22from dataclasses import dataclass
23from datetime import datetime
24from typing import Any, List, Optional, Union
25
26from filelock import BaseFileLock, FileLock, Timeout
27
28from .instrumentation.utils import CacheManagerMetricsHelper
29from .caching.cache_item import CacheItem
30from .caching.eviction_policy import FIFO, EvictionPolicyFactory
31import time
32
33DEFAULT_CACHE_SIZE_MB = 10_000 # 10 GB
34DEFAULT_CACHE_REFRESH_INTERVAL = 300 # 5 minutes
35DEFAULT_LOCK_TIMEOUT = 600 # 10 minutes
36
37
[docs]
38@dataclass
39class CacheConfig:
40 """Configuration for the cache.
41
42 :param location: The directory where the cache is stored.
43 :param size_mb: The maximum size of the cache in megabytes.
44 :param use_etag: Use etag to update the cached files.
45 :param eviction_policy: Cache eviction policy.
46 """
47
48 location: str
49 size_mb: int
50 use_etag: bool = True
51 eviction_policy: str = FIFO
52
[docs]
53 @staticmethod
54 def create(location: str, size_mb: int, use_etag: bool = True, eviction_policy: str = FIFO) -> "CacheConfig":
55 """Create a CacheConfig instance.
56
57 :param location: The directory where the cache is stored.
58 :param size_mb: The maximum size of the cache in megabytes.
59 :param use_etag: Use etag to update the cached files.
60 :param eviction_policy: Cache eviction policy.
61
62 :return: A CacheConfig instance.
63 """
64 return CacheConfig(location=location, size_mb=size_mb, use_etag=use_etag, eviction_policy=eviction_policy)
65
[docs]
66 def size_bytes(self) -> int:
67 """
68 Convert cache size from megabytes to bytes.
69
70 :return: The size of the cache in bytes.
71 """
72 return self.size_mb * 1024 * 1024
73
74
75class CacheManager:
76 """
77 A cache manager that stores files in a specified directory and evicts files based on the LRU policy.
78 """
79
80 def __init__(
81 self, profile: str, cache_config: CacheConfig, cache_refresh_interval: int = DEFAULT_CACHE_REFRESH_INTERVAL
82 ):
83 self._profile = profile
84 self._cache_config = cache_config
85 self._max_cache_size = cache_config.size_bytes()
86 self._cache_refresh_interval = cache_refresh_interval
87 self._last_refresh_time = datetime.now()
88
89 # Create eviction policy using factory
90 policy: str = self._cache_config.eviction_policy # type: ignore
91 self._eviction_policy = EvictionPolicyFactory.create(policy)
92
93 # Create cache directory
94 self._cache_dir = cache_config.location
95 os.makedirs(self._cache_dir, exist_ok=True)
96 os.makedirs(os.path.join(self._cache_dir, self._profile), exist_ok=True)
97
98 # Metrics
99 self._metrics_helper = CacheManagerMetricsHelper()
100
101 # Populate cache with existing files in the cache directory
102 self._cache_refresh_lock_file = FileLock(
103 os.path.join(self._cache_dir, ".cache_refresh.lock"), timeout=0, blocking=False
104 )
105 self.refresh_cache()
106
107 def _get_file_size(self, file_path: str) -> Optional[int]:
108 """
109 Get the size of the file in bytes.
110
111 :return: The file size in bytes.
112 """
113 try:
114 return os.path.getsize(file_path)
115 except OSError:
116 return None
117
118 def _delete(self, file_name: str) -> None:
119 """
120 Delete a file from the cache directory.
121 """
122 try:
123 os.unlink(os.path.join(self._cache_dir, self._profile, file_name))
124 os.unlink(os.path.join(self._cache_dir, self._profile, f".{file_name}.lock"))
125 except OSError:
126 pass
127
128 def _get_cache_key(self, file_name: str) -> str:
129 """
130 Hash the file name using MD5.
131 """
132 return hashlib.md5(file_name.encode()).hexdigest()
133
134 def _should_refresh_cache(self) -> bool:
135 """
136 Check if enough time has passed since the last refresh.
137 """
138 now = datetime.now()
139 return (now - self._last_refresh_time).seconds > self._cache_refresh_interval
140
141 def use_etag(self) -> bool:
142 """
143 Check if ``use_etag`` is set in the cache config.
144 """
145 return self._cache_config.use_etag
146
147 def get_max_cache_size(self) -> int:
148 """
149 Return the cache size in bytes from the cache config.
150 """
151 return self._max_cache_size
152
153 def get_cache_dir(self) -> str:
154 """
155 Return the path to the local cache directory.
156
157 :return: The full path to the cache directory.
158 """
159 return os.path.join(self._cache_dir, self._profile)
160
161 def get_cache_file_path(self, key: str) -> str:
162 """
163 Return the path to the local cache file for the given key.
164
165 :return: The full path to the cached file.
166 """
167 hashed_name = self._get_cache_key(key)
168 return os.path.join(self._cache_dir, self._profile, hashed_name)
169
170 def read(self, key: str) -> Optional[bytes]:
171 """
172 Read the contents of a file from the cache if it exists.
173
174 :param key: The key corresponding to the file to be read.
175
176 :return: The contents of the file as bytes if found in the cache, otherwise None.
177 """
178 success = True
179 try:
180 try:
181 if self.contains(key):
182 file_path = self.get_cache_file_path(key)
183 with open(file_path, "rb") as fp:
184 data = fp.read()
185 # Update access time based on eviction policy
186 self.update_access_time(file_path)
187 return data
188 except OSError:
189 pass
190
191 # cache miss
192 success = False
193 return None
194 finally:
195 self._metrics_helper.increase(operation="READ", success=success)
196
197 def open(self, key: str, mode: str = "rb") -> Optional[Any]:
198 """
199 Open a file from the cache and return the file object.
200
201 :param key: The key corresponding to the file to be opened.
202 :param mode: The mode in which to open the file (default is ``rb`` for read binary).
203
204 :return: The file object if the file is found in the cache, otherwise None.
205 """
206 success = True
207 try:
208 try:
209 if self.contains(key):
210 file_path = self.get_cache_file_path(key)
211 # Update access time based on eviction policy
212 self.update_access_time(file_path)
213 return open(file_path, mode)
214 except OSError:
215 pass
216
217 # cache miss
218 success = False
219 return None
220 finally:
221 self._metrics_helper.increase(operation="OPEN", success=success)
222
223 def set(self, key: str, source: Union[str, bytes]) -> None:
224 """
225 Store a file in the cache.
226
227 :param key: The key corresponding to the file to be stored.
228 :param source: The source data to be stored, either a path to a file or bytes.
229 """
230 success = True
231 try:
232 hashed_name = self._get_cache_key(key)
233 file_path = os.path.join(self._cache_dir, self._profile, hashed_name)
234
235 if isinstance(source, str):
236 # Move the file to the cache directory
237 os.rename(src=source, dst=file_path)
238 # Only allow the owner to read and write the file
239 os.chmod(file_path, mode=stat.S_IRUSR | stat.S_IWUSR)
240 else:
241 # Create a temporary file and move the file to the cache directory
242 with tempfile.NamedTemporaryFile(
243 mode="wb", delete=False, dir=os.path.dirname(file_path), prefix="."
244 ) as temp_file:
245 temp_file_path = temp_file.name
246 temp_file.write(source)
247 os.rename(src=temp_file_path, dst=file_path)
248 # Only allow the owner to read and write the file
249 os.chmod(file_path, mode=stat.S_IRUSR | stat.S_IWUSR)
250
251 # update access time if applicable
252 self.update_access_time(file_path)
253
254 # Refresh cache after a few minutes
255 if self._should_refresh_cache():
256 thread = threading.Thread(target=self.refresh_cache)
257 thread.start()
258 except Exception:
259 success = False
260 finally:
261 self._metrics_helper.increase(operation="SET", success=success)
262
263 def contains(self, key: str) -> bool:
264 """
265 Check if the cache contains a file corresponding to the given key.
266
267 :param key: The key corresponding to the file.
268
269 :return: True if the file is found in the cache, False otherwise.
270 """
271 hashed_name = self._get_cache_key(key)
272 file_path = os.path.join(self._cache_dir, self._profile, hashed_name)
273 return os.path.exists(file_path)
274
275 def delete(self, key: str) -> None:
276 """
277 Delete a file from the cache.
278 """
279 try:
280 hashed_name = self._get_cache_key(key)
281 self._delete(hashed_name)
282 finally:
283 self._metrics_helper.increase(operation="DELETE", success=True)
284
285 def cache_size(self) -> int:
286 """
287 Return the current size of the cache in bytes.
288
289 :return: The cache size in bytes.
290 """
291 file_size = 0
292
293 # Traverse the directory and subdirectories
294 for dirpath, _, filenames in os.walk(self._cache_dir):
295 for file_name in filenames:
296 file_path = os.path.join(dirpath, file_name)
297 if os.path.isfile(file_path) and not file_name.endswith(".lock"):
298 size = self._get_file_size(file_path)
299 if size:
300 file_size += size
301
302 return file_size
303
304 def _evict_files(self) -> None:
305 """
306 Evict cache entries based on the configured eviction policy.
307 """
308 cache_items: List[CacheItem] = []
309
310 # Traverse the directory and subdirectories
311 for dirpath, _, filenames in os.walk(self._cache_dir):
312 for file_name in filenames:
313 file_path = os.path.join(dirpath, file_name)
314 # Skip lock files and hidden files
315 if file_name.endswith(".lock") or file_name.startswith("."):
316 continue
317 try:
318 if os.path.isfile(file_path):
319 cache_item = CacheItem.from_path(file_path, file_name)
320 if cache_item and cache_item.file_size:
321 cache_items.append(cache_item)
322 except OSError:
323 # Ignore if file has already been evicted
324 pass
325
326 # Sort items according to eviction policy
327 cache_items = self._eviction_policy.sort_items(cache_items)
328
329 # Rebuild the cache
330 cache = OrderedDict()
331 cache_size = 0
332 for item in cache_items:
333 cache[item.file_path] = item.file_size
334 cache_size += item.file_size
335
336 # Evict old files if necessary in case the existing files exceed cache size
337 while cache_size > self._max_cache_size:
338 # Pop the first item in the OrderedDict (according to policy's sorting)
339 oldest_file, file_size = cache.popitem(last=False)
340 cache_size -= file_size
341 self._delete(os.path.basename(oldest_file))
342
343 def refresh_cache(self) -> bool:
344 """
345 Scan the cache directory and evict cache entries based on the last modification time.
346 This method is protected by a :py:class:`filelock.FileLock` that only allows a single process to evict the cached files.
347 """
348 try:
349 # If the process acquires the lock, then proceed with the cache eviction
350 with self._cache_refresh_lock_file.acquire(blocking=False):
351 self._evict_files()
352 self._last_refresh_time = datetime.now()
353 return True
354 except Timeout:
355 # If the process cannot acquire the lock, ignore and wait for the next turn
356 pass
357
358 return False
359
360 def acquire_lock(self, key: str) -> BaseFileLock:
361 """
362 Create a :py:class:`filelock.FileLock` object for a given key.
363
364 :return: :py:class:`filelock.FileLock` object.
365 """
366 hashed_name = self._get_cache_key(key)
367 lock_file = os.path.join(self._cache_dir, self._profile, f".{hashed_name}.lock")
368 return FileLock(lock_file, timeout=DEFAULT_LOCK_TIMEOUT)
369
370 def update_access_time(self, file_path: str) -> None:
371 """Update access time to current time for LRU policy.
372
373 Only updates atime, preserving mtime for FIFO ordering.
374 This is used to track when files are accessed for LRU eviction.
375
376 :param file_path: Path to the file to update access time.
377 """
378 current_time = time.time()
379 try:
380 # Only update atime, preserve mtime for FIFO ordering
381 stat = os.stat(file_path)
382 os.utime(file_path, (current_time, stat.st_mtime))
383 except (OSError, FileNotFoundError):
384 # File might be deleted by another process or have permission issues
385 # Just continue without updating the access time
386 pass