1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18from collections.abc import Callable, Iterator
19from datetime import datetime, timezone
20from typing import IO, Any, Optional, TypeVar, Union
21
22from aistore.sdk import Client
23from aistore.sdk.authn import AuthNClient
24from aistore.sdk.errors import AISError
25from aistore.sdk.obj.object_props import ObjectProps
26from dateutil.parser import parse as dateutil_parser
27from requests.exceptions import HTTPError
28from urllib3.util import Retry
29
30from ..constants import DEFAULT_READ_TIMEOUT
31from ..telemetry import Telemetry
32from ..types import (
33 AWARE_DATETIME_MIN,
34 Credentials,
35 CredentialsProvider,
36 ObjectMetadata,
37 Range,
38)
39from ..utils import split_path, validate_attributes
40from .base import BaseStorageProvider
41
42_T = TypeVar("_T")
43
44PROVIDER = "ais"
45DEFAULT_PAGE_SIZE = 1000
46
47
[docs]
48class StaticAISCredentialProvider(CredentialsProvider):
49 """
50 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
51 """
52
53 _username: Optional[str]
54 _password: Optional[str]
55 _authn_endpoint: Optional[str]
56 _token: Optional[str]
57 _skip_verify: bool
58 _ca_cert: Optional[str]
59
60 def __init__(
61 self,
62 username: Optional[str] = None,
63 password: Optional[str] = None,
64 authn_endpoint: Optional[str] = None,
65 token: Optional[str] = None,
66 skip_verify: bool = True,
67 ca_cert: Optional[str] = None,
68 ):
69 """
70 Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials.
71
72 :param username: The username for the AIStore authentication.
73 :param password: The password for the AIStore authentication.
74 :param authn_endpoint: The AIStore authentication endpoint.
75 :param token: The AIStore authentication token. This is used for authentication if username,
76 password and authn_endpoint are not provided.
77 :param skip_verify: If true, skip SSL certificate verification.
78 :param ca_cert: Path to a CA certificate file for SSL verification.
79 """
80 self._username = username
81 self._password = password
82 self._authn_endpoint = authn_endpoint
83 self._token = token
84 self._skip_verify = skip_verify
85 self._ca_cert = ca_cert
86
[docs]
87 def get_credentials(self) -> Credentials:
88 if self._username and self._password and self._authn_endpoint:
89 authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert)
90 self._token = authn_client.login(self._username, self._password)
91 return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
92
[docs]
93 def refresh_credentials(self) -> None:
94 pass
95
96
[docs]
97class AIStoreStorageProvider(BaseStorageProvider):
98 """
99 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with NVIDIA AIStore.
100 """
101
102 def __init__(
103 self,
104 endpoint: str = os.getenv("AIS_ENDPOINT", ""),
105 provider: str = PROVIDER,
106 skip_verify: bool = True,
107 ca_cert: Optional[str] = None,
108 timeout: Optional[Union[float, tuple[float, float]]] = None,
109 retry: Optional[dict[str, Any]] = None,
110 base_path: str = "",
111 credentials_provider: Optional[CredentialsProvider] = None,
112 config_dict: Optional[dict[str, Any]] = None,
113 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
114 **kwargs: Any,
115 ) -> None:
116 """
117 AIStore client for managing buckets, objects, and ETL jobs.
118
119 :param endpoint: The AIStore endpoint.
120 :param skip_verify: Whether to skip SSL certificate verification.
121 :param ca_cert: Path to a CA certificate file for SSL verification.
122 :param timeout: Request timeout in seconds; a single float
123 for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read
124 timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout.
125 :param retry: ``urllib3.util.Retry`` parameters.
126 :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used.
127 :param base_path: The root prefix path within the bucket where all operations will be scoped.
128 :param credentials_provider: The provider to retrieve AIStore credentials.
129 :param config_dict: Resolved MSC config.
130 :param telemetry_provider: A function that provides a telemetry instance.
131 """
132 super().__init__(
133 base_path=base_path,
134 provider_name=PROVIDER,
135 config_dict=config_dict,
136 telemetry_provider=telemetry_provider,
137 )
138
139 # https://aistore.nvidia.com/docs/python-sdk#client.Client
140 client_retry = None if retry is None else Retry(**retry)
141 token = None
142 if timeout is None:
143 timeout = float(DEFAULT_READ_TIMEOUT)
144 if credentials_provider:
145 token = credentials_provider.get_credentials().token
146 self.client = Client(
147 endpoint=endpoint,
148 retry=client_retry,
149 skip_verify=skip_verify,
150 ca_cert=ca_cert,
151 timeout=timeout,
152 token=token,
153 )
154 else:
155 self.client = Client(
156 endpoint=endpoint, retry=client_retry, timeout=timeout, skip_verify=skip_verify, ca_cert=ca_cert
157 )
158 self.provider = provider
159
160 def _translate_errors(
161 self,
162 func: Callable[[], _T],
163 operation: str,
164 bucket: str,
165 key: str,
166 ) -> _T:
167 """
168 Translates errors like timeouts and client errors.
169
170 :param func: The function that performs the actual object storage operation.
171 :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``).
172 :param bucket: The name of the object storage bucket involved in the operation.
173 :param key: The key of the object within the object storage bucket.
174
175 :return: The result of the object storage operation, typically the return value of the `func` callable.
176 """
177
178 try:
179 return func()
180 except AISError as error:
181 status_code = error.status_code
182 if status_code == 404:
183 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
184 error_info = f"status_code: {status_code}, message: {error.message}"
185 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
186 except HTTPError as error:
187 status_code = error.response.status_code
188 if status_code == 404:
189 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
190 else:
191 raise RuntimeError(
192 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}"
193 ) from error
194 except Exception as error:
195 raise RuntimeError(
196 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
197 ) from error
198
199 def _put_object(
200 self,
201 path: str,
202 body: bytes,
203 if_match: Optional[str] = None,
204 if_none_match: Optional[str] = None,
205 attributes: Optional[dict[str, str]] = None,
206 ) -> int:
207 # ais does not support if_match and if_none_match
208 bucket, key = split_path(path)
209
210 def _invoke_api() -> int:
211 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
212 obj.put_content(body)
213 validated_attributes = validate_attributes(attributes)
214 if validated_attributes:
215 obj.set_custom_props(custom_metadata=validated_attributes, replace_existing=True)
216
217 return len(body)
218
219 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
220
221 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
222 bucket, key = split_path(path)
223 if byte_range:
224 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
225 else:
226 bytes_range = None
227
228 def _invoke_api() -> bytes:
229 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
230 if byte_range:
231 reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType]
232 else:
233 reader = obj.get()
234 return reader.read_all()
235
236 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
237
238 def _copy_object(self, src_path: str, dest_path: str) -> int:
239 src_bucket, src_key = split_path(src_path)
240 dest_bucket, dest_key = split_path(dest_path)
241
242 def _invoke_api() -> int:
243 src_obj = self.client.bucket(bck_name=src_bucket, provider=self.provider).object(obj_name=src_key)
244 dest_obj = self.client.bucket(bck_name=dest_bucket, provider=self.provider).object(obj_name=dest_key)
245
246 # Get source size before copying
247 src_headers = src_obj.head()
248 src_props = ObjectProps(src_headers)
249
250 # Server-side copy (preserves custom metadata automatically)
251 src_obj.copy(to_obj=dest_obj) # type: ignore[attr-defined]
252
253 return int(src_props.size)
254
255 return self._translate_errors(
256 _invoke_api, operation="COPY", bucket=f"{src_bucket}->{dest_bucket}", key=f"{src_key}->{dest_key}"
257 )
258
259 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
260 bucket, key = split_path(path)
261
262 def _invoke_api() -> None:
263 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
264 # AIS doesn't support if-match deletion, so we implement a fallback mechanism
265 if if_match:
266 raise NotImplementedError("AIStore does not support if-match deletion")
267 # Perform deletion
268 obj.delete()
269
270 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
271
272 def _is_dir(self, path: str) -> bool:
273 # Ensure the path ends with '/' to mimic a directory
274 path = self._append_delimiter(path)
275
276 bucket, prefix = split_path(path)
277
278 def _invoke_api() -> bool:
279 # List objects with the given prefix (limit to 1 for efficiency)
280 objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter(
281 prefix=prefix, page_size=1
282 )
283 # Check if there are any objects with this prefix
284 return any(True for _ in objects)
285
286 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
287
288 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
289 bucket, key = split_path(path)
290 if path.endswith("/") or (bucket and not key):
291 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
292 # which metadata is not guaranteed to exist for cases such as
293 # "virtual prefix" that was never explicitly created.
294 if self._is_dir(path):
295 return ObjectMetadata(
296 key=path,
297 type="directory",
298 content_length=0,
299 last_modified=AWARE_DATETIME_MIN,
300 )
301 else:
302 raise FileNotFoundError(f"Directory {path} does not exist.")
303 else:
304
305 def _invoke_api() -> ObjectMetadata:
306 obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key)
307 try:
308 headers = obj.head()
309 props = ObjectProps(headers)
310
311 # The access time is not always present in the response.
312 if props.access_time:
313 last_modified = datetime.fromtimestamp(int(props.access_time) / 1e9).astimezone(timezone.utc)
314 else:
315 last_modified = AWARE_DATETIME_MIN
316
317 return ObjectMetadata(
318 key=key,
319 content_length=int(props.size), # pyright: ignore [reportArgumentType]
320 last_modified=last_modified,
321 etag=props.checksum_value,
322 metadata=props.custom_metadata,
323 )
324 except (AISError, HTTPError) as e:
325 # Check if this might be a virtual directory (prefix with objects under it)
326 status_code = None
327 if isinstance(e, AISError):
328 status_code = e.status_code
329 elif isinstance(e, HTTPError):
330 status_code = e.response.status_code
331
332 if status_code == 404:
333 if self._is_dir(path):
334 return ObjectMetadata(
335 key=path + "/",
336 type="directory",
337 content_length=0,
338 last_modified=AWARE_DATETIME_MIN,
339 )
340 # Re-raise to be handled by _translate_errors
341 raise
342
343 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
344
345 def _list_objects(
346 self,
347 path: str,
348 start_after: Optional[str] = None,
349 end_at: Optional[str] = None,
350 include_directories: bool = False,
351 follow_symlinks: bool = True,
352 ) -> Iterator[ObjectMetadata]:
353 bucket, prefix = split_path(path)
354
355 # Get the prefix of the start_after and end_at paths relative to the bucket.
356 if start_after:
357 _, start_after = split_path(start_after)
358 if end_at:
359 _, end_at = split_path(end_at)
360
361 def _invoke_api() -> Iterator[ObjectMetadata]:
362 # AIS has no start key option like other object stores.
363 all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter(
364 prefix=prefix, props="name,size,atime,checksum,cone", page_size=DEFAULT_PAGE_SIZE
365 )
366
367 # Assume AIS guarantees lexicographical order.
368 for bucket_entry in all_objects:
369 obj = bucket_entry.object
370 key = obj.name
371 props = bucket_entry.generate_object_props()
372
373 # The access time is not always present in the response.
374 if props.access_time:
375 last_modified = dateutil_parser(props.access_time).astimezone(timezone.utc)
376 else:
377 last_modified = AWARE_DATETIME_MIN
378
379 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
380 yield ObjectMetadata(
381 key=key, content_length=int(props.size), last_modified=last_modified, etag=props.checksum_value
382 )
383 elif end_at is not None and end_at < key:
384 return
385
386 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
387
388 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
389 file_size: int = 0
390
391 if isinstance(f, str):
392 with open(f, "rb") as fp:
393 body = fp.read()
394 file_size = len(body)
395 self._put_object(remote_path, body, attributes=attributes)
396 else:
397 if isinstance(f, io.StringIO):
398 body = f.read().encode("utf-8")
399 file_size = len(body)
400 self._put_object(remote_path, body, attributes=attributes)
401 else:
402 body = f.read()
403 file_size = len(body)
404 self._put_object(remote_path, body, attributes=attributes)
405
406 return file_size
407
408 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
409 if metadata is None:
410 metadata = self._get_object_metadata(remote_path)
411
412 if isinstance(f, str):
413 if os.path.dirname(f):
414 os.makedirs(os.path.dirname(f), exist_ok=True)
415 with open(f, "wb") as fp:
416 fp.write(self._get_object(remote_path))
417 else:
418 if isinstance(f, io.StringIO):
419 f.write(self._get_object(remote_path).decode("utf-8"))
420 else:
421 f.write(self._get_object(remote_path))
422
423 return metadata.content_length