1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18from collections.abc import Callable, Iterator
19from datetime import datetime, timezone
20from typing import IO, Any, Optional, TypeVar, Union
21
22from aistore.sdk import Client
23from aistore.sdk.authn import AuthNClient
24from aistore.sdk.errors import AISError
25from aistore.sdk.obj.object_props import ObjectProps
26from dateutil.parser import parse as dateutil_parser
27from requests.exceptions import HTTPError
28from urllib3.util import Retry
29
30from ..telemetry import Telemetry
31from ..types import (
32 AWARE_DATETIME_MIN,
33 Credentials,
34 CredentialsProvider,
35 ObjectMetadata,
36 Range,
37)
38from ..utils import split_path, validate_attributes
39from .base import BaseStorageProvider
40
41_T = TypeVar("_T")
42
43PROVIDER = "ais"
44DEFAULT_PAGE_SIZE = 1000
45
46
[docs]
47class StaticAISCredentialProvider(CredentialsProvider):
48 """
49 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
50 """
51
52 _username: Optional[str]
53 _password: Optional[str]
54 _authn_endpoint: Optional[str]
55 _token: Optional[str]
56 _skip_verify: bool
57 _ca_cert: Optional[str]
58
59 def __init__(
60 self,
61 username: Optional[str] = None,
62 password: Optional[str] = None,
63 authn_endpoint: Optional[str] = None,
64 token: Optional[str] = None,
65 skip_verify: bool = True,
66 ca_cert: Optional[str] = None,
67 ):
68 """
69 Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials.
70
71 :param username: The username for the AIStore authentication.
72 :param password: The password for the AIStore authentication.
73 :param authn_endpoint: The AIStore authentication endpoint.
74 :param token: The AIStore authentication token. This is used for authentication if username,
75 password and authn_endpoint are not provided.
76 :param skip_verify: If true, skip SSL certificate verification.
77 :param ca_cert: Path to a CA certificate file for SSL verification.
78 """
79 self._username = username
80 self._password = password
81 self._authn_endpoint = authn_endpoint
82 self._token = token
83 self._skip_verify = skip_verify
84 self._ca_cert = ca_cert
85
[docs]
86 def get_credentials(self) -> Credentials:
87 if self._username and self._password and self._authn_endpoint:
88 authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert)
89 self._token = authn_client.login(self._username, self._password)
90 return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
91
[docs]
92 def refresh_credentials(self) -> None:
93 pass
94
95
[docs]
96class AIStoreStorageProvider(BaseStorageProvider):
97 """
98 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with NVIDIA AIStore.
99 """
100
101 def __init__(
102 self,
103 endpoint: str = os.getenv("AIS_ENDPOINT", ""),
104 provider: str = PROVIDER,
105 skip_verify: bool = True,
106 ca_cert: Optional[str] = None,
107 timeout: Optional[Union[float, tuple[float, float]]] = None,
108 retry: Optional[dict[str, Any]] = None,
109 base_path: str = "",
110 credentials_provider: Optional[CredentialsProvider] = None,
111 config_dict: Optional[dict[str, Any]] = None,
112 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
113 **kwargs: Any,
114 ) -> None:
115 """
116 AIStore client for managing buckets, objects, and ETL jobs.
117
118 :param endpoint: The AIStore endpoint.
119 :param skip_verify: Whether to skip SSL certificate verification.
120 :param ca_cert: Path to a CA certificate file for SSL verification.
121 :param timeout: Request timeout in seconds; a single float
122 for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read
123 timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout.
124 :param retry: ``urllib3.util.Retry`` parameters.
125 :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used.
126 :param base_path: The root prefix path within the bucket where all operations will be scoped.
127 :param credentials_provider: The provider to retrieve AIStore credentials.
128 :param config_dict: Resolved MSC config.
129 :param telemetry_provider: A function that provides a telemetry instance.
130 """
131 super().__init__(
132 base_path=base_path,
133 provider_name=PROVIDER,
134 config_dict=config_dict,
135 telemetry_provider=telemetry_provider,
136 )
137
138 # https://aistore.nvidia.com/docs/python-sdk#client.Client
139 client_retry = None if retry is None else Retry(**retry)
140 token = None
141 if credentials_provider:
142 token = credentials_provider.get_credentials().token
143 self.client = Client(
144 endpoint=endpoint,
145 retry=client_retry,
146 skip_verify=skip_verify,
147 ca_cert=ca_cert,
148 timeout=timeout,
149 token=token,
150 )
151 else:
152 self.client = Client(endpoint=endpoint, retry=client_retry)
153 self.provider = provider
154
155 def _translate_errors(
156 self,
157 func: Callable[[], _T],
158 operation: str,
159 bucket: str,
160 key: str,
161 ) -> _T:
162 """
163 Translates errors like timeouts and client errors.
164
165 :param func: The function that performs the actual object storage operation.
166 :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``).
167 :param bucket: The name of the object storage bucket involved in the operation.
168 :param key: The key of the object within the object storage bucket.
169
170 :return: The result of the object storage operation, typically the return value of the `func` callable.
171 """
172
173 try:
174 return func()
175 except AISError as error:
176 status_code = error.status_code
177 if status_code == 404:
178 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
179 error_info = f"status_code: {status_code}, message: {error.message}"
180 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
181 except HTTPError as error:
182 status_code = error.response.status_code
183 if status_code == 404:
184 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
185 else:
186 raise RuntimeError(
187 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}"
188 ) from error
189 except Exception as error:
190 raise RuntimeError(
191 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
192 ) from error
193
194 def _put_object(
195 self,
196 path: str,
197 body: bytes,
198 if_match: Optional[str] = None,
199 if_none_match: Optional[str] = None,
200 attributes: Optional[dict[str, str]] = None,
201 ) -> int:
202 # ais does not support if_match and if_none_match
203 bucket, key = split_path(path)
204
205 def _invoke_api() -> int:
206 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
207 obj.put_content(body)
208 validated_attributes = validate_attributes(attributes)
209 if validated_attributes:
210 obj.set_custom_props(custom_metadata=validated_attributes, replace_existing=True)
211
212 return len(body)
213
214 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
215
216 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
217 bucket, key = split_path(path)
218 if byte_range:
219 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
220 else:
221 bytes_range = None
222
223 def _invoke_api() -> bytes:
224 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
225 if byte_range:
226 reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType]
227 else:
228 reader = obj.get()
229 return reader.read_all()
230
231 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
232
233 def _copy_object(self, src_path: str, dest_path: str) -> int:
234 src_bucket, src_key = split_path(src_path)
235 dest_bucket, dest_key = split_path(dest_path)
236
237 def _invoke_api() -> int:
238 src_obj = self.client.bucket(bck_name=src_bucket, provider=self.provider).object(obj_name=src_key)
239 dest_obj = self.client.bucket(bck_name=dest_bucket, provider=self.provider).object(obj_name=dest_key)
240
241 # Get source size before copying
242 src_headers = src_obj.head()
243 src_props = ObjectProps(src_headers)
244
245 # Server-side copy (preserves custom metadata automatically)
246 src_obj.copy(to_obj=dest_obj) # type: ignore[attr-defined]
247
248 return int(src_props.size)
249
250 return self._translate_errors(
251 _invoke_api, operation="COPY", bucket=f"{src_bucket}->{dest_bucket}", key=f"{src_key}->{dest_key}"
252 )
253
254 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
255 bucket, key = split_path(path)
256
257 def _invoke_api() -> None:
258 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
259 # AIS doesn't support if-match deletion, so we implement a fallback mechanism
260 if if_match:
261 raise NotImplementedError("AIStore does not support if-match deletion")
262 # Perform deletion
263 obj.delete()
264
265 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
266
267 def _is_dir(self, path: str) -> bool:
268 # Ensure the path ends with '/' to mimic a directory
269 path = self._append_delimiter(path)
270
271 bucket, prefix = split_path(path)
272
273 def _invoke_api() -> bool:
274 # List objects with the given prefix (limit to 1 for efficiency)
275 objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter(
276 prefix=prefix, page_size=1
277 )
278 # Check if there are any objects with this prefix
279 return any(True for _ in objects)
280
281 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
282
283 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
284 bucket, key = split_path(path)
285 if path.endswith("/") or (bucket and not key):
286 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
287 # which metadata is not guaranteed to exist for cases such as
288 # "virtual prefix" that was never explicitly created.
289 if self._is_dir(path):
290 return ObjectMetadata(
291 key=path,
292 type="directory",
293 content_length=0,
294 last_modified=AWARE_DATETIME_MIN,
295 )
296 else:
297 raise FileNotFoundError(f"Directory {path} does not exist.")
298 else:
299
300 def _invoke_api() -> ObjectMetadata:
301 obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key)
302 try:
303 headers = obj.head()
304 props = ObjectProps(headers)
305
306 # The access time is not always present in the response.
307 if props.access_time:
308 last_modified = datetime.fromtimestamp(int(props.access_time) / 1e9).astimezone(timezone.utc)
309 else:
310 last_modified = AWARE_DATETIME_MIN
311
312 return ObjectMetadata(
313 key=key,
314 content_length=int(props.size), # pyright: ignore [reportArgumentType]
315 last_modified=last_modified,
316 etag=props.checksum_value,
317 metadata=props.custom_metadata,
318 )
319 except (AISError, HTTPError) as e:
320 # Check if this might be a virtual directory (prefix with objects under it)
321 status_code = None
322 if isinstance(e, AISError):
323 status_code = e.status_code
324 elif isinstance(e, HTTPError):
325 status_code = e.response.status_code
326
327 if status_code == 404:
328 if self._is_dir(path):
329 return ObjectMetadata(
330 key=path + "/",
331 type="directory",
332 content_length=0,
333 last_modified=AWARE_DATETIME_MIN,
334 )
335 # Re-raise to be handled by _translate_errors
336 raise
337
338 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
339
340 def _list_objects(
341 self,
342 path: str,
343 start_after: Optional[str] = None,
344 end_at: Optional[str] = None,
345 include_directories: bool = False,
346 follow_symlinks: bool = True,
347 ) -> Iterator[ObjectMetadata]:
348 bucket, prefix = split_path(path)
349
350 # Get the prefix of the start_after and end_at paths relative to the bucket.
351 if start_after:
352 _, start_after = split_path(start_after)
353 if end_at:
354 _, end_at = split_path(end_at)
355
356 def _invoke_api() -> Iterator[ObjectMetadata]:
357 # AIS has no start key option like other object stores.
358 all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter(
359 prefix=prefix, props="name,size,atime,checksum,cone", page_size=DEFAULT_PAGE_SIZE
360 )
361
362 # Assume AIS guarantees lexicographical order.
363 for bucket_entry in all_objects:
364 obj = bucket_entry.object
365 key = obj.name
366 props = bucket_entry.generate_object_props()
367
368 # The access time is not always present in the response.
369 if props.access_time:
370 last_modified = dateutil_parser(props.access_time).astimezone(timezone.utc)
371 else:
372 last_modified = AWARE_DATETIME_MIN
373
374 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
375 yield ObjectMetadata(
376 key=key, content_length=int(props.size), last_modified=last_modified, etag=props.checksum_value
377 )
378 elif end_at is not None and end_at < key:
379 return
380
381 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
382
383 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
384 file_size: int = 0
385
386 if isinstance(f, str):
387 with open(f, "rb") as fp:
388 body = fp.read()
389 file_size = len(body)
390 self._put_object(remote_path, body, attributes=attributes)
391 else:
392 if isinstance(f, io.StringIO):
393 body = f.read().encode("utf-8")
394 file_size = len(body)
395 self._put_object(remote_path, body, attributes=attributes)
396 else:
397 body = f.read()
398 file_size = len(body)
399 self._put_object(remote_path, body, attributes=attributes)
400
401 return file_size
402
403 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
404 if metadata is None:
405 metadata = self._get_object_metadata(remote_path)
406
407 if isinstance(f, str):
408 if os.path.dirname(f):
409 os.makedirs(os.path.dirname(f), exist_ok=True)
410 with open(f, "wb") as fp:
411 fp.write(self._get_object(remote_path))
412 else:
413 if isinstance(f, io.StringIO):
414 f.write(self._get_object(remote_path).decode("utf-8"))
415 else:
416 f.write(self._get_object(remote_path))
417
418 return metadata.content_length