1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import hashlib
17import os
18import stat
19import tempfile
20import threading
21from collections import OrderedDict
22from dataclasses import dataclass
23from datetime import datetime
24from typing import Any, List, Optional, Tuple, Union
25
26from filelock import BaseFileLock, FileLock, Timeout
27
28from .instrumentation.utils import CacheManagerMetricsHelper
29
30DEFAULT_CACHE_SIZE_MB = 10_000 # 10 GB
31DEFAULT_CACHE_REFRESH_INTERVAL = 300 # 5 minutes
32DEFAULT_LOCK_TIMEOUT = 600 # 10 minutes
33
34
[docs]
35@dataclass
36class CacheConfig:
37 """
38 Configuration for the :py:class:`CacheManager`.
39 """
40
41 #: The directory where the cache is stored.
42 location: str
43 #: The maximum size of the cache in megabytes.
44 size_mb: int
45 #: Use etag to update the cached files.
46 use_etag: bool
47
[docs]
48 def size_bytes(self) -> int:
49 """
50 Convert cache size from megabytes to bytes.
51
52 :return: The size of the cache in bytes.
53 """
54 return self.size_mb * 1024 * 1024
55
56
57class CacheManager:
58 """
59 A cache manager that stores files in a specified directory and evicts files based on the LRU policy.
60 """
61
62 def __init__(
63 self, profile: str, cache_config: CacheConfig, cache_refresh_interval: int = DEFAULT_CACHE_REFRESH_INTERVAL
64 ):
65 self._profile = profile
66 self._cache_config = cache_config
67 self._max_cache_size = cache_config.size_bytes()
68 self._cache_refresh_interval = cache_refresh_interval
69 self._cache_load_factor = 0.7
70 self._last_refresh_time = datetime.now()
71
72 # Metrics
73 self._metrics_helper = CacheManagerMetricsHelper()
74
75 # Create cache directory
76 self._cache_dir = cache_config.location
77 os.makedirs(self._cache_dir, exist_ok=True)
78 os.makedirs(os.path.join(self._cache_dir, self._profile), exist_ok=True)
79
80 # Populate cache with existing files in the cache directory
81 self._cache_refresh_lock_file = FileLock(
82 os.path.join(self._cache_dir, ".cache_refresh.lock"), timeout=0, blocking=False
83 )
84 self.refresh_cache()
85
86 def _get_file_size(self, file_path: str) -> Optional[int]:
87 """
88 Get the size of the file in bytes.
89
90 :return: The file size in bytes.
91 """
92 try:
93 return os.path.getsize(file_path)
94 except OSError:
95 return None
96
97 def _delete(self, file_name: str) -> None:
98 """
99 Delete a file from the cache directory.
100 """
101 try:
102 os.unlink(os.path.join(self._cache_dir, self._profile, file_name))
103 os.unlink(os.path.join(self._cache_dir, self._profile, f".{file_name}.lock"))
104 except OSError:
105 pass
106
107 def _get_cache_key(self, file_name: str) -> str:
108 """
109 Hash the file name using MD5.
110 """
111 return hashlib.md5(file_name.encode()).hexdigest()
112
113 def _should_refresh_cache(self) -> bool:
114 """
115 Check if enough time has passed since the last refresh.
116 """
117 now = datetime.now()
118 return (now - self._last_refresh_time).seconds > self._cache_refresh_interval
119
120 def use_etag(self) -> bool:
121 """
122 Check if ``use_etag`` is set in the cache config.
123 """
124 return self._cache_config.use_etag
125
126 def get_max_cache_size(self) -> int:
127 """
128 Return the cache size in bytes from the cache config.
129 """
130 return self._max_cache_size
131
132 def get_cache_dir(self) -> str:
133 """
134 Return the path to the local cache directory.
135
136 :return: The full path to the cache directory.
137 """
138 return os.path.join(self._cache_dir, self._profile)
139
140 def get_cache_file_path(self, key: str) -> str:
141 """
142 Return the path to the local cache file for the given key.
143
144 :return: The full path to the cached file.
145 """
146 hashed_name = self._get_cache_key(key)
147 return os.path.join(self._cache_dir, self._profile, hashed_name)
148
149 def read(self, key: str) -> Optional[bytes]:
150 """
151 Read the contents of a file from the cache if it exists.
152
153 :param key: The key corresponding to the file to be read.
154
155 :return: The contents of the file as bytes if found in the cache, otherwise None.
156 """
157 success = True
158 try:
159 try:
160 if self.contains(key):
161 with open(self.get_cache_file_path(key), "rb") as fp:
162 return fp.read()
163 except OSError:
164 pass
165
166 # cache miss
167 success = False
168 return None
169 finally:
170 self._metrics_helper.increase(operation="READ", success=success)
171
172 def open(self, key: str, mode: str = "rb") -> Optional[Any]:
173 """
174 Open a file from the cache and return the file object.
175
176 :param key: The key corresponding to the file to be opened.
177 :param mode: The mode in which to open the file (default is ``rb`` for read binary).
178
179 :return: The file object if the file is found in the cache, otherwise None.
180 """
181 success = True
182 try:
183 try:
184 if self.contains(key):
185 return open(self.get_cache_file_path(key), mode)
186 except OSError:
187 pass
188
189 # cache miss
190 success = False
191 return None
192 finally:
193 self._metrics_helper.increase(operation="OPEN", success=success)
194
195 def set(self, key: str, source: Union[str, bytes]) -> None:
196 """
197 Store a file in the cache.
198
199 :param key: The key corresponding to the file to be stored.
200 :param source: The source data to be stored, either a path to a file or bytes.
201 """
202 success = True
203 try:
204 hashed_name = self._get_cache_key(key)
205 file_path = os.path.join(self._cache_dir, self._profile, hashed_name)
206
207 if isinstance(source, str):
208 # Move the file to the cache directory
209 os.rename(src=source, dst=file_path)
210 # Only allow the owner to read and write the file
211 os.chmod(file_path, mode=stat.S_IRUSR | stat.S_IWUSR)
212 else:
213 # Create a temporary file and move the file to the cache directory
214 with tempfile.NamedTemporaryFile(
215 mode="wb", delete=False, dir=os.path.dirname(file_path), prefix="."
216 ) as temp_file:
217 temp_file_path = temp_file.name
218 temp_file.write(source)
219 os.rename(src=temp_file_path, dst=file_path)
220 # Only allow the owner to read and write the file
221 os.chmod(file_path, mode=stat.S_IRUSR | stat.S_IWUSR)
222
223 # Refresh cache after a few minutes
224 if self._should_refresh_cache():
225 thread = threading.Thread(target=self.refresh_cache)
226 thread.start()
227 except Exception:
228 success = False
229 finally:
230 self._metrics_helper.increase(operation="SET", success=success)
231
232 def contains(self, key: str) -> bool:
233 """
234 Check if the cache contains a file corresponding to the given key.
235
236 :param key: The key corresponding to the file.
237
238 :return: True if the file is found in the cache, False otherwise.
239 """
240 hashed_name = self._get_cache_key(key)
241 file_path = os.path.join(self._cache_dir, self._profile, hashed_name)
242 return os.path.exists(file_path)
243
244 def delete(self, key: str) -> None:
245 """
246 Delete a file from the cache.
247 """
248 try:
249 hashed_name = self._get_cache_key(key)
250 self._delete(hashed_name)
251 finally:
252 self._metrics_helper.increase(operation="DELETE", success=True)
253
254 def cache_size(self) -> int:
255 """
256 Return the current size of the cache in bytes.
257
258 :return: The cache size in bytes.
259 """
260 file_size = 0
261
262 # Traverse the directory and subdirectories
263 for dirpath, _, filenames in os.walk(self._cache_dir):
264 for file_name in filenames:
265 file_path = os.path.join(dirpath, file_name)
266 if os.path.isfile(file_path) and not file_name.endswith(".lock"):
267 size = self._get_file_size(file_path)
268 if size:
269 file_size += size
270
271 return file_size
272
273 def _evict_files(self) -> None:
274 """
275 Evict cache entries based on the last modification time.
276 """
277 # list of (file name, last modified time, file size)
278 file_paths: List[Tuple[str, float, Optional[int]]] = []
279
280 # Traverse the directory and subdirectories
281 for dirpath, _, filenames in os.walk(self._cache_dir):
282 for file_name in filenames:
283 file_path = os.path.join(dirpath, file_name)
284 # Skip lock and hidden files
285 if file_name.endswith(".lock") or file_name.startswith("."):
286 continue
287 try:
288 if os.path.isfile(file_path):
289 mtime = os.path.getmtime(file_path)
290 fsize = self._get_file_size(file_path)
291 file_paths.append((file_path, mtime, fsize))
292 except OSError:
293 # Ignore if file has already been evicted
294 pass
295
296 # Sort the files based on the last modified time
297 file_paths.sort(key=lambda tup: tup[1])
298
299 # Rebuild the cache
300 cache = OrderedDict()
301 cache_size = 0
302 for file_path, _, file_size in file_paths:
303 if file_size:
304 cache[file_path] = file_size
305 cache_size += file_size
306
307 # Evict old files if necessary in case the existing files exceed cache size
308 while (cache_size / self._max_cache_size) > self._cache_load_factor:
309 # Pop the first (oldest) item in the OrderedDict (LRU eviction)
310 oldest_file, file_size = cache.popitem(last=False)
311 cache_size -= file_size
312 self._delete(oldest_file)
313
314 def refresh_cache(self) -> bool:
315 """
316 Scan the cache directory and evict cache entries based on the last modification time.
317 This method is protected by a :py:class:`filelock.FileLock` that only allows a single process to evict the cached files.
318 """
319 try:
320 # If the process acquires the lock, then proceed with the cache eviction
321 with self._cache_refresh_lock_file.acquire(blocking=False):
322 self._evict_files()
323 self._last_refresh_time = datetime.now()
324 return True
325 except Timeout:
326 # If the process cannot acquire the lock, ignore and wait for the next turn
327 pass
328
329 return False
330
331 def acquire_lock(self, key: str) -> BaseFileLock:
332 """
333 Create a :py:class:`filelock.FileLock` object for a given key.
334
335 :return: :py:class:`filelock.FileLock` object.
336 """
337 hashed_name = self._get_cache_key(key)
338 lock_file = os.path.join(self._cache_dir, self._profile, f".{hashed_name}.lock")
339 return FileLock(lock_file, timeout=DEFAULT_LOCK_TIMEOUT)