1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from abc import ABC, abstractmethod
17from dataclasses import asdict, dataclass, field
18from datetime import datetime, timezone
19from typing import IO, Any, Dict, Iterator, List, Optional, Tuple, Union
20
21from dateutil.parser import parse as dateutil_parser
22
23MSC_PROTOCOL_NAME = "msc"
24MSC_PROTOCOL = MSC_PROTOCOL_NAME + "://"
25
26DEFAULT_POSIX_PROFILE_NAME = "default"
27DEFAULT_POSIX_PROFILE = {
28 "profiles": {DEFAULT_POSIX_PROFILE_NAME: {"storage_provider": {"type": "file", "options": {"base_path": "/"}}}}
29}
30
31DEFAULT_RETRY_ATTEMPTS = 3
32DEFAULT_RETRY_DELAY = 1.0
33
34
[docs]
35@dataclass
36class Credentials:
37 """
38 A data class representing the credentials needed to access a storage provider.
39 """
40
41 #: The access key for authentication.
42 access_key: str
43 #: The secret key for authentication.
44 secret_key: str
45 #: An optional security token for temporary credentials.
46 token: Optional[str]
47 #: The expiration time of the credentials in ISO 8601 format.
48 expiration: Optional[str]
49
[docs]
50 def is_expired(self) -> bool:
51 """
52 Checks if the credentials are expired based on the expiration time.
53
54 :return: ``True`` if the credentials are expired, ``False`` otherwise.
55 """
56 expiry = dateutil_parser(self.expiration) if self.expiration else None
57 if expiry is None:
58 return False
59 return expiry <= datetime.now(tz=timezone.utc)
60
61
106
107
[docs]
108class CredentialsProvider(ABC):
109 """
110 Abstract base class for providing credentials to access a storage provider.
111 """
112
[docs]
113 @abstractmethod
114 def get_credentials(self) -> Credentials:
115 """
116 Retrieves the current credentials.
117
118 :return: The current credentials used for authentication.
119 """
120 pass
121
[docs]
122 @abstractmethod
123 def refresh_credentials(self) -> None:
124 """
125 Refreshes the credentials if they are expired or about to expire.
126 """
127 pass
128
129
[docs]
130@dataclass
131class Range:
132 """
133 Byte-range read.
134 """
135
136 offset: int
137 size: int
138
139
[docs]
140class StorageProvider(ABC):
141 """
142 Abstract base class for interacting with a storage provider.
143 """
144
[docs]
145 @abstractmethod
146 def put_object(self, path: str, body: bytes) -> None:
147 """
148 Uploads an object to the storage provider.
149
150 :param path: The path where the object will be stored.
151 :param body: The content of the object to store.
152 """
153 pass
154
[docs]
155 @abstractmethod
156 def get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
157 """
158 Retrieves an object from the storage provider.
159
160 :param path: The path where the object is stored.
161
162 :return: The content of the retrieved object.
163 """
164 pass
165
[docs]
166 @abstractmethod
167 def copy_object(self, src_path: str, dest_path: str) -> None:
168 """
169 Copies an object from source to destination in the storage provider.
170
171 :param src_path: The path of the source object to copy.
172 :param dest_path: The path of the destination.
173 """
174 pass
175
[docs]
176 @abstractmethod
177 def delete_object(self, path: str) -> None:
178 """
179 Deletes an object from the storage provider.
180
181 :param path: The path of the object to delete.
182 """
183 pass
184
195
[docs]
196 @abstractmethod
197 def list_objects(
198 self,
199 prefix: str,
200 start_after: Optional[str] = None,
201 end_at: Optional[str] = None,
202 include_directories: bool = False,
203 ) -> Iterator[ObjectMetadata]:
204 """
205 Lists objects in the storage provider under the specified prefix.
206
207 :param prefix: The prefix or path to list objects under.
208 :param start_after: The key to start after (i.e. exclusive). An object with this key doesn't have to exist.
209 :param end_at: The key to end at (i.e. inclusive). An object with this key doesn't have to exist.
210 :param include_directories: Whether to include directories in the result. When True, directories are returned alongside objects.
211
212 :return: An iterator over objects metadata under the specified prefix.
213 """
214 pass
215
[docs]
216 @abstractmethod
217 def upload_file(self, remote_path: str, f: Union[str, IO]) -> None:
218 """
219 Uploads a file from the local file system to the storage provider.
220
221 :param remote_path: The path where the object will be stored.
222 :param f: The source file to upload. This can either be a string representing the local
223 file path, or a file-like object (e.g., an open file handle).
224 """
225 pass
226
[docs]
227 @abstractmethod
228 def download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
229 """
230 Downloads a file from the storage provider to the local file system.
231
232 :param remote_path: The path of the file to download.
233 :param f: The destination for the downloaded file. This can either be a string representing
234 the local file path where the file will be saved, or a file-like object to write the
235 downloaded content into.
236 :param metadata: Metadata about the object to download.
237 """
238 pass
239
[docs]
240 @abstractmethod
241 def glob(self, pattern: str) -> List[str]:
242 """
243 Matches and retrieves a list of object keys in the storage provider that match the specified pattern.
244
245 :param pattern: The pattern to match object keys against, supporting wildcards (e.g., ``*.txt``).
246
247 :return: A list of object keys that match the specified pattern.
248 """
249 pass
250
[docs]
251 @abstractmethod
252 def is_file(self, path: str) -> bool:
253 """
254 Checks whether the specified key in the storage provider points to a file (as opposed to a folder or directory).
255
256 :param path: The path to check.
257
258 :return: ``True`` if the key points to a file, ``False`` if it points to a directory or folder.
259 """
260 pass
261
262
353
354
[docs]
355@dataclass
356class StorageProviderConfig:
357 """
358 A data class that represents the configuration needed to initialize a storage provider.
359 """
360
361 #: The name or type of the storage provider (e.g., ``s3``, ``gcs``, ``oci``, ``azure``).
362 type: str
363 #: Additional options required to configure the storage provider (e.g., endpoint URLs, region, etc.).
364 options: Optional[Dict[str, Any]] = None
365
366
[docs]
367class ProviderBundle(ABC):
368 """
369 Abstract base class that serves as a container for various providers (storage, credentials, and metadata)
370 that interact with a storage service. The :py:class:`ProviderBundle` abstracts access to these providers, allowing for
371 flexible implementations of cloud storage solutions.
372 """
373
374 @property
375 @abstractmethod
376 def storage_provider_config(self) -> StorageProviderConfig:
377 """
378 :return: The configuration for the storage provider, which includes the provider
379 name/type and additional options.
380 """
381 pass
382
383 @property
384 @abstractmethod
385 def credentials_provider(self) -> Optional[CredentialsProvider]:
386 """
387 :return: The credentials provider responsible for managing authentication credentials
388 required to access the storage service.
389 """
390 pass
391
392 @property
393 @abstractmethod
394 def metadata_provider(self) -> Optional[MetadataProvider]:
395 """
396 :return: The metadata provider responsible for retrieving metadata about objects in the storage service.
397 """
398 pass
399
400
[docs]
401@dataclass
402class RetryConfig:
403 """
404 A data class that represents the configuration for retry strategy.
405 """
406
407 #: The number of attempts before giving up. Must be at least 1.
408 attempts: int = DEFAULT_RETRY_ATTEMPTS
409 #: The delay (in seconds) between retry attempts. Must be a non-negative value.
410 delay: float = DEFAULT_RETRY_DELAY
411
412 def __post_init__(self) -> None:
413 if self.attempts < 1:
414 raise ValueError("Attempts must be at least 1.")
415 if self.delay < 0:
416 raise ValueError("Delay must be a non-negative number.")
417
418
[docs]
419class RetryableError(Exception):
420 """
421 Exception raised for errors that should trigger a retry.
422 """
423
424 pass