1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from abc import ABC, abstractmethod
17from collections.abc import Iterator
18from dataclasses import asdict, dataclass, field
19from datetime import datetime, timezone
20from enum import Enum
21from typing import IO, Any, Optional, Tuple, Union
22
23from dateutil.parser import parse as dateutil_parser
24
25MSC_PROTOCOL_NAME = "msc"
26MSC_PROTOCOL = MSC_PROTOCOL_NAME + "://"
27
28DEFAULT_RETRY_ATTEMPTS = 3
29DEFAULT_RETRY_DELAY = 1.0
30DEFAULT_RETRY_BACKOFF_MULTIPLIER = 2.0
31
32# datetime.min is a naive datetime.
33#
34# This creates issues when doing datetime.astimezone(timezone.utc) since it assumes the local timezone for the naive datetime.
35# If the local timezone is offset behind UTC, it attempts to subtract off the offset which goes below the representable limit (i.e. an underflow).
36# A `ValueError: year 0 is out of range` is thrown as a result.
37AWARE_DATETIME_MIN = datetime.min.replace(tzinfo=timezone.utc)
38
39
[docs]
40@dataclass
41class Credentials:
42 """
43 A data class representing the credentials needed to access a storage provider.
44 """
45
46 #: The access key for authentication.
47 access_key: str
48 #: The secret key for authentication.
49 secret_key: str
50 #: An optional security token for temporary credentials.
51 token: Optional[str]
52 #: The expiration time of the credentials in ISO 8601 format.
53 expiration: Optional[str]
54 #: A dictionary for storing custom key-value pairs.
55 custom_fields: dict[str, Any] = field(default_factory=dict)
56
[docs]
57 def is_expired(self) -> bool:
58 """
59 Checks if the credentials are expired based on the expiration time.
60
61 :return: ``True`` if the credentials are expired, ``False`` otherwise.
62 """
63 expiry = dateutil_parser(self.expiration) if self.expiration else None
64 if expiry is None:
65 return False
66 return expiry <= datetime.now(tz=timezone.utc)
67
[docs]
68 def get_custom_field(self, key: str, default: Any = None) -> Any:
69 """
70 Retrieves a value from custom fields by its key.
71
72 :param key: The key to look up in custom fields.
73 :param default: The default value to return if the key is not found.
74 :return: The value associated with the key, or the default value if not found.
75 """
76 return self.custom_fields.get(key, default)
77
78
129
130
[docs]
131class CredentialsProvider(ABC):
132 """
133 Abstract base class for providing credentials to access a storage provider.
134 """
135
[docs]
136 @abstractmethod
137 def get_credentials(self) -> Credentials:
138 """
139 Retrieves the current credentials.
140
141 :return: The current credentials used for authentication.
142 """
143 pass
144
[docs]
145 @abstractmethod
146 def refresh_credentials(self) -> None:
147 """
148 Refreshes the credentials if they are expired or about to expire.
149 """
150 pass
151
152
[docs]
153@dataclass
154class Range:
155 """
156 Byte-range read.
157 """
158
159 offset: int
160 size: int
161
162
[docs]
163class StorageProvider(ABC):
164 """
165 Abstract base class for interacting with a storage provider.
166 """
167
[docs]
168 @abstractmethod
169 def put_object(
170 self,
171 path: str,
172 body: bytes,
173 if_match: Optional[str] = None,
174 if_none_match: Optional[str] = None,
175 attributes: Optional[dict[str, str]] = None,
176 ) -> None:
177 """
178 Uploads an object to the storage provider.
179
180 :param path: The path where the object will be stored.
181 :param body: The content of the object to store.
182 :param attributes: The attributes to add to the file
183 """
184 pass
185
[docs]
186 @abstractmethod
187 def get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
188 """
189 Retrieves an object from the storage provider.
190
191 :param path: The path where the object is stored.
192
193 :return: The content of the retrieved object.
194 """
195 pass
196
[docs]
197 @abstractmethod
198 def copy_object(self, src_path: str, dest_path: str) -> None:
199 """
200 Copies an object from source to destination in the storage provider.
201
202 :param src_path: The path of the source object to copy.
203 :param dest_path: The path of the destination.
204 """
205 pass
206
[docs]
207 @abstractmethod
208 def delete_object(self, path: str, if_match: Optional[str] = None) -> None:
209 """
210 Deletes an object from the storage provider.
211
212 :param path: The path of the object to delete.
213 :param if_match: Optional if-match value to use for conditional deletion.
214 """
215 pass
216
228
[docs]
229 @abstractmethod
230 def list_objects(
231 self,
232 path: str,
233 start_after: Optional[str] = None,
234 end_at: Optional[str] = None,
235 include_directories: bool = False,
236 attribute_filter_expression: Optional[str] = None,
237 show_attributes: bool = False,
238 follow_symlinks: bool = True,
239 ) -> Iterator[ObjectMetadata]:
240 """
241 Lists objects in the storage provider under the specified path.
242
243 :param path: The path to list objects under. The path must be a valid file or subdirectory path, cannot be partial or just "prefix".
244 :param start_after: The key to start after (i.e. exclusive). An object with this key doesn't have to exist.
245 :param end_at: The key to end at (i.e. inclusive). An object with this key doesn't have to exist.
246 :param include_directories: Whether to include directories in the result. When True, directories are returned alongside objects.
247 :param attribute_filter_expression: The attribute filter expression to apply to the result.
248 :param show_attributes: Whether to return attributes in the result. There will be performance impact if this is True as now we need to get object metadata for each object.
249 :param follow_symlinks: Whether to follow symbolic links. Only applicable for POSIX file storage providers.
250
251 :return: An iterator over objects metadata under the specified path.
252 """
253 pass
254
[docs]
255 @abstractmethod
256 def upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> None:
257 """
258 Uploads a file from the local file system to the storage provider.
259
260 :param remote_path: The path where the object will be stored.
261 :param f: The source file to upload. This can either be a string representing the local
262 file path, or a file-like object (e.g., an open file handle).
263 :param attributes: The attributes to add to the file if a new file is created.
264 """
265 pass
266
[docs]
267 @abstractmethod
268 def download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
269 """
270 Downloads a file from the storage provider to the local file system.
271
272 :param remote_path: The path of the file to download.
273 :param f: The destination for the downloaded file. This can either be a string representing
274 the local file path where the file will be saved, or a file-like object to write the
275 downloaded content into.
276 :param metadata: Metadata about the object to download.
277 """
278 pass
279
[docs]
280 @abstractmethod
281 def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]:
282 """
283 Matches and retrieves a list of object keys in the storage provider that match the specified pattern.
284
285 :param pattern: The pattern to match object keys against, supporting wildcards (e.g., ``*.txt``).
286 :param attribute_filter_expression: The attribute filter expression to apply to the result.
287
288 :return: A list of object keys that match the specified pattern.
289 """
290 pass
291
[docs]
292 @abstractmethod
293 def is_file(self, path: str) -> bool:
294 """
295 Checks whether the specified key in the storage provider points to a file (as opposed to a folder or directory).
296
297 :param path: The path to check.
298
299 :return: ``True`` if the key points to a file, ``False`` if it points to a directory or folder.
300 """
301 pass
302
303
407
408
[docs]
409@dataclass
410class StorageProviderConfig:
411 """
412 A data class that represents the configuration needed to initialize a storage provider.
413 """
414
415 #: The name or type of the storage provider (e.g., ``s3``, ``gcs``, ``oci``, ``azure``).
416 type: str
417 #: Additional options required to configure the storage provider (e.g., endpoint URLs, region, etc.).
418 options: Optional[dict[str, Any]] = None
419
420
[docs]
421class ProviderBundle(ABC):
422 """
423 Abstract base class that serves as a container for various providers (storage, credentials, and metadata)
424 that interact with a storage service. The :py:class:`ProviderBundle` abstracts access to these providers, allowing for
425 flexible implementations of cloud storage solutions.
426 """
427
428 @property
429 @abstractmethod
430 def storage_provider_config(self) -> StorageProviderConfig:
431 """
432 :return: The configuration for the storage provider, which includes the provider
433 name/type and additional options.
434 """
435 pass
436
437 @property
438 @abstractmethod
439 def credentials_provider(self) -> Optional[CredentialsProvider]:
440 """
441 :return: The credentials provider responsible for managing authentication credentials
442 required to access the storage service.
443 """
444 pass
445
446 @property
447 @abstractmethod
448 def metadata_provider(self) -> Optional[MetadataProvider]:
449 """
450 :return: The metadata provider responsible for retrieving metadata about objects in the storage service.
451 """
452 pass
453
454 @property
455 @abstractmethod
456 def replicas(self) -> list["Replica"]:
457 """
458 :return: The replicas configuration for this provider bundle, if any.
459 """
460 pass
461
462
[docs]
463@dataclass
464class RetryConfig:
465 """
466 A data class that represents the configuration for retry strategy.
467 """
468
469 #: The number of attempts before giving up. Must be at least 1.
470 attempts: int = DEFAULT_RETRY_ATTEMPTS
471 #: The base delay (in seconds) for exponential backoff. Must be a non-negative value.
472 delay: float = DEFAULT_RETRY_DELAY
473 #: The backoff multiplier for exponential backoff. Must be at least 1.0.
474 backoff_multiplier: float = DEFAULT_RETRY_BACKOFF_MULTIPLIER
475
476 def __post_init__(self) -> None:
477 if self.attempts < 1:
478 raise ValueError("Attempts must be at least 1.")
479 if self.delay < 0:
480 raise ValueError("Delay must be a non-negative number.")
481 if self.backoff_multiplier < 1.0:
482 raise ValueError("Backoff multiplier must be at least 1.0.")
483
484
[docs]
485class RetryableError(Exception):
486 """
487 Exception raised for errors that should trigger a retry.
488 """
489
490 pass
491
492
[docs]
493class PreconditionFailedError(Exception):
494 """
495 Exception raised when a precondition fails. e.g. if-match, if-none-match, etc.
496 """
497
498 pass
499
500
[docs]
501class NotModifiedError(Exception):
502 """
503 Raised when a conditional operation fails because the resource has not been modified.
504
505 This typically occurs when using if-none-match with a specific generation/etag
506 and the resource's current generation/etag matches the specified one.
507 """
508
509 pass
510
511
[docs]
512class SourceVersionCheckMode(Enum):
513 """
514 Enum for controlling source version checking behavior.
515 """
516
517 INHERIT = "inherit" # Inherit from configuration (cache config)
518 ENABLE = "enable" # Always check source version
519 DISABLE = "disable" # Never check source version
520
521
[docs]
522@dataclass
523class Replica:
524 """
525 A tier of storage that can be used to store data.
526 """
527
528 replica_profile: str
529 read_priority: int
530
531
[docs]
532class AutoCommitConfig:
533 """
534 A data class that represents the configuration for auto commit.
535 """
536
537 interval_minutes: Optional[float] # The interval in minutes for auto commit.
538 at_exit: bool = False # if True, commit on program exit
539
540 def __init__(self, interval_minutes: Optional[float] = None, at_exit: bool = False) -> None:
541 self.interval_minutes = interval_minutes
542 self.at_exit = at_exit
543
544
[docs]
545class ExecutionMode(Enum):
546 """
547 Enum for controlling execution mode in sync operations.
548 """
549
550 LOCAL = "local"
551 RAY = "ray"
552
553
[docs]
554class PatternType(Enum):
555 """
556 Type of pattern operation for include/exclude filtering.
557 """
558
559 INCLUDE = "include"
560 EXCLUDE = "exclude"
561
562
563# Type alias for pattern matching
564PatternList = list[Tuple[PatternType, str]]