1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import logging
17import multiprocessing
18import os
19import queue
20import tempfile
21import threading
22from concurrent.futures import ThreadPoolExecutor
23from enum import Enum
24from typing import Any, Dict, Iterator, List, Optional, Union
25
26from .config import StorageClientConfig
27from .constants import MEMORY_LOAD_LIMIT
28from .file import ObjectFile, PosixFile
29from .instrumentation.utils import instrumented
30from .providers.posix_file import PosixFileStorageProvider
31from .retry import retry
32from .types import MSC_PROTOCOL, ObjectMetadata, Range
33from .utils import join_paths
34
35logger = logging.Logger(__name__)
36
37
38class _SyncOp(Enum):
39 ADD = "add"
40 DELETE = "delete"
41 STOP = "stop"
42
43
[docs]
44@instrumented
45class StorageClient:
46 """
47 A client for interacting with different storage providers.
48 """
49
50 _config: StorageClientConfig
51
52 def __init__(self, config: StorageClientConfig):
53 """
54 Initializes the :py:class:`StorageClient` with the given configuration.
55
56 :param config: The configuration object for the storage client.
57 """
58 self._initialize_providers(config)
59
60 def _initialize_providers(self, config: StorageClientConfig) -> None:
61 self._config = config
62 self._credentials_provider = self._config.credentials_provider
63 self._storage_provider = self._config.storage_provider
64 self._metadata_provider = self._config.metadata_provider
65 self._cache_config = self._config.cache_config
66 self._retry_config = self._config.retry_config
67 self._cache_manager = self._config.cache_manager
68
69 def _build_cache_path(self, path: str) -> str:
70 """
71 Build cache path with or without etag.
72 """
73 cache_path = f"{path}:{None}"
74
75 if self._metadata_provider:
76 if self._cache_manager and self._cache_manager.use_etag():
77 metadata = self._metadata_provider.get_object_metadata(path)
78 cache_path = f"{path}:{metadata.etag}"
79 else:
80 if self._cache_manager and self._cache_manager.use_etag():
81 metadata = self._storage_provider.get_object_metadata(path)
82 cache_path = f"{path}:{metadata.etag}"
83
84 return cache_path
85
86 def _is_cache_enabled(self) -> bool:
87 return self._cache_manager is not None and not self._is_posix_file_storage_provider()
88
89 def _is_posix_file_storage_provider(self) -> bool:
90 return isinstance(self._storage_provider, PosixFileStorageProvider)
91
[docs]
92 def is_default_profile(self) -> bool:
93 """
94 Return True if the storage client is using the default profile.
95 """
96 return self._config.profile == "default"
97
98 @property
99 def profile(self) -> str:
100 return self._config.profile
101
102 @retry
103 def read(self, path: str, byte_range: Optional[Range] = None) -> bytes:
104 """
105 Reads an object from the storage provider at the specified path.
106
107 :param path: The path of the object to read.
108 :return: The content of the object.
109 """
110 if self._metadata_provider:
111 path, exists = self._metadata_provider.realpath(path)
112 if not exists:
113 raise FileNotFoundError(f"The file at path '{path}' was not found.")
114
115 # Never cache range-read requests
116 if byte_range:
117 return self._storage_provider.get_object(path, byte_range=byte_range)
118
119 # Read from cache if the file exists
120 if self._is_cache_enabled():
121 assert self._cache_manager is not None
122 cache_path = self._build_cache_path(path)
123 data = self._cache_manager.read(cache_path)
124
125 if data is None:
126 data = self._storage_provider.get_object(path)
127 self._cache_manager.set(cache_path, data)
128
129 return data
130
131 return self._storage_provider.get_object(path, byte_range=byte_range)
132
[docs]
133 def info(self, path: str, strict: bool = True) -> ObjectMetadata:
134 """
135 Retrieves metadata or information about an object stored at the specified path.
136
137 :param path: The path to the object for which metadata or information is being retrieved.
138 :param strict: If True, performs additional validation to determine whether the path refers to a directory.
139
140 :return: A dictionary containing metadata or information about the object.
141 """
142 if self._metadata_provider:
143 return self._metadata_provider.get_object_metadata(path)
144 else:
145 return self._storage_provider.get_object_metadata(path, strict=strict)
146
147 @retry
148 def download_file(self, remote_path: str, local_path: str) -> None:
149 """
150 Downloads a file from the storage provider to the local file system.
151
152 :param remote_path: The path of the file in the storage provider.
153 :param local_path: The local path where the file should be downloaded.
154 """
155
156 if self._metadata_provider:
157 real_path, exists = self._metadata_provider.realpath(remote_path)
158 if not exists:
159 raise FileNotFoundError(f"The file at path '{remote_path}' was not found by metadata provider.")
160 metadata = self._metadata_provider.get_object_metadata(remote_path)
161 self._storage_provider.download_file(real_path, local_path, metadata)
162 else:
163 self._storage_provider.download_file(remote_path, local_path)
164
165 @retry
166 def upload_file(self, remote_path: str, local_path: str) -> None:
167 """
168 Uploads a file from the local file system to the storage provider.
169
170 :param remote_path: The path where the file should be stored in the storage provider.
171 :param local_path: The local path of the file to upload.
172 """
173 virtual_path = remote_path
174 if self._metadata_provider:
175 remote_path, exists = self._metadata_provider.realpath(remote_path)
176 if exists:
177 raise FileExistsError(
178 f"The file at path '{virtual_path}' already exists; "
179 f"overwriting is not yet allowed when using a metadata provider."
180 )
181 self._storage_provider.upload_file(remote_path, local_path)
182 if self._metadata_provider:
183 metadata = self._storage_provider.get_object_metadata(remote_path)
184 self._metadata_provider.add_file(virtual_path, metadata)
185
186 @retry
187 def write(self, path: str, body: bytes) -> None:
188 """
189 Writes an object to the storage provider at the specified path.
190
191 :param path: The path where the object should be written.
192 :param body: The content to write to the object.
193 """
194 virtual_path = path
195 if self._metadata_provider:
196 path, exists = self._metadata_provider.realpath(path)
197 if exists:
198 raise FileExistsError(
199 f"The file at path '{virtual_path}' already exists; "
200 f"overwriting is not yet allowed when using a metadata provider."
201 )
202 self._storage_provider.put_object(path, body)
203 if self._metadata_provider:
204 # TODO(NGCDP-3016): Handle eventual consistency of Swiftstack, without wait.
205 metadata = self._storage_provider.get_object_metadata(path)
206 self._metadata_provider.add_file(virtual_path, metadata)
207
[docs]
208 def copy(self, src_path: str, dest_path: str) -> None:
209 """
210 Copies an object from source to destination in the storage provider.
211
212 :param src_path: The virtual path of the source object to copy.
213 :param dest_path: The virtual path of the destination.
214 """
215 virtual_dest_path = dest_path
216 if self._metadata_provider:
217 src_path, exists = self._metadata_provider.realpath(src_path)
218 if not exists:
219 raise FileNotFoundError(f"The file at path '{src_path}' was not found.")
220
221 dest_path, exists = self._metadata_provider.realpath(dest_path)
222 if exists:
223 raise FileExistsError(
224 f"The file at path '{virtual_dest_path}' already exists; "
225 f"overwriting is not yet allowed when using a metadata provider."
226 )
227
228 self._storage_provider.copy_object(src_path, dest_path)
229 if self._metadata_provider:
230 metadata = self._storage_provider.get_object_metadata(dest_path)
231 self._metadata_provider.add_file(virtual_dest_path, metadata)
232
[docs]
233 def delete(self, path: str) -> None:
234 """
235 Deletes an object from the storage provider at the specified path.
236
237 :param path: The virtual path of the object to delete.
238 """
239 virtual_path = path
240 if self._metadata_provider:
241 path, exists = self._metadata_provider.realpath(path)
242 if not exists:
243 raise FileNotFoundError(f"The file at path '{virtual_path}' was not found.")
244 self._metadata_provider.remove_file(virtual_path)
245
246 # Delete the cached file if it exists
247 if self._is_cache_enabled():
248 assert self._cache_manager is not None
249 cache_path = self._build_cache_path(path)
250 self._cache_manager.delete(cache_path)
251
252 self._storage_provider.delete_object(path)
253
[docs]
254 def glob(self, pattern: str, include_url_prefix: bool = False) -> List[str]:
255 """
256 Matches and retrieves a list of objects in the storage provider that
257 match the specified pattern.
258
259 :param pattern: The pattern to match object paths against, supporting wildcards (e.g., ``*.txt``).
260 :param include_url_prefix: Whether to include the URL prefix ``msc://profile`` in the result.
261
262 :return: A list of object paths that match the pattern.
263 """
264 if self._metadata_provider:
265 results = self._metadata_provider.glob(pattern)
266 else:
267 results = self._storage_provider.glob(pattern)
268
269 if include_url_prefix:
270 results = [join_paths(f"{MSC_PROTOCOL}{self._config.profile}", path) for path in results]
271
272 return results
273
[docs]
274 def list(
275 self,
276 prefix: str = "",
277 start_after: Optional[str] = None,
278 end_at: Optional[str] = None,
279 include_directories: bool = False,
280 include_url_prefix: bool = False,
281 ) -> Iterator[ObjectMetadata]:
282 """
283 Lists objects in the storage provider under the specified prefix.
284
285 :param prefix: The prefix to list objects under.
286 :param start_after: The key to start after (i.e. exclusive). An object with this key doesn't have to exist.
287 :param end_at: The key to end at (i.e. inclusive). An object with this key doesn't have to exist.
288 :param include_directories: Whether to include directories in the result. When True, directories are returned alongside objects.
289 :param include_url_prefix: Whether to include the URL prefix ``msc://profile`` in the result.
290
291 :return: An iterator over objects.
292 """
293 if self._metadata_provider:
294 objects = self._metadata_provider.list_objects(prefix, start_after, end_at, include_directories)
295 else:
296 objects = self._storage_provider.list_objects(prefix, start_after, end_at, include_directories)
297
298 for object in objects:
299 if include_url_prefix:
300 object.key = join_paths(f"{MSC_PROTOCOL}{self._config.profile}", object.key)
301 yield object
302
[docs]
303 def open(
304 self,
305 path: str,
306 mode: str = "rb",
307 buffering: int = -1,
308 encoding: Optional[str] = None,
309 disable_read_cache: bool = False,
310 memory_load_limit: int = MEMORY_LOAD_LIMIT,
311 ) -> Union[PosixFile, ObjectFile]:
312 """
313 Returns a file-like object from the storage provider at the specified path.
314
315 :param path: The path of the object to read.
316 :param mode: The file mode, only "w", "r", "a", "wb", "rb" and "ab" are supported.
317 :param buffering: The buffering mode. Only applies to PosixFile.
318 :param encoding: The encoding to use for text files.
319 :param disable_read_cache: When set to True, disables caching for the file content.
320 This parameter is only applicable to ObjectFile when the mode is "r" or "rb".
321 :param memory_load_limit: Size limit in bytes for loading files into memory. Defaults to 512MB.
322 This parameter is only applicable to ObjectFile when the mode is "r" or "rb".
323
324 :return: A file-like object (PosixFile or ObjectFile) for the specified path.
325 """
326 if self._is_posix_file_storage_provider():
327 return PosixFile(self, path=path, mode=mode, buffering=buffering, encoding=encoding)
328 else:
329 return ObjectFile(
330 self,
331 remote_path=path,
332 mode=mode,
333 encoding=encoding,
334 disable_read_cache=disable_read_cache,
335 memory_load_limit=memory_load_limit,
336 )
337
[docs]
338 def is_file(self, path: str) -> bool:
339 """
340 Checks whether the specified path points to a file (rather than a directory or folder).
341
342 :param path: The path to check.
343
344 :return: ``True`` if the path points to a file, ``False`` otherwise.
345 """
346 if self._metadata_provider:
347 _, exists = self._metadata_provider.realpath(path)
348 return exists
349 return self._storage_provider.is_file(path)
350
[docs]
351 def commit_updates(self, prefix: Optional[str] = None) -> None:
352 """
353 Commits any pending updates to the metadata provider. No-op if not using a metadata provider.
354
355 :param prefix: If provided, scans the prefix to find files to commit.
356 """
357 if self._metadata_provider:
358 if prefix:
359 for obj in self._storage_provider.list_objects(prefix=prefix):
360 fullpath = os.path.join(prefix, obj.key)
361 self._metadata_provider.add_file(fullpath, obj)
362 self._metadata_provider.commit_updates()
363
[docs]
364 def is_empty(self, path: str) -> bool:
365 """
366 Checks whether the specified path is empty. A path is considered empty if there are no
367 objects whose keys start with the given path as a prefix.
368
369 :param path: The path to check. This is typically a prefix representing a directory or folder.
370
371 :return: ``True`` if no objects exist under the specified path prefix, ``False`` otherwise.
372 """
373 if self._metadata_provider:
374 objects = self._metadata_provider.list_objects(path)
375 else:
376 objects = self._storage_provider.list_objects(path)
377
378 try:
379 return next(objects) is None
380 except StopIteration:
381 pass
382 return True
383
384 def __getstate__(self) -> Dict[str, Any]:
385 state = self.__dict__.copy()
386 del state["_credentials_provider"]
387 del state["_storage_provider"]
388 del state["_metadata_provider"]
389 del state["_cache_manager"]
390 return state
391
392 def __setstate__(self, state: Dict[str, Any]) -> None:
393 config = state["_config"]
394 self._initialize_providers(config)
395
[docs]
396 def sync_from(
397 self,
398 source_client: "StorageClient",
399 source_path: str = "",
400 target_path: str = "",
401 delete_unmatched_files: bool = False,
402 ) -> None:
403 """
404 Syncs files from the source storage client to "path/".
405
406 :param source_client: The source storage client.
407 :param source_path: The path to sync from.
408 :param target_path: The path to sync to.
409 :param delete_unmatched_files: Whether to delete files at the target that are not present at the source.
410 """
411 if source_client == self and (source_path.startswith(target_path) or target_path.startswith(source_path)):
412 raise ValueError("Source and target paths cannot overlap on same StorageClient.")
413
414 # Attempt to balance the number of worker processes and threads.
415 cpu_count = multiprocessing.cpu_count()
416 default_processes = "8" if cpu_count > 8 else str(cpu_count)
417 num_worker_processes = int(os.getenv("MSC_NUM_PROCESSES", default_processes))
418 num_worker_threads = int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", max(cpu_count // num_worker_processes, 16)))
419
420 if num_worker_processes == 1:
421 file_queue = queue.Queue(maxsize=2000)
422 result_queue = None
423 else:
424 manager = multiprocessing.Manager()
425 file_queue = manager.Queue(maxsize=2000)
426 # Only need result_queue if using a metadata provider and multiprocessing.
427 result_queue = manager.Queue() if self._metadata_provider else None
428
429 def match_file_metadata(source_info: ObjectMetadata, target_info: ObjectMetadata) -> bool:
430 # If target and source have valid etags defined, use etag and file size to compare.
431 if source_info.etag and target_info.etag:
432 return source_info.etag == target_info.etag and source_info.content_length == target_info.content_length
433 # Else, check file size is the same and the target's last_modified is newer than the source.
434 return (
435 source_info.content_length == target_info.content_length
436 and source_info.last_modified <= target_info.last_modified
437 )
438
439 def producer():
440 """Lists source files and adds them to the queue."""
441 source_iter = iter(source_client.list(prefix=source_path))
442 target_iter = iter(self.list(prefix=target_path))
443
444 source_file = next(source_iter, None)
445 target_file = next(target_iter, None)
446
447 while source_file or target_file:
448 if source_file and target_file:
449 source_key = source_file.key[len(source_path) :].lstrip("/")
450 target_key = target_file.key[len(target_path) :].lstrip("/")
451
452 if source_key < target_key:
453 file_queue.put((_SyncOp.ADD, source_file))
454 source_file = next(source_iter, None)
455 elif source_key > target_key:
456 if delete_unmatched_files:
457 file_queue.put((_SyncOp.DELETE, target_file))
458 target_file = next(target_iter, None) # Skip unmatched target file
459 else:
460 # Both exist, compare metadata
461 if not match_file_metadata(source_file, target_file):
462 file_queue.put((_SyncOp.ADD, source_file))
463 source_file = next(source_iter, None)
464 target_file = next(target_iter, None)
465 elif source_file:
466 file_queue.put((_SyncOp.ADD, source_file))
467 source_file = next(source_iter, None)
468 else:
469 if delete_unmatched_files:
470 assert target_file is not None
471 file_queue.put((_SyncOp.DELETE, target_file))
472 target_file = next(target_iter, None)
473
474 for _ in range(num_worker_threads * num_worker_processes):
475 file_queue.put((_SyncOp.STOP, None)) # Signal consumers to stop
476
477 producer_thread = threading.Thread(target=producer, daemon=True)
478 producer_thread.start()
479
480 if num_worker_processes == 1:
481 # Single process does not require multiprocessing.
482 _sync_worker_process(
483 source_client, source_path, self, target_path, num_worker_threads, file_queue, result_queue
484 )
485 else:
486 with multiprocessing.Pool(processes=num_worker_processes) as pool:
487 pool.apply(
488 _sync_worker_process,
489 args=(source_client, source_path, self, target_path, num_worker_threads, file_queue, result_queue),
490 )
491
492 producer_thread.join()
493
494 # Pull from result_queue to collect pending updates from each multiprocessing worker.
495 if result_queue and self._metadata_provider:
496 while not result_queue.empty():
497 op, target_file_path, physical_metadata = result_queue.get()
498 if op == _SyncOp.ADD:
499 # Use realpath() to get physical path so metadata provider can
500 # track the logical/physical mapping.
501 phys_path, _ = self._metadata_provider.realpath(target_file_path)
502 physical_metadata.key = phys_path
503 self._metadata_provider.add_file(target_file_path, physical_metadata)
504 elif op == _SyncOp.DELETE:
505 self._metadata_provider.remove_file(target_file_path)
506 else:
507 raise RuntimeError(f"Unknown operation: {op}")
508
509 self.commit_updates()
510
511
512def _sync_worker_process(
513 source_client: StorageClient,
514 source_path: str,
515 target_client: StorageClient,
516 target_path: str,
517 num_worker_threads: int,
518 file_queue: queue.Queue,
519 result_queue: Optional[queue.Queue],
520):
521 """Helper function for sync_from, defined at top-level for multiprocessing."""
522
523 def _sync_consumer() -> None:
524 """Processes files from the queue and copies them."""
525 while True:
526 op, file_metadata = file_queue.get()
527 if op == _SyncOp.STOP:
528 break
529
530 source_key = file_metadata.key[len(source_path) :].lstrip("/")
531 target_file_path = os.path.join(target_path, source_key)
532
533 if op == _SyncOp.ADD:
534 logger.debug(f"sync {file_metadata.key} -> {target_file_path}")
535 if file_metadata.content_length < MEMORY_LOAD_LIMIT:
536 file_content = source_client.read(file_metadata.key)
537 target_client.write(target_file_path, file_content)
538 else:
539 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
540 temp_filename = temp_file.name
541
542 try:
543 source_client.download_file(file_metadata.key, temp_filename)
544 target_client.upload_file(target_file_path, temp_filename)
545 finally:
546 os.remove(temp_filename) # Ensure the temporary file is removed
547 elif op == _SyncOp.DELETE:
548 logger.debug(f"rm {file_metadata.key}")
549 target_client.delete(file_metadata.key)
550 else:
551 raise ValueError(f"Unknown operation: {op}")
552
553 if result_queue:
554 if op == _SyncOp.ADD:
555 # add tuple of (virtual_path, physical_metadata) to result_queue
556 physical_metadata = target_client._metadata_provider.get_object_metadata(
557 target_file_path, include_pending=True
558 )
559 result_queue.put((op, target_file_path, physical_metadata))
560 elif op == _SyncOp.DELETE:
561 result_queue.put((op, target_file_path, None))
562 else:
563 raise RuntimeError(f"Unknown operation: {op}")
564
565 """Worker process that spawns threads to handle syncing."""
566 with ThreadPoolExecutor(max_workers=num_worker_threads) as executor:
567 futures = [executor.submit(_sync_consumer) for _ in range(num_worker_threads)]
568 for future in futures:
569 future.result() # Ensure all threads complete