1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from abc import ABC, abstractmethod
17from collections.abc import Iterator
18from dataclasses import asdict, dataclass, field
19from datetime import datetime, timezone
20from enum import Enum
21from typing import IO, Any, Optional, Union
22
23from dateutil.parser import parse as dateutil_parser
24
25MSC_PROTOCOL_NAME = "msc"
26MSC_PROTOCOL = MSC_PROTOCOL_NAME + "://"
27
28DEFAULT_RETRY_ATTEMPTS = 3
29DEFAULT_RETRY_DELAY = 1.0
30
31# datetime.min is a naive datetime.
32#
33# This creates issues when doing datetime.astimezone(timezone.utc) since it assumes the local timezone for the naive datetime.
34# If the local timezone is offset behind UTC, it attempts to subtract off the offset which goes below the representable limit (i.e. an underflow).
35# A `ValueError: year 0 is out of range` is thrown as a result.
36AWARE_DATETIME_MIN = datetime.min.replace(tzinfo=timezone.utc)
37
38
[docs]
39@dataclass
40class Credentials:
41 """
42 A data class representing the credentials needed to access a storage provider.
43 """
44
45 #: The access key for authentication.
46 access_key: str
47 #: The secret key for authentication.
48 secret_key: str
49 #: An optional security token for temporary credentials.
50 token: Optional[str]
51 #: The expiration time of the credentials in ISO 8601 format.
52 expiration: Optional[str]
53 #: A dictionary for storing custom key-value pairs.
54 custom_fields: dict[str, Any] = field(default_factory=dict)
55
[docs]
56 def is_expired(self) -> bool:
57 """
58 Checks if the credentials are expired based on the expiration time.
59
60 :return: ``True`` if the credentials are expired, ``False`` otherwise.
61 """
62 expiry = dateutil_parser(self.expiration) if self.expiration else None
63 if expiry is None:
64 return False
65 return expiry <= datetime.now(tz=timezone.utc)
66
[docs]
67 def get_custom_field(self, key: str, default: Any = None) -> Any:
68 """
69 Retrieves a value from custom fields by its key.
70
71 :param key: The key to look up in custom fields.
72 :param default: The default value to return if the key is not found.
73 :return: The value associated with the key, or the default value if not found.
74 """
75 return self.custom_fields.get(key, default)
76
77
128
129
[docs]
130class CredentialsProvider(ABC):
131 """
132 Abstract base class for providing credentials to access a storage provider.
133 """
134
[docs]
135 @abstractmethod
136 def get_credentials(self) -> Credentials:
137 """
138 Retrieves the current credentials.
139
140 :return: The current credentials used for authentication.
141 """
142 pass
143
[docs]
144 @abstractmethod
145 def refresh_credentials(self) -> None:
146 """
147 Refreshes the credentials if they are expired or about to expire.
148 """
149 pass
150
151
[docs]
152@dataclass
153class Range:
154 """
155 Byte-range read.
156 """
157
158 offset: int
159 size: int
160
161
[docs]
162class StorageProvider(ABC):
163 """
164 Abstract base class for interacting with a storage provider.
165 """
166
[docs]
167 @abstractmethod
168 def put_object(
169 self,
170 path: str,
171 body: bytes,
172 if_match: Optional[str] = None,
173 if_none_match: Optional[str] = None,
174 attributes: Optional[dict[str, str]] = None,
175 ) -> None:
176 """
177 Uploads an object to the storage provider.
178
179 :param path: The path where the object will be stored.
180 :param body: The content of the object to store.
181 :param attributes: The attributes to add to the file
182 """
183 pass
184
[docs]
185 @abstractmethod
186 def get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
187 """
188 Retrieves an object from the storage provider.
189
190 :param path: The path where the object is stored.
191
192 :return: The content of the retrieved object.
193 """
194 pass
195
[docs]
196 @abstractmethod
197 def copy_object(self, src_path: str, dest_path: str) -> None:
198 """
199 Copies an object from source to destination in the storage provider.
200
201 :param src_path: The path of the source object to copy.
202 :param dest_path: The path of the destination.
203 """
204 pass
205
[docs]
206 @abstractmethod
207 def delete_object(self, path: str, if_match: Optional[str] = None) -> None:
208 """
209 Deletes an object from the storage provider.
210
211 :param path: The path of the object to delete.
212 :param if_match: Optional if-match value to use for conditional deletion.
213 """
214 pass
215
227
[docs]
228 @abstractmethod
229 def list_objects(
230 self,
231 prefix: str,
232 start_after: Optional[str] = None,
233 end_at: Optional[str] = None,
234 include_directories: bool = False,
235 attribute_filter_expression: Optional[str] = None,
236 ) -> Iterator[ObjectMetadata]:
237 """
238 Lists objects in the storage provider under the specified prefix.
239
240 :param prefix: The prefix or path to list objects under.
241 :param start_after: The key to start after (i.e. exclusive). An object with this key doesn't have to exist.
242 :param end_at: The key to end at (i.e. inclusive). An object with this key doesn't have to exist.
243 :param include_directories: Whether to include directories in the result. When True, directories are returned alongside objects.
244 :param attribute_filter_expression: The attribute filter expression to apply to the result.
245
246 :return: An iterator over objects metadata under the specified prefix.
247 """
248 pass
249
[docs]
250 @abstractmethod
251 def upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> None:
252 """
253 Uploads a file from the local file system to the storage provider.
254
255 :param remote_path: The path where the object will be stored.
256 :param f: The source file to upload. This can either be a string representing the local
257 file path, or a file-like object (e.g., an open file handle).
258 :param attributes: The attributes to add to the file if a new file is created.
259 """
260 pass
261
[docs]
262 @abstractmethod
263 def download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
264 """
265 Downloads a file from the storage provider to the local file system.
266
267 :param remote_path: The path of the file to download.
268 :param f: The destination for the downloaded file. This can either be a string representing
269 the local file path where the file will be saved, or a file-like object to write the
270 downloaded content into.
271 :param metadata: Metadata about the object to download.
272 """
273 pass
274
[docs]
275 @abstractmethod
276 def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]:
277 """
278 Matches and retrieves a list of object keys in the storage provider that match the specified pattern.
279
280 :param pattern: The pattern to match object keys against, supporting wildcards (e.g., ``*.txt``).
281 :param attribute_filter_expression: The attribute filter expression to apply to the result.
282
283 :return: A list of object keys that match the specified pattern.
284 """
285 pass
286
[docs]
287 @abstractmethod
288 def is_file(self, path: str) -> bool:
289 """
290 Checks whether the specified key in the storage provider points to a file (as opposed to a folder or directory).
291
292 :param path: The path to check.
293
294 :return: ``True`` if the key points to a file, ``False`` if it points to a directory or folder.
295 """
296 pass
297
298
400
401
[docs]
402@dataclass
403class StorageProviderConfig:
404 """
405 A data class that represents the configuration needed to initialize a storage provider.
406 """
407
408 #: The name or type of the storage provider (e.g., ``s3``, ``gcs``, ``oci``, ``azure``).
409 type: str
410 #: Additional options required to configure the storage provider (e.g., endpoint URLs, region, etc.).
411 options: Optional[dict[str, Any]] = None
412
413
[docs]
414class ProviderBundle(ABC):
415 """
416 Abstract base class that serves as a container for various providers (storage, credentials, and metadata)
417 that interact with a storage service. The :py:class:`ProviderBundle` abstracts access to these providers, allowing for
418 flexible implementations of cloud storage solutions.
419 """
420
421 @property
422 @abstractmethod
423 def storage_provider_config(self) -> StorageProviderConfig:
424 """
425 :return: The configuration for the storage provider, which includes the provider
426 name/type and additional options.
427 """
428 pass
429
430 @property
431 @abstractmethod
432 def credentials_provider(self) -> Optional[CredentialsProvider]:
433 """
434 :return: The credentials provider responsible for managing authentication credentials
435 required to access the storage service.
436 """
437 pass
438
439 @property
440 @abstractmethod
441 def metadata_provider(self) -> Optional[MetadataProvider]:
442 """
443 :return: The metadata provider responsible for retrieving metadata about objects in the storage service.
444 """
445 pass
446
447 @property
448 @abstractmethod
449 def replicas(self) -> list["Replica"]:
450 """
451 :return: The replicas configuration for this provider bundle, if any.
452 """
453 pass
454
455
[docs]
456@dataclass
457class RetryConfig:
458 """
459 A data class that represents the configuration for retry strategy.
460 """
461
462 #: The number of attempts before giving up. Must be at least 1.
463 attempts: int = DEFAULT_RETRY_ATTEMPTS
464 #: The delay (in seconds) between retry attempts. Must be a non-negative value.
465 delay: float = DEFAULT_RETRY_DELAY
466
467 def __post_init__(self) -> None:
468 if self.attempts < 1:
469 raise ValueError("Attempts must be at least 1.")
470 if self.delay < 0:
471 raise ValueError("Delay must be a non-negative number.")
472
473
[docs]
474class RetryableError(Exception):
475 """
476 Exception raised for errors that should trigger a retry.
477 """
478
479 pass
480
481
[docs]
482class PreconditionFailedError(Exception):
483 """
484 Exception raised when a precondition fails. e.g. if-match, if-none-match, etc.
485 """
486
487 pass
488
489
[docs]
490class NotModifiedError(Exception):
491 """Raised when a conditional operation fails because the resource has not been modified.
492
493 This typically occurs when using if-none-match with a specific generation/etag
494 and the resource's current generation/etag matches the specified one.
495 """
496
497 pass
498
499
[docs]
500class SourceVersionCheckMode(Enum):
501 """Enum for controlling source version checking behavior."""
502
503 INHERIT = "inherit" # Inherit from configuration (cache config)
504 ENABLE = "enable" # Always check source version
505 DISABLE = "disable" # Never check source version
506
507
[docs]
508@dataclass
509class Replica:
510 """A tier of storage that can be used to store data."""
511
512 replica_profile: str
513 read_priority: int
514
515
[docs]
516class AutoCommitConfig:
517 """
518 A data class that represents the configuration for auto commit.
519 """
520
521 interval_minutes: Optional[float] # The interval in minutes for auto commit.
522 at_exit: bool = False # if True, commit on program exit
523
524 def __init__(self, interval_minutes: Optional[float] = None, at_exit: bool = False) -> None:
525 self.interval_minutes = interval_minutes
526 self.at_exit = at_exit
527
528
[docs]
529class ExecutionMode(Enum):
530 """
531 Enum for controlling execution mode in sync operations.
532 """
533
534 LOCAL = "local"
535 RAY = "ray"