1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from __future__ import annotations
17
18from abc import ABC, abstractmethod
19from collections.abc import Iterator
20from dataclasses import asdict, dataclass, field
21from datetime import datetime, timezone
22from enum import Enum
23from typing import IO, Any, NamedTuple, Optional, Tuple, Union
24
25from dateutil.parser import parse as dateutil_parser
26
27MSC_PROTOCOL_NAME = "msc"
28MSC_PROTOCOL = MSC_PROTOCOL_NAME + "://"
29
30DEFAULT_RETRY_ATTEMPTS = 3
31DEFAULT_RETRY_DELAY = 1.0
32DEFAULT_RETRY_BACKOFF_MULTIPLIER = 2.0
33
34# datetime.min is a naive datetime.
35#
36# This creates issues when doing datetime.astimezone(timezone.utc) since it assumes the local timezone for the naive datetime.
37# If the local timezone is offset behind UTC, it attempts to subtract off the offset which goes below the representable limit (i.e. an underflow).
38# A `ValueError: year 0 is out of range` is thrown as a result.
39AWARE_DATETIME_MIN = datetime.min.replace(tzinfo=timezone.utc)
40
41
[docs]
42@dataclass
43class Credentials:
44 """
45 A data class representing the credentials needed to access a storage provider.
46 """
47
48 #: The access key for authentication.
49 access_key: str
50 #: The secret key for authentication.
51 secret_key: str
52 #: An optional security token for temporary credentials.
53 token: Optional[str]
54 #: The expiration time of the credentials in ISO 8601 format.
55 expiration: Optional[str]
56 #: A dictionary for storing custom key-value pairs.
57 custom_fields: dict[str, Any] = field(default_factory=dict)
58
[docs]
59 def is_expired(self) -> bool:
60 """
61 Checks if the credentials are expired based on the expiration time.
62
63 :return: ``True`` if the credentials are expired, ``False`` otherwise.
64 """
65 expiry = dateutil_parser(self.expiration) if self.expiration else None
66 if expiry is None:
67 return False
68 return expiry <= datetime.now(tz=timezone.utc)
69
[docs]
70 def get_custom_field(self, key: str, default: Any = None) -> Any:
71 """
72 Retrieves a value from custom fields by its key.
73
74 :param key: The key to look up in custom fields.
75 :param default: The default value to return if the key is not found.
76 :return: The value associated with the key, or the default value if not found.
77 """
78 return self.custom_fields.get(key, default)
79
80
131
132
[docs]
133class CredentialsProvider(ABC):
134 """
135 Abstract base class for providing credentials to access a storage provider.
136 """
137
[docs]
138 @abstractmethod
139 def get_credentials(self) -> Credentials:
140 """
141 Retrieves the current credentials.
142
143 :return: The current credentials used for authentication.
144 """
145 pass
146
[docs]
147 @abstractmethod
148 def refresh_credentials(self) -> None:
149 """
150 Refreshes the credentials if they are expired or about to expire.
151 """
152 pass
153
154
[docs]
155@dataclass
156class Range:
157 """
158 A data class that represents a byte range for read operations.
159 """
160
161 #: The start offset in bytes.
162 offset: int
163 #: The number of bytes to read.
164 size: int
165
166
[docs]
167class StorageProvider(ABC):
168 """
169 Abstract base class for interacting with a storage provider.
170 """
171
[docs]
172 @abstractmethod
173 def put_object(
174 self,
175 path: str,
176 body: bytes,
177 if_match: Optional[str] = None,
178 if_none_match: Optional[str] = None,
179 attributes: Optional[dict[str, str]] = None,
180 ) -> None:
181 """
182 Uploads an object to the storage provider.
183
184 :param path: The path where the object will be stored.
185 :param body: The content of the object to store.
186 :param attributes: The attributes to add to the file
187 """
188 pass
189
[docs]
190 @abstractmethod
191 def get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
192 """
193 Retrieves an object from the storage provider.
194
195 :param path: The path where the object is stored.
196
197 :return: The content of the retrieved object.
198 """
199 pass
200
[docs]
201 @abstractmethod
202 def copy_object(self, src_path: str, dest_path: str) -> None:
203 """
204 Copies an object from source to destination in the storage provider.
205
206 :param src_path: The path of the source object to copy.
207 :param dest_path: The path of the destination.
208 """
209 pass
210
[docs]
211 @abstractmethod
212 def delete_object(self, path: str, if_match: Optional[str] = None) -> None:
213 """
214 Deletes an object from the storage provider.
215
216 :param path: The path of the object to delete.
217 :param if_match: Optional if-match value to use for conditional deletion.
218 """
219 pass
220
232
[docs]
233 @abstractmethod
234 def list_objects(
235 self,
236 path: str,
237 start_after: Optional[str] = None,
238 end_at: Optional[str] = None,
239 include_directories: bool = False,
240 attribute_filter_expression: Optional[str] = None,
241 show_attributes: bool = False,
242 follow_symlinks: bool = True,
243 ) -> Iterator[ObjectMetadata]:
244 """
245 Lists objects in the storage provider under the specified path.
246
247 :param path: The path to list objects under. The path must be a valid file or subdirectory path, cannot be partial or just "prefix".
248 :param start_after: The key to start after (i.e. exclusive). An object with this key doesn't have to exist.
249 :param end_at: The key to end at (i.e. inclusive). An object with this key doesn't have to exist.
250 :param include_directories: Whether to include directories in the result. When True, directories are returned alongside objects.
251 :param attribute_filter_expression: The attribute filter expression to apply to the result.
252 :param show_attributes: Whether to return attributes in the result. There will be performance impact if this is True as now we need to get object metadata for each object.
253 :param follow_symlinks: Whether to follow symbolic links. Only applicable for POSIX file storage providers.
254
255 :return: An iterator over objects metadata under the specified path.
256 """
257 pass
258
[docs]
259 @abstractmethod
260 def upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> None:
261 """
262 Uploads a file from the local file system to the storage provider.
263
264 :param remote_path: The path where the object will be stored.
265 :param f: The source file to upload. This can either be a string representing the local
266 file path, or a file-like object (e.g., an open file handle).
267 :param attributes: The attributes to add to the file if a new file is created.
268 """
269 pass
270
[docs]
271 @abstractmethod
272 def download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
273 """
274 Downloads a file from the storage provider to the local file system.
275
276 :param remote_path: The path of the file to download.
277 :param f: The destination for the downloaded file. This can either be a string representing
278 the local file path where the file will be saved, or a file-like object to write the
279 downloaded content into.
280 :param metadata: Metadata about the object to download.
281 """
282 pass
283
[docs]
284 @abstractmethod
285 def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]:
286 """
287 Matches and retrieves a list of object keys in the storage provider that match the specified pattern.
288
289 :param pattern: The pattern to match object keys against, supporting wildcards (e.g., ``*.txt``).
290 :param attribute_filter_expression: The attribute filter expression to apply to the result.
291
292 :return: A list of object keys that match the specified pattern.
293 """
294 pass
295
[docs]
296 @abstractmethod
297 def is_file(self, path: str) -> bool:
298 """
299 Checks whether the specified key in the storage provider points to a file (as opposed to a folder or directory).
300
301 :param path: The path to check.
302
303 :return: ``True`` if the key points to a file, ``False`` if it points to a directory or folder.
304 """
305 pass
306
307
[docs]
308class ResolvedPathState(str, Enum):
309 """
310 Enum representing the state of a resolved path.
311 """
312
313 EXISTS = "exists" # File currently exists
314 DELETED = "deleted" # File existed before but has been deleted
315 UNTRACKED = "untracked" # File never existed or was never tracked
316
317
[docs]
318class ResolvedPath(NamedTuple):
319 """
320 Result of resolving a virtual path to a physical path.
321
322 :param physical_path: The physical path in storage backend
323 :param state: The state of the path (EXISTS, DELETED, or UNTRACKED)
324 :param profile: Optional profile name for routing in CompositeStorageClient.
325 None means use current client's storage provider.
326 String means route to named child StorageClient.
327
328 State meanings:
329 - EXISTS: File currently exists in metadata
330 - DELETED: File existed before but has been deleted (soft delete)
331 - UNTRACKED: File never existed or was never tracked
332 """
333
334 physical_path: str
335 state: ResolvedPathState
336 profile: Optional[str] = None
337
338 @property
339 def exists(self) -> bool:
340 """Backward compatibility property: True if state is EXISTS."""
341 return self.state == ResolvedPathState.EXISTS
342
343
495
496
[docs]
497@dataclass
498class StorageProviderConfig:
499 """
500 A data class that represents the configuration needed to initialize a storage provider.
501 """
502
503 #: The name or type of the storage provider (e.g., ``s3``, ``gcs``, ``oci``, ``azure``).
504 type: str
505 #: Additional options required to configure the storage provider (e.g., endpoint URLs, region, etc.).
506 options: Optional[dict[str, Any]] = None
507
508
[docs]
509@dataclass
510class StorageBackend:
511 """
512 Represents configuration for a single storage backend.
513 """
514
515 storage_provider_config: StorageProviderConfig
516 credentials_provider: Optional[CredentialsProvider] = None
517 replicas: list["Replica"] = field(default_factory=list)
518
519
[docs]
520class ProviderBundle(ABC):
521 """
522 Abstract base class that serves as a container for various providers (storage, credentials, and metadata)
523 that interact with a storage service. The :py:class:`ProviderBundle` abstracts access to these providers, allowing for
524 flexible implementations of cloud storage solutions.
525 """
526
527 @property
528 @abstractmethod
529 def storage_provider_config(self) -> StorageProviderConfig:
530 """
531 :return: The configuration for the storage provider, which includes the provider
532 name/type and additional options.
533 """
534 pass
535
536 @property
537 @abstractmethod
538 def credentials_provider(self) -> Optional[CredentialsProvider]:
539 """
540 :return: The credentials provider responsible for managing authentication credentials
541 required to access the storage service.
542 """
543 pass
544
545 @property
546 @abstractmethod
547 def metadata_provider(self) -> Optional[MetadataProvider]:
548 """
549 :return: The metadata provider responsible for retrieving metadata about objects in the storage service.
550 """
551 pass
552
553 @property
554 @abstractmethod
555 def replicas(self) -> list["Replica"]:
556 """
557 :return: The replicas configuration for this provider bundle, if any.
558 """
559 pass
560
561
[docs]
562class ProviderBundleV2(ABC):
563 """
564 Abstract base class that serves as a container for various providers (storage, credentials, and metadata)
565 that interact with one or multiple storage service. The :py:class:`ProviderBundleV2` abstracts access to these providers, allowing for
566 flexible implementations of cloud storage solutions.
567
568 """
569
570 @property
571 @abstractmethod
572 def storage_backends(self) -> dict[str, StorageBackend]:
573 """
574 :return: Mapping of storage backend name -> StorageBackend. Must have at least one backend.
575 """
576 pass
577
578 @property
579 @abstractmethod
580 def metadata_provider(self) -> Optional[MetadataProvider]:
581 """
582 :return: The metadata provider responsible for retrieving metadata about objects in the storage service. If there are multiple backends, this is required.
583 """
584 pass
585
586
[docs]
587@dataclass
588class RetryConfig:
589 """
590 A data class that represents the configuration for retry strategy.
591 """
592
593 #: The number of attempts before giving up. Must be at least 1.
594 attempts: int = DEFAULT_RETRY_ATTEMPTS
595 #: The base delay (in seconds) for exponential backoff. Must be a non-negative value.
596 delay: float = DEFAULT_RETRY_DELAY
597 #: The backoff multiplier for exponential backoff. Must be at least 1.0.
598 backoff_multiplier: float = DEFAULT_RETRY_BACKOFF_MULTIPLIER
599
600 def __post_init__(self) -> None:
601 if self.attempts < 1:
602 raise ValueError("Attempts must be at least 1.")
603 if self.delay < 0:
604 raise ValueError("Delay must be a non-negative number.")
605 if self.backoff_multiplier < 1.0:
606 raise ValueError("Backoff multiplier must be at least 1.0.")
607
608
[docs]
609class RetryableError(Exception):
610 """
611 Exception raised for errors that should trigger a retry.
612 """
613
614 pass
615
616
[docs]
617class PreconditionFailedError(Exception):
618 """
619 Exception raised when a precondition fails. e.g. if-match, if-none-match, etc.
620 """
621
622 pass
623
624
[docs]
625class NotModifiedError(Exception):
626 """
627 Raised when a conditional operation fails because the resource has not been modified.
628
629 This typically occurs when using if-none-match with a specific generation/etag
630 and the resource's current generation/etag matches the specified one.
631 """
632
633 pass
634
635
[docs]
636class SourceVersionCheckMode(Enum):
637 """
638 Enum for controlling source version checking behavior.
639 """
640
641 INHERIT = "inherit" # Inherit from configuration (cache config)
642 ENABLE = "enable" # Always check source version
643 DISABLE = "disable" # Never check source version
644
645
[docs]
646@dataclass
647class Replica:
648 """
649 A tier of storage that can be used to store data.
650 """
651
652 replica_profile: str
653 read_priority: int
654
655
[docs]
656class AutoCommitConfig:
657 """
658 A data class that represents the configuration for auto commit.
659 """
660
661 interval_minutes: Optional[float] # The interval in minutes for auto commit.
662 at_exit: bool = False # if True, commit on program exit
663
664 def __init__(self, interval_minutes: Optional[float] = None, at_exit: bool = False) -> None:
665 self.interval_minutes = interval_minutes
666 self.at_exit = at_exit
667
668
[docs]
669class ExecutionMode(Enum):
670 """
671 Enum for controlling execution mode in sync operations.
672 """
673
674 LOCAL = "local"
675 RAY = "ray"
676
677
[docs]
678class PatternType(Enum):
679 """
680 Type of pattern operation for include/exclude filtering.
681 """
682
683 INCLUDE = "include"
684 EXCLUDE = "exclude"
685
686
687# Type alias for pattern matching
688PatternList = list[Tuple[PatternType, str]]
689
690
[docs]
691@dataclass
692class SyncResult:
693 """
694 A data class that represents the summary of a sync operation.
695 """
696
697 #: The total number of work units tracked for progress (including files from both source and target after filtering). Each work unit represents an ADD or DELETE operation.
698 total_work_units: int = 0
699 #: The total number of files processed to the target.
700 total_files_added: int = 0
701 #: The total number of files deleted from the target.
702 total_files_deleted: int = 0
703 #: The total number of bytes transferred to the target.
704 total_bytes_added: int = 0
705 #: The total number of bytes deleted from the target.
706 total_bytes_deleted: int = 0
707 #: The total time taken to process the sync operation.
708 total_time_seconds: float = 0.0
709
710 def __str__(self) -> str:
711 return (
712 f"Sync statistics:\n"
713 f" Work units: {self.total_work_units}\n"
714 f" Files added: {self.total_files_added}\n"
715 f" Files deleted: {self.total_files_deleted}\n"
716 f" Bytes added: {self.total_bytes_added}\n"
717 f" Bytes deleted: {self.total_bytes_deleted}\n"
718 f" Time elapsed: {self.total_time_seconds:.2f}s"
719 )
720
721
[docs]
722class SyncError(RuntimeError):
723 """
724 Exception raised when errors occur during a sync operation.
725
726 This exception includes the partial SyncResult showing what was accomplished
727 before the error occurred, allowing users to understand the state of the sync.
728
729 :param message: The error message describing what went wrong.
730 :param sync_result: The partial SyncResult with statistics from the failed sync operation.
731 """
732
733 def __init__(self, message: str, sync_result: SyncResult):
734 super().__init__(message)
735 self.sync_result = sync_result
736
737 def __str__(self) -> str:
738 sync_stats = str(self.sync_result).replace("Sync statistics:", "Partial sync statistics:")
739 return f"{super().__str__()}\n\n{sync_stats}"