1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import logging
17import multiprocessing
18import os
19import queue
20import tempfile
21import threading
22from collections.abc import Iterator
23from concurrent.futures import ThreadPoolExecutor
24from enum import Enum
25from pathlib import PurePosixPath
26from typing import Any, Optional, Union, cast
27
28from .config import StorageClientConfig
29from .constants import MEMORY_LOAD_LIMIT
30from .file import ObjectFile, PosixFile
31from .instrumentation.utils import instrumented
32from .progress_bar import ProgressBar
33from .providers.posix_file import PosixFileStorageProvider
34from .retry import retry
35from .types import MSC_PROTOCOL, ObjectMetadata, Range, SourceVersionCheckMode
36from .utils import NullStorageClient, calculate_worker_processes_and_threads, join_paths
37
38logger = logging.Logger(__name__)
39
40
41class _SyncOp(Enum):
42 ADD = "add"
43 DELETE = "delete"
44 STOP = "stop"
45
46
[docs]
47@instrumented
48class StorageClient:
49 """
50 A client for interacting with different storage providers.
51 """
52
53 _config: StorageClientConfig
54
55 def __init__(self, config: StorageClientConfig):
56 """
57 Initializes the :py:class:`StorageClient` with the given configuration.
58
59 :param config: The configuration object for the storage client.
60 """
61 self._initialize_providers(config)
62
63 def _initialize_providers(self, config: StorageClientConfig) -> None:
64 self._config = config
65 self._credentials_provider = self._config.credentials_provider
66 self._storage_provider = self._config.storage_provider
67 self._metadata_provider = self._config.metadata_provider
68 self._cache_config = self._config.cache_config
69 self._retry_config = self._config.retry_config
70 self._cache_manager = self._config.cache_manager
71
72 def _get_source_version(self, path: str) -> Optional[str]:
73 """
74 Get etag from metadata provider or storage provider.
75 """
76 if self._metadata_provider:
77 metadata = self._metadata_provider.get_object_metadata(path)
78 else:
79 metadata = self._storage_provider.get_object_metadata(path)
80 return metadata.etag
81
82 def _is_cache_enabled(self) -> bool:
83 return self._cache_manager is not None and not self._is_posix_file_storage_provider()
84
85 def _is_posix_file_storage_provider(self) -> bool:
86 return isinstance(self._storage_provider, PosixFileStorageProvider)
87
[docs]
88 def is_default_profile(self) -> bool:
89 """
90 Return True if the storage client is using the default profile.
91 """
92 return self._config.profile == "default"
93
94 @property
95 def profile(self) -> str:
96 return self._config.profile
97
98 @retry
99 def read(self, path: str, byte_range: Optional[Range] = None) -> bytes:
100 """
101 Reads an object from the specified logical path.
102
103 :param path: The logical path of the object to read.
104 :return: The content of the object.
105 """
106 if self._metadata_provider:
107 path, exists = self._metadata_provider.realpath(path)
108 if not exists:
109 raise FileNotFoundError(f"The file at path '{path}' was not found.")
110
111 # Never cache range-read requests
112 if byte_range:
113 return self._storage_provider.get_object(path, byte_range=byte_range)
114
115 # Read from cache if the file exists
116 if self._is_cache_enabled():
117 assert self._cache_manager is not None
118 source_version = self._get_source_version(path)
119 data = self._cache_manager.read(path, source_version)
120
121 if data is None:
122 data = self._storage_provider.get_object(path)
123 self._cache_manager.set(path, data, source_version)
124
125 return data
126
127 return self._storage_provider.get_object(path, byte_range=byte_range)
128
[docs]
129 def info(self, path: str, strict: bool = True) -> ObjectMetadata:
130 """
131 Retrieves metadata or information about an object stored at the specified path.
132
133 :param path: The logical path to the object for which metadata or information is being retrieved.
134 :param strict: If True, performs additional validation to determine whether the path refers to a directory.
135
136 :return: A dictionary containing metadata about the object.
137 """
138 if not self._metadata_provider:
139 return self._storage_provider.get_object_metadata(path, strict=strict)
140
141 # For metadata_provider, first check if the path exists as a file, then fallback to detecting if path is a directory.
142 try:
143 return self._metadata_provider.get_object_metadata(path)
144 except FileNotFoundError:
145 # Try listing from the parent to determine if path is a valid directory
146 parent = os.path.dirname(path.rstrip("/")) + "/"
147 parent = "" if parent == "/" else parent
148 target = path.rstrip("/") + "/"
149
150 try:
151 entries = self._metadata_provider.list_objects(parent, include_directories=True)
152 for entry in entries:
153 if entry.key == target and entry.type == "directory":
154 return entry
155 except Exception:
156 pass
157 raise # Raise original FileNotFoundError
158
159 @retry
160 def download_file(self, remote_path: str, local_path: str) -> None:
161 """
162 Downloads a file to the local file system.
163
164 :param remote_path: The logical path of the file in the storage provider.
165 :param local_path: The local path where the file should be downloaded.
166 """
167
168 if self._metadata_provider:
169 real_path, exists = self._metadata_provider.realpath(remote_path)
170 if not exists:
171 raise FileNotFoundError(f"The file at path '{remote_path}' was not found by metadata provider.")
172 metadata = self._metadata_provider.get_object_metadata(remote_path)
173 self._storage_provider.download_file(real_path, local_path, metadata)
174 else:
175 self._storage_provider.download_file(remote_path, local_path)
176
177 @retry
178 def upload_file(self, remote_path: str, local_path: str, attributes: Optional[dict[str, str]] = None) -> None:
179 """
180 Uploads a file from the local file system.
181
182 :param remote_path: The logical path where the file should be stored.
183 :param local_path: The local path of the file to upload.
184 :param attributes: The attributes to add to the file.
185 """
186 virtual_path = remote_path
187 if self._metadata_provider:
188 remote_path, exists = self._metadata_provider.realpath(remote_path)
189 if exists:
190 raise FileExistsError(
191 f"The file at path '{virtual_path}' already exists; "
192 f"overwriting is not yet allowed when using a metadata provider."
193 )
194 self._storage_provider.upload_file(remote_path, local_path, attributes)
195 if self._metadata_provider:
196 metadata = self._storage_provider.get_object_metadata(remote_path)
197 self._metadata_provider.add_file(virtual_path, metadata)
198
199 @retry
200 def write(self, path: str, body: bytes, attributes: Optional[dict[str, str]] = None) -> None:
201 """
202 Writes an object at the specified path.
203
204 :param path: The logical path where the object should be written.
205 :param body: The content to write to the object.
206 :param attributes: The attributes to add to the file.
207 """
208 virtual_path = path
209 if self._metadata_provider:
210 path, exists = self._metadata_provider.realpath(path)
211 if exists:
212 raise FileExistsError(
213 f"The file at path '{virtual_path}' already exists; "
214 f"overwriting is not yet allowed when using a metadata provider."
215 )
216 self._storage_provider.put_object(path, body, attributes=attributes)
217 if self._metadata_provider:
218 # TODO(NGCDP-3016): Handle eventual consistency of Swiftstack, without wait.
219 metadata = self._storage_provider.get_object_metadata(path)
220 self._metadata_provider.add_file(virtual_path, metadata)
221
[docs]
222 def copy(self, src_path: str, dest_path: str) -> None:
223 """
224 Copies an object from source to destination path.
225
226 :param src_path: The logical path of the source object to copy.
227 :param dest_path: The logical path of the destination.
228 """
229 virtual_dest_path = dest_path
230 if self._metadata_provider:
231 src_path, exists = self._metadata_provider.realpath(src_path)
232 if not exists:
233 raise FileNotFoundError(f"The file at path '{src_path}' was not found.")
234
235 dest_path, exists = self._metadata_provider.realpath(dest_path)
236 if exists:
237 raise FileExistsError(
238 f"The file at path '{virtual_dest_path}' already exists; "
239 f"overwriting is not yet allowed when using a metadata provider."
240 )
241
242 self._storage_provider.copy_object(src_path, dest_path)
243 if self._metadata_provider:
244 metadata = self._storage_provider.get_object_metadata(dest_path)
245 self._metadata_provider.add_file(virtual_dest_path, metadata)
246
[docs]
247 def delete(self, path: str, recursive: bool = False) -> None:
248 """
249 Deletes an object at the specified path.
250
251 :param path: The logical path of the object to delete.
252 :param recursive: Whether to delete objects in the path recursively.
253 """
254 if recursive:
255 self.sync_from(
256 NullStorageClient(),
257 path,
258 path,
259 delete_unmatched_files=True,
260 num_worker_processes=1,
261 description="Deleting",
262 )
263 # If this is a posix storage provider, we need to also delete remaining directory stubs.
264 if self._is_posix_file_storage_provider():
265 posix_storage_provider = cast(PosixFileStorageProvider, self._storage_provider)
266 posix_storage_provider.rmtree(path)
267 return
268
269 virtual_path = path
270 if self._metadata_provider:
271 path, exists = self._metadata_provider.realpath(path)
272 if not exists:
273 raise FileNotFoundError(f"The file at path '{virtual_path}' was not found.")
274 self._metadata_provider.remove_file(virtual_path)
275
276 # Delete the cached file if it exists
277 if self._is_cache_enabled():
278 assert self._cache_manager is not None
279 self._cache_manager.delete(virtual_path)
280
281 self._storage_provider.delete_object(path)
282
[docs]
283 def glob(
284 self,
285 pattern: str,
286 include_url_prefix: bool = False,
287 attribute_filter_expression: Optional[str] = None,
288 ) -> list[str]:
289 """
290 Matches and retrieves a list of objects in the storage provider that
291 match the specified pattern.
292
293 :param pattern: The pattern to match object paths against, supporting wildcards (e.g., ``*.txt``).
294 :param include_url_prefix: Whether to include the URL prefix ``msc://profile`` in the result.
295 :param attribute_filter_expression: The attribute filter expression to apply to the result.
296
297 :return: A list of object paths that match the pattern.
298 """
299 if self._metadata_provider:
300 results = self._metadata_provider.glob(pattern)
301 else:
302 results = self._storage_provider.glob(pattern, attribute_filter_expression)
303
304 if include_url_prefix:
305 results = [join_paths(f"{MSC_PROTOCOL}{self._config.profile}", path) for path in results]
306
307 return results
308
[docs]
309 def list(
310 self,
311 prefix: str = "",
312 start_after: Optional[str] = None,
313 end_at: Optional[str] = None,
314 include_directories: bool = False,
315 include_url_prefix: bool = False,
316 attribute_filter_expression: Optional[str] = None,
317 ) -> Iterator[ObjectMetadata]:
318 """
319 Lists objects in the storage provider under the specified prefix.
320
321 :param prefix: The prefix to list objects under.
322 :param start_after: The key to start after (i.e. exclusive). An object with this key doesn't have to exist.
323 :param end_at: The key to end at (i.e. inclusive). An object with this key doesn't have to exist.
324 :param include_directories: Whether to include directories in the result. When True, directories are returned alongside objects.
325 :param include_url_prefix: Whether to include the URL prefix ``msc://profile`` in the result.
326 :param attribute_filter_expression: The attribute filter expression to apply to the result.
327
328 :return: An iterator over objects.
329 """
330 if self._metadata_provider:
331 objects = self._metadata_provider.list_objects(prefix, start_after, end_at, include_directories)
332 else:
333 objects = self._storage_provider.list_objects(
334 prefix, start_after, end_at, include_directories, attribute_filter_expression
335 )
336
337 for object in objects:
338 if include_url_prefix:
339 if self.is_default_profile():
340 object.key = str(PurePosixPath("/") / object.key)
341 else:
342 object.key = join_paths(f"{MSC_PROTOCOL}{self._config.profile}", object.key)
343 yield object
344
[docs]
345 def open(
346 self,
347 path: str,
348 mode: str = "rb",
349 buffering: int = -1,
350 encoding: Optional[str] = None,
351 disable_read_cache: bool = False,
352 memory_load_limit: int = MEMORY_LOAD_LIMIT,
353 atomic: bool = True,
354 check_source_version: SourceVersionCheckMode = SourceVersionCheckMode.INHERIT,
355 attributes: Optional[dict[str, str]] = None,
356 ) -> Union[PosixFile, ObjectFile]:
357 """
358 Returns a file-like object from the specified path.
359
360 :param path: The logical path of the object to read.
361 :param mode: The file mode, only "w", "r", "a", "wb", "rb" and "ab" are supported.
362 :param buffering: The buffering mode. Only applies to PosixFile.
363 :param encoding: The encoding to use for text files.
364 :param disable_read_cache: When set to True, disables caching for the file content.
365 This parameter is only applicable to ObjectFile when the mode is "r" or "rb".
366 :param memory_load_limit: Size limit in bytes for loading files into memory. Defaults to 512MB.
367 This parameter is only applicable to ObjectFile when the mode is "r" or "rb".
368 :param atomic: When set to True, the file will be written atomically (rename upon close).
369 This parameter is only applicable to PosixFile in write mode.
370 :param check_source_version: Whether to check the source version of cached objects.
371 :param attributes: The attributes to add to the file. This parameter is only applicable when the mode is "w" or "wb" or "a" or "ab".
372 :return: A file-like object (PosixFile or ObjectFile) for the specified path.
373 """
374 if self._is_posix_file_storage_provider():
375 return PosixFile(
376 self, path=path, mode=mode, buffering=buffering, encoding=encoding, atomic=atomic, attributes=attributes
377 )
378 else:
379 if atomic is False:
380 logger.warning("Non-atomic writes are not supported for object storage providers.")
381
382 return ObjectFile(
383 self,
384 remote_path=path,
385 mode=mode,
386 encoding=encoding,
387 disable_read_cache=disable_read_cache,
388 memory_load_limit=memory_load_limit,
389 check_source_version=check_source_version,
390 attributes=attributes,
391 )
392
[docs]
393 def is_file(self, path: str) -> bool:
394 """
395 Checks whether the specified path points to a file (rather than a directory or folder).
396
397 :param path: The logical path to check.
398
399 :return: ``True`` if the path points to a file, ``False`` otherwise.
400 """
401 if self._metadata_provider:
402 _, exists = self._metadata_provider.realpath(path)
403 return exists
404 return self._storage_provider.is_file(path)
405
422
[docs]
423 def is_empty(self, path: str) -> bool:
424 """
425 Checks whether the specified path is empty. A path is considered empty if there are no
426 objects whose keys start with the given path as a prefix.
427
428 :param path: The logical path to check. This is typically a prefix representing a directory or folder.
429
430 :return: ``True`` if no objects exist under the specified path prefix, ``False`` otherwise.
431 """
432 if self._metadata_provider:
433 objects = self._metadata_provider.list_objects(path)
434 else:
435 objects = self._storage_provider.list_objects(path)
436
437 try:
438 return next(objects) is None
439 except StopIteration:
440 pass
441 return True
442
443 def __getstate__(self) -> dict[str, Any]:
444 state = self.__dict__.copy()
445 del state["_credentials_provider"]
446 del state["_storage_provider"]
447 del state["_metadata_provider"]
448 del state["_cache_manager"]
449 return state
450
451 def __setstate__(self, state: dict[str, Any]) -> None:
452 config = state["_config"]
453 self._initialize_providers(config)
454
[docs]
455 def sync_from(
456 self,
457 source_client: "StorageClient",
458 source_path: str = "",
459 target_path: str = "",
460 delete_unmatched_files: bool = False,
461 description: str = "Syncing",
462 num_worker_processes: Optional[int] = None,
463 ) -> None:
464 """
465 Syncs files from the source storage client to "path/".
466
467 :param source_client: The source storage client.
468 :param source_path: The logical path to sync from.
469 :param target_path: The logical path to sync to.
470 :param delete_unmatched_files: Whether to delete files at the target that are not present at the source.
471 :param description: Description of sync process for logging purposes.
472 :param num_worker_processes: The number of worker processes to use.
473 """
474 source_path = source_path.lstrip("/")
475 target_path = target_path.lstrip("/")
476
477 if source_client == self and (source_path.startswith(target_path) or target_path.startswith(source_path)):
478 raise ValueError("Source and target paths cannot overlap on same StorageClient.")
479
480 logger.debug(f"Starting sync operation {description}")
481
482 # Attempt to balance the number of worker processes and threads.
483 num_worker_processes, num_worker_threads = calculate_worker_processes_and_threads(num_worker_processes)
484
485 if num_worker_processes == 1:
486 file_queue = queue.Queue(maxsize=100000)
487 result_queue = queue.Queue()
488 else:
489 manager = multiprocessing.Manager()
490 file_queue = manager.Queue(maxsize=100000)
491 result_queue = manager.Queue()
492
493 progress = ProgressBar(desc=description, show_progress=True, total_items=0)
494
495 def match_file_metadata(source_info: ObjectMetadata, target_info: ObjectMetadata) -> bool:
496 # If target and source have valid etags defined, use etag and file size to compare.
497 if source_info.etag and target_info.etag:
498 return source_info.etag == target_info.etag and source_info.content_length == target_info.content_length
499 # Else, check file size is the same and the target's last_modified is newer than the source.
500 return (
501 source_info.content_length == target_info.content_length
502 and source_info.last_modified <= target_info.last_modified
503 )
504
505 def producer():
506 """Lists source files and adds them to the queue."""
507 source_iter = iter(source_client.list(prefix=source_path))
508 target_iter = iter(self.list(prefix=target_path))
509 total_count = 0
510
511 source_file = next(source_iter, None)
512 target_file = next(target_iter, None)
513
514 while source_file or target_file:
515 # Update progress and count each pair (or single) considered for syncing
516 if total_count % 1000 == 0:
517 progress.update_total(total_count)
518 total_count += 1
519
520 if source_file and target_file:
521 source_key = source_file.key[len(source_path) :].lstrip("/")
522 target_key = target_file.key[len(target_path) :].lstrip("/")
523
524 if source_key < target_key:
525 file_queue.put((_SyncOp.ADD, source_file))
526 source_file = next(source_iter, None)
527 elif source_key > target_key:
528 if delete_unmatched_files:
529 file_queue.put((_SyncOp.DELETE, target_file))
530 else:
531 progress.update_progress()
532 target_file = next(target_iter, None) # Skip unmatched target file
533 else:
534 # Both exist, compare metadata
535 if not match_file_metadata(source_file, target_file):
536 file_queue.put((_SyncOp.ADD, source_file))
537 else:
538 progress.update_progress()
539
540 source_file = next(source_iter, None)
541 target_file = next(target_iter, None)
542 elif source_file:
543 file_queue.put((_SyncOp.ADD, source_file))
544 source_file = next(source_iter, None)
545 else:
546 if delete_unmatched_files:
547 file_queue.put((_SyncOp.DELETE, target_file))
548 else:
549 progress.update_progress()
550 target_file = next(target_iter, None)
551
552 progress.update_total(total_count)
553
554 for _ in range(num_worker_threads * num_worker_processes):
555 file_queue.put((_SyncOp.STOP, None)) # Signal consumers to stop
556
557 producer_thread = threading.Thread(target=producer, daemon=True)
558 producer_thread.start()
559
560 def _result_consumer():
561 # Pull from result_queue to collect pending updates from each multiprocessing worker.
562 while True:
563 op, target_file_path, physical_metadata = result_queue.get()
564 if op == _SyncOp.STOP:
565 break
566
567 if self._metadata_provider:
568 if op == _SyncOp.ADD:
569 # Use realpath() to get physical path so metadata provider can
570 # track the logical/physical mapping.
571 phys_path, _ = self._metadata_provider.realpath(target_file_path)
572 physical_metadata.key = phys_path
573 self._metadata_provider.add_file(target_file_path, physical_metadata)
574 elif op == _SyncOp.DELETE:
575 self._metadata_provider.remove_file(target_file_path)
576 else:
577 raise RuntimeError(f"Unknown operation: {op}")
578 progress.update_progress()
579
580 result_consumer_thread = threading.Thread(target=_result_consumer, daemon=True)
581 result_consumer_thread.start()
582
583 if num_worker_processes == 1:
584 # Single process does not require multiprocessing.
585 _sync_worker_process(
586 source_client, source_path, self, target_path, num_worker_threads, file_queue, result_queue
587 )
588 else:
589 with multiprocessing.Pool(processes=num_worker_processes) as pool:
590 pool.apply(
591 _sync_worker_process,
592 args=(source_client, source_path, self, target_path, num_worker_threads, file_queue, result_queue),
593 )
594
595 producer_thread.join()
596
597 result_queue.put((_SyncOp.STOP, None, None))
598 result_consumer_thread.join()
599
600 self.commit_metadata()
601 progress.close()
602 logger.debug(f"Completed sync operation {description}")
603
604
605def _sync_worker_process(
606 source_client: StorageClient,
607 source_path: str,
608 target_client: StorageClient,
609 target_path: str,
610 num_worker_threads: int,
611 file_queue: queue.Queue,
612 result_queue: Optional[queue.Queue],
613):
614 """Helper function for sync_from, defined at top-level for multiprocessing."""
615
616 def _sync_consumer() -> None:
617 """Processes files from the queue and copies them."""
618 while True:
619 op, file_metadata = file_queue.get()
620 if op == _SyncOp.STOP:
621 break
622
623 source_key = file_metadata.key[len(source_path) :].lstrip("/")
624 target_file_path = os.path.join(target_path, source_key)
625
626 if op == _SyncOp.ADD:
627 logger.debug(f"sync {file_metadata.key} -> {target_file_path}")
628 if file_metadata.content_length < MEMORY_LOAD_LIMIT:
629 file_content = source_client.read(file_metadata.key)
630 target_client.write(target_file_path, file_content)
631 else:
632 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
633 temp_filename = temp_file.name
634
635 try:
636 source_client.download_file(file_metadata.key, temp_filename)
637 target_client.upload_file(target_file_path, temp_filename)
638 finally:
639 os.remove(temp_filename) # Ensure the temporary file is removed
640 elif op == _SyncOp.DELETE:
641 logger.debug(f"rm {file_metadata.key}")
642 target_client.delete(file_metadata.key)
643 else:
644 raise ValueError(f"Unknown operation: {op}")
645
646 if result_queue:
647 if op == _SyncOp.ADD:
648 # add tuple of (virtual_path, physical_metadata) to result_queue
649 if target_client._metadata_provider:
650 physical_metadata = target_client._metadata_provider.get_object_metadata(
651 target_file_path, include_pending=True
652 )
653 else:
654 physical_metadata = None
655 result_queue.put((op, target_file_path, physical_metadata))
656 elif op == _SyncOp.DELETE:
657 result_queue.put((op, target_file_path, None))
658 else:
659 raise RuntimeError(f"Unknown operation: {op}")
660
661 # Worker process that spawns threads to handle syncing.
662 with ThreadPoolExecutor(max_workers=num_worker_threads) as executor:
663 futures = [executor.submit(_sync_consumer) for _ in range(num_worker_threads)]
664 for future in futures:
665 future.result() # Ensure all threads complete