1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from __future__ import annotations
17
18from abc import ABC, abstractmethod
19from collections.abc import Iterator, Sequence
20from dataclasses import asdict, dataclass, field, replace
21from datetime import datetime, timezone
22from enum import Enum
23from typing import IO, Any, NamedTuple, Optional, Tuple, Union
24
25from dateutil.parser import parse as dateutil_parser
26
27MSC_PROTOCOL_NAME = "msc"
28MSC_PROTOCOL = MSC_PROTOCOL_NAME + "://"
29
30DEFAULT_RETRY_ATTEMPTS = 3
31DEFAULT_RETRY_DELAY = 1.0
32DEFAULT_RETRY_BACKOFF_MULTIPLIER = 2.0
33
34# datetime.min is a naive datetime.
35#
36# This creates issues when doing datetime.astimezone(timezone.utc) since it assumes the local timezone for the naive datetime.
37# If the local timezone is offset behind UTC, it attempts to subtract off the offset which goes below the representable limit (i.e. an underflow).
38# A `ValueError: year 0 is out of range` is thrown as a result.
39AWARE_DATETIME_MIN = datetime.min.replace(tzinfo=timezone.utc)
40
41
[docs]
42class SignerType(str, Enum):
43 """Supported signer backends for presigned URL generation."""
44
45 S3 = "s3"
46 CLOUDFRONT = "cloudfront"
47 AZURE = "azure"
48
49
[docs]
50@dataclass
51class Credentials:
52 """
53 A data class representing the credentials needed to access a storage provider.
54 """
55
56 #: The access key for authentication.
57 access_key: str
58 #: The secret key for authentication.
59 secret_key: str
60 #: An optional security token for temporary credentials.
61 token: Optional[str]
62 #: The expiration time of the credentials in ISO 8601 format.
63 expiration: Optional[str]
64 #: A dictionary for storing custom key-value pairs.
65 custom_fields: dict[str, Any] = field(default_factory=dict)
66
[docs]
67 def is_expired(self) -> bool:
68 """
69 Checks if the credentials are expired based on the expiration time.
70
71 :return: ``True`` if the credentials are expired, ``False`` otherwise.
72 """
73 expiry = dateutil_parser(self.expiration) if self.expiration else None
74 if expiry is None:
75 return False
76 return expiry <= datetime.now(tz=timezone.utc)
77
[docs]
78 def get_custom_field(self, key: str, default: Any = None) -> Any:
79 """
80 Retrieves a value from custom fields by its key.
81
82 :param key: The key to look up in custom fields.
83 :param default: The default value to return if the key is not found.
84 :return: The value associated with the key, or the default value if not found.
85 """
86 return self.custom_fields.get(key, default)
87
88
143
144
[docs]
145class CredentialsProvider(ABC):
146 """
147 Abstract base class for providing credentials to access a storage provider.
148 """
149
[docs]
150 @abstractmethod
151 def get_credentials(self) -> Credentials:
152 """
153 Retrieves the current credentials.
154
155 :return: The current credentials used for authentication.
156 """
157 pass
158
[docs]
159 @abstractmethod
160 def refresh_credentials(self) -> None:
161 """
162 Refreshes the credentials if they are expired or about to expire.
163 """
164 pass
165
166
[docs]
167@dataclass
168class Range:
169 """
170 A data class that represents a byte range for read operations.
171 """
172
173 #: The start offset in bytes.
174 offset: int
175 #: The number of bytes to read.
176 size: int
177
178
[docs]
179class StorageProvider(ABC):
180 """
181 Abstract base class for interacting with a storage provider.
182 """
183
[docs]
184 @abstractmethod
185 def put_object(
186 self,
187 path: str,
188 body: bytes,
189 if_match: Optional[str] = None,
190 if_none_match: Optional[str] = None,
191 attributes: Optional[dict[str, str]] = None,
192 ) -> None:
193 """
194 Uploads an object to the storage provider.
195
196 :param path: The path where the object will be stored.
197 :param body: The content of the object to store.
198 :param if_match: Optional If-Match value for conditional upload.
199 :param if_none_match: Optional If-None-Match value for conditional upload.
200 :param attributes: The attributes to add to the file.
201 """
202 pass
203
[docs]
204 @abstractmethod
205 def get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
206 """
207 Retrieves an object from the storage provider.
208
209 :param path: The path where the object is stored.
210 :param byte_range: Optional byte range (offset, length) to read.
211 :return: The content of the retrieved object.
212 """
213 pass
214
[docs]
215 @abstractmethod
216 def copy_object(self, src_path: str, dest_path: str) -> None:
217 """
218 Copies an object from source to destination in the storage provider.
219
220 :param src_path: The path of the source object to copy.
221 :param dest_path: The path of the destination.
222 """
223 pass
224
[docs]
225 @abstractmethod
226 def delete_object(self, path: str, if_match: Optional[str] = None) -> None:
227 """
228 Deletes an object from the storage provider.
229
230 :param path: The path of the object to delete.
231 :param if_match: Optional if-match value to use for conditional deletion.
232 """
233 pass
234
[docs]
235 @abstractmethod
236 def delete_objects(self, paths: list[str]) -> None:
237 """
238 Deletes multiple objects from the storage provider.
239
240 :param paths: A list of paths of objects to delete.
241 """
242 pass
243
255
[docs]
256 @abstractmethod
257 def list_objects(
258 self,
259 path: str,
260 start_after: Optional[str] = None,
261 end_at: Optional[str] = None,
262 include_directories: bool = False,
263 attribute_filter_expression: Optional[str] = None,
264 show_attributes: bool = False,
265 follow_symlinks: bool = True,
266 ) -> Iterator[ObjectMetadata]:
267 """
268 Lists objects in the storage provider under the specified path.
269
270 :param path: The path to list objects under. The path must be a valid file or subdirectory path, cannot be partial or just "prefix".
271 :param start_after: The key to start after (i.e. exclusive). An object with this key doesn't have to exist.
272 :param end_at: The key to end at (i.e. inclusive). An object with this key doesn't have to exist.
273 :param include_directories: Whether to include directories in the result. When ``True``, directories are returned alongside objects.
274 :param attribute_filter_expression: The attribute filter expression to apply to the result.
275 :param show_attributes: Whether to return attributes in the result. There will be performance impact if this is True as now we need to get object metadata for each object.
276 :param follow_symlinks: Whether to follow symbolic links. Only applicable for POSIX file storage providers.
277
278 :return: An iterator over objects metadata under the specified path.
279 """
280 pass
281
[docs]
282 @abstractmethod
283 def list_objects_recursive(
284 self,
285 path: str = "",
286 start_after: Optional[str] = None,
287 end_at: Optional[str] = None,
288 max_workers: int = 32,
289 look_ahead: int = 2,
290 follow_symlinks: bool = True,
291 ) -> Iterator[ObjectMetadata]:
292 """
293 Lists files recursively in the storage provider under the specified path.
294
295 :param path: The path to list objects under.
296 :param start_after: The key to start after (i.e. exclusive). An object with this key doesn't have to exist.
297 :param end_at: The key to end at (i.e. inclusive). An object with this key doesn't have to exist.
298 :param max_workers: Maximum concurrent workers for provider-level recursive listing.
299 :param look_ahead: Prefixes to buffer per worker for provider-level recursive listing.
300 :param follow_symlinks: Whether to follow symbolic links. Only applicable for POSIX file storage providers.
301 :return: An iterator over object metadata under the specified path.
302 """
303 pass
304
[docs]
305 @abstractmethod
306 def upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> None:
307 """
308 Uploads a file from the local file system to the storage provider.
309
310 :param remote_path: The path where the object will be stored.
311 :param f: The source file to upload. This can either be a string representing the local
312 file path, or a file-like object (e.g., an open file handle).
313 :param attributes: The attributes to add to the file if a new file is created.
314 """
315 pass
316
[docs]
317 @abstractmethod
318 def download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
319 """
320 Downloads a file from the storage provider to the local file system.
321
322 :param remote_path: The path of the file to download.
323 :param f: The destination for the downloaded file. This can either be a string representing
324 the local file path where the file will be saved, or a file-like object to write the
325 downloaded content into.
326 :param metadata: Metadata about the object to download.
327 """
328 pass
329
[docs]
330 @abstractmethod
331 def download_files(
332 self,
333 remote_paths: list[str],
334 local_paths: list[str],
335 metadata: Optional[Sequence[Optional[ObjectMetadata]]] = None,
336 max_workers: int = 16,
337 ) -> None:
338 """
339 Downloads multiple files from the storage provider to the local file system.
340
341 :param remote_paths: List of remote paths of files to download.
342 :param local_paths: List of local file paths to save the downloaded files to.
343 :param metadata: Optional per-file metadata used to decide between regular and multipart download.
344 :param max_workers: Maximum number of concurrent download workers (default: 16).
345 :raises ValueError: If remote_paths and local_paths have different lengths.
346 """
347 pass
348
[docs]
349 @abstractmethod
350 def upload_files(
351 self,
352 local_paths: list[str],
353 remote_paths: list[str],
354 attributes: Optional[Sequence[Optional[dict[str, str]]]] = None,
355 max_workers: int = 16,
356 ) -> None:
357 """
358 Uploads multiple files from the local file system to the storage provider.
359
360 :param local_paths: List of local file paths to upload.
361 :param remote_paths: List of remote paths to upload the files to.
362 :param attributes: Optional list of per-file attributes to add. When provided, must have the same length
363 as local_paths/remote_paths. Each element may be ``None`` for files that need no attributes.
364 :param max_workers: Maximum number of concurrent upload workers (default: 16).
365 :raises ValueError: If local_paths and remote_paths have different lengths.
366 :raises ValueError: If attributes is provided and has a different length than local_paths.
367 """
368 pass
369
[docs]
370 @abstractmethod
371 def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]:
372 """
373 Matches and retrieves a list of object keys in the storage provider that match the specified pattern.
374
375 :param pattern: The pattern to match object keys against, supporting wildcards (e.g., ``*.txt``).
376 :param attribute_filter_expression: The attribute filter expression to apply to the result.
377
378 :return: A list of object keys that match the specified pattern.
379 """
380 pass
381
[docs]
382 @abstractmethod
383 def is_file(self, path: str) -> bool:
384 """
385 Checks whether the specified key in the storage provider points to a file (as opposed to a folder or directory).
386
387 :param path: The path to check.
388
389 :return: ``True`` if the key points to a file, ``False`` if it points to a directory or folder.
390 """
391 pass
392
[docs]
393 def generate_presigned_url(
394 self,
395 path: str,
396 *,
397 method: str = "GET",
398 signer_type: Optional[SignerType] = None,
399 signer_options: Optional[dict[str, Any]] = None,
400 ) -> str:
401 """
402 Generate a pre-signed URL granting temporary access to the object at *path*.
403
404 :param path: The object path within the storage provider.
405 :param method: The HTTP method the URL should authorise (e.g. ``"GET"``, ``"PUT"``).
406 :param signer_type: The signing backend to use. ``None`` means the provider's native signer.
407 :param signer_options: Backend-specific options forwarded to the signer.
408 :return: A pre-signed URL string.
409 :raises NotImplementedError: If this storage provider does not support presigned URLs.
410 """
411 raise NotImplementedError(f"{type(self).__name__} does not support presigned URL generation.")
412
413
[docs]
414class ResolvedPathState(str, Enum):
415 """
416 Enum representing the state of a resolved path.
417 """
418
419 EXISTS = "exists" # File currently exists
420 DELETED = "deleted" # File existed before but has been deleted
421 UNTRACKED = "untracked" # File never existed or was never tracked
422
423
[docs]
424class ResolvedPath(NamedTuple):
425 """
426 Result of resolving a virtual path to a physical path.
427
428 :param physical_path: The physical path in storage backend
429 :param state: The state of the path (EXISTS, DELETED, or UNTRACKED)
430 :param profile: Optional profile name for routing in CompositeStorageClient.
431 None means use current client's storage provider.
432 String means route to named child StorageClient.
433
434 State meanings:
435 - EXISTS: File currently exists in metadata
436 - DELETED: File existed before but has been deleted (soft delete)
437 - UNTRACKED: File never existed or was never tracked
438 """
439
440 physical_path: str
441 state: ResolvedPathState
442 profile: Optional[str] = None
443
444 @property
445 def exists(self) -> bool:
446 """Backward compatibility property: True if state is EXISTS."""
447 return self.state == ResolvedPathState.EXISTS
448
449
606
607
[docs]
608@dataclass
609class StorageProviderConfig:
610 """
611 A data class that represents the configuration needed to initialize a storage provider.
612 """
613
614 #: The name or type of the storage provider (e.g., ``s3``, ``gcs``, ``oci``, ``azure``).
615 type: str
616 #: Additional options required to configure the storage provider (e.g., endpoint URLs, region, etc.).
617 options: Optional[dict[str, Any]] = None
618
619
[docs]
620@dataclass
621class StorageBackend:
622 """
623 Represents configuration for a single storage backend.
624 """
625
626 storage_provider_config: StorageProviderConfig
627 credentials_provider: Optional[CredentialsProvider] = None
628 replicas: list["Replica"] = field(default_factory=list)
629
630
[docs]
631class ProviderBundle(ABC):
632 """
633 Abstract base class that serves as a container for various providers (storage, credentials, and metadata)
634 that interact with a storage service. The :py:class:`ProviderBundle` abstracts access to these providers, allowing for
635 flexible implementations of cloud storage solutions.
636 """
637
638 @property
639 @abstractmethod
640 def storage_provider_config(self) -> StorageProviderConfig:
641 """
642 :return: The configuration for the storage provider, which includes the provider
643 name/type and additional options.
644 """
645 pass
646
647 @property
648 @abstractmethod
649 def credentials_provider(self) -> Optional[CredentialsProvider]:
650 """
651 :return: The credentials provider responsible for managing authentication credentials
652 required to access the storage service.
653 """
654 pass
655
656 @property
657 @abstractmethod
658 def metadata_provider(self) -> Optional[MetadataProvider]:
659 """
660 :return: The metadata provider responsible for retrieving metadata about objects in the storage service.
661 """
662 pass
663
664 @property
665 @abstractmethod
666 def replicas(self) -> list["Replica"]:
667 """
668 :return: The replicas configuration for this provider bundle, if any.
669 """
670 pass
671
672
[docs]
673class ProviderBundleV2(ABC):
674 """
675 Abstract base class that serves as a container for various providers (storage, credentials, and metadata)
676 that interact with one or multiple storage service. The :py:class:`ProviderBundleV2` abstracts access to these providers, allowing for
677 flexible implementations of cloud storage solutions.
678
679 """
680
681 @property
682 @abstractmethod
683 def storage_backends(self) -> dict[str, StorageBackend]:
684 """
685 :return: Mapping of storage backend name -> StorageBackend. Must have at least one backend.
686 """
687 pass
688
689 @property
690 @abstractmethod
691 def metadata_provider(self) -> Optional[MetadataProvider]:
692 """
693 :return: The metadata provider responsible for retrieving metadata about objects in the storage service. If there are multiple backends, this is required.
694 """
695 pass
696
697
[docs]
698@dataclass
699class RetryConfig:
700 """
701 A data class that represents the configuration for retry strategy.
702 """
703
704 #: The number of attempts before giving up. Must be at least 1.
705 attempts: int = DEFAULT_RETRY_ATTEMPTS
706 #: The base delay (in seconds) for exponential backoff. Must be a non-negative value.
707 delay: float = DEFAULT_RETRY_DELAY
708 #: The backoff multiplier for exponential backoff. Must be at least 1.0.
709 backoff_multiplier: float = DEFAULT_RETRY_BACKOFF_MULTIPLIER
710
711 def __post_init__(self) -> None:
712 if self.attempts < 1:
713 raise ValueError("Attempts must be at least 1.")
714 if self.delay < 0:
715 raise ValueError("Delay must be a non-negative number.")
716 if self.backoff_multiplier < 1.0:
717 raise ValueError("Backoff multiplier must be at least 1.0.")
718
719
[docs]
720class RetryableError(Exception):
721 """
722 Exception raised for errors that should trigger a retry.
723 """
724
725 pass
726
727
[docs]
728class PreconditionFailedError(Exception):
729 """
730 Exception raised when a precondition fails. e.g. if-match, if-none-match, etc.
731 """
732
733 pass
734
735
[docs]
736class NotModifiedError(Exception):
737 """
738 Raised when a conditional operation fails because the resource has not been modified.
739
740 This typically occurs when using if-none-match with a specific generation/etag
741 and the resource's current generation/etag matches the specified one.
742 """
743
744 pass
745
746
[docs]
747class SourceVersionCheckMode(Enum):
748 """
749 Enum for controlling source version checking behavior.
750 """
751
752 INHERIT = "inherit" # Inherit from configuration (cache config)
753 ENABLE = "enable" # Always check source version
754 DISABLE = "disable" # Never check source version
755
756
[docs]
757@dataclass
758class Replica:
759 """
760 A tier of storage that can be used to store data.
761 """
762
763 replica_profile: str
764 read_priority: int
765
766
[docs]
767class AutoCommitConfig:
768 """
769 A data class that represents the configuration for auto commit.
770 """
771
772 interval_minutes: Optional[float] # The interval in minutes for auto commit.
773 at_exit: bool = False # if True, commit on program exit
774
775 def __init__(self, interval_minutes: Optional[float] = None, at_exit: bool = False) -> None:
776 self.interval_minutes = interval_minutes
777 self.at_exit = at_exit
778
779
[docs]
780class ExecutionMode(Enum):
781 """
782 Enum for controlling execution mode in sync operations.
783 """
784
785 LOCAL = "local"
786 RAY = "ray"
787
788
[docs]
789class PatternType(Enum):
790 """
791 Type of pattern operation for include/exclude filtering.
792 """
793
794 INCLUDE = "include"
795 EXCLUDE = "exclude"
796
797
798# Type alias for pattern matching
799PatternList = list[Tuple[PatternType, str]]
800
801
[docs]
802@dataclass
803class DryrunResult:
804 """
805 Holds references to JSONL files produced by a dryrun sync operation.
806
807 Each file contains one JSON object per line, matching the :py:class:`ObjectMetadata`
808 serialization format (see :py:meth:`ObjectMetadata.to_dict` / :py:meth:`ObjectMetadata.from_dict`).
809
810 The caller is responsible for cleaning up the files when they are no longer needed.
811 """
812
813 #: Path to a JSONL file listing source objects that would be added to the target.
814 files_to_add: str
815 #: Path to a JSONL file listing target objects that would be deleted.
816 files_to_delete: str
817
818
[docs]
819@dataclass
820class SyncResult:
821 """
822 A data class that represents the summary of a sync operation.
823 """
824
825 #: The total number of work units tracked for progress (including files from both source and target after filtering). Each work unit represents an ADD or DELETE operation.
826 total_work_units: int = 0
827 #: The total number of files processed to the target.
828 total_files_added: int = 0
829 #: The total number of files deleted from the target.
830 total_files_deleted: int = 0
831 #: The total number of bytes transferred to the target.
832 total_bytes_added: int = 0
833 #: The total number of bytes deleted from the target.
834 total_bytes_deleted: int = 0
835 #: The total time taken to process the sync operation.
836 total_time_seconds: float = 0.0
837 #: Dryrun details with paths to JSONL files. ``None`` for normal (non-dryrun) sync operations.
838 dryrun: Optional[DryrunResult] = None
839
840 def __str__(self) -> str:
841 header = "Sync dryrun statistics:" if self.dryrun else "Sync statistics:"
842 lines = (
843 f"{header}\n"
844 f" Work units: {self.total_work_units}\n"
845 f" Files added: {self.total_files_added}\n"
846 f" Files deleted: {self.total_files_deleted}\n"
847 f" Bytes added: {self.total_bytes_added}\n"
848 f" Bytes deleted: {self.total_bytes_deleted}\n"
849 f" Time elapsed: {self.total_time_seconds:.2f}s"
850 )
851 if self.dryrun:
852 lines += f"\n Files to add: {self.dryrun.files_to_add}\n Files to delete: {self.dryrun.files_to_delete}"
853 return lines
854
855
[docs]
856class SyncError(RuntimeError):
857 """
858 Exception raised when errors occur during a sync operation.
859
860 This exception includes the partial SyncResult showing what was accomplished
861 before the error occurred, allowing users to understand the state of the sync.
862
863 :param message: The error message describing what went wrong.
864 :param sync_result: The partial SyncResult with statistics from the failed sync operation.
865 """
866
867 def __init__(self, message: str, sync_result: SyncResult):
868 super().__init__(message)
869 self.sync_result = sync_result
870
871 def __str__(self) -> str:
872 sync_stats = str(self.sync_result).replace("Sync statistics:", "Partial sync statistics:")
873 return f"{super().__str__()}\n\n{sync_stats}"