1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18from collections.abc import Callable, Iterator
19from datetime import datetime, timezone
20from typing import IO, Any, Optional, TypeVar, Union
21
22from aistore.sdk import Client
23from aistore.sdk.authn import AuthNClient
24from aistore.sdk.errors import AISError
25from aistore.sdk.obj.object_props import ObjectProps
26from dateutil.parser import parse as dateutil_parser
27from requests.exceptions import HTTPError
28from urllib3.util import Retry
29
30from ..telemetry import Telemetry
31from ..types import (
32 AWARE_DATETIME_MIN,
33 Credentials,
34 CredentialsProvider,
35 ObjectMetadata,
36 Range,
37)
38from ..utils import split_path, validate_attributes
39from .base import BaseStorageProvider
40
41_T = TypeVar("_T")
42
43PROVIDER = "ais"
44DEFAULT_PAGE_SIZE = 1000
45
46
[docs]
47class StaticAISCredentialProvider(CredentialsProvider):
48 """
49 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
50 """
51
52 _username: Optional[str]
53 _password: Optional[str]
54 _authn_endpoint: Optional[str]
55 _token: Optional[str]
56 _skip_verify: bool
57 _ca_cert: Optional[str]
58
59 def __init__(
60 self,
61 username: Optional[str] = None,
62 password: Optional[str] = None,
63 authn_endpoint: Optional[str] = None,
64 token: Optional[str] = None,
65 skip_verify: bool = True,
66 ca_cert: Optional[str] = None,
67 ):
68 """
69 Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials.
70
71 :param username: The username for the AIStore authentication.
72 :param password: The password for the AIStore authentication.
73 :param authn_endpoint: The AIStore authentication endpoint.
74 :param token: The AIStore authentication token. This is used for authentication if username,
75 password and authn_endpoint are not provided.
76 :param skip_verify: If true, skip SSL certificate verification.
77 :param ca_cert: Path to a CA certificate file for SSL verification.
78 """
79 self._username = username
80 self._password = password
81 self._authn_endpoint = authn_endpoint
82 self._token = token
83 self._skip_verify = skip_verify
84 self._ca_cert = ca_cert
85
[docs]
86 def get_credentials(self) -> Credentials:
87 if self._username and self._password and self._authn_endpoint:
88 authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert)
89 self._token = authn_client.login(self._username, self._password)
90 return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
91
[docs]
92 def refresh_credentials(self) -> None:
93 pass
94
95
[docs]
96class AIStoreStorageProvider(BaseStorageProvider):
97 """
98 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with NVIDIA AIStore.
99 """
100
101 def __init__(
102 self,
103 endpoint: str = os.getenv("AIS_ENDPOINT", ""),
104 provider: str = PROVIDER,
105 skip_verify: bool = True,
106 ca_cert: Optional[str] = None,
107 timeout: Optional[Union[float, tuple[float, float]]] = None,
108 retry: Optional[dict[str, Any]] = None,
109 base_path: str = "",
110 credentials_provider: Optional[CredentialsProvider] = None,
111 config_dict: Optional[dict[str, Any]] = None,
112 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
113 **kwargs: Any,
114 ) -> None:
115 """
116 AIStore client for managing buckets, objects, and ETL jobs.
117
118 :param endpoint: The AIStore endpoint.
119 :param skip_verify: Whether to skip SSL certificate verification.
120 :param ca_cert: Path to a CA certificate file for SSL verification.
121 :param timeout: Request timeout in seconds; a single float
122 for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read
123 timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout.
124 :param retry: ``urllib3.util.Retry`` parameters.
125 :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used.
126 :param base_path: The root prefix path within the bucket where all operations will be scoped.
127 :param credentials_provider: The provider to retrieve AIStore credentials.
128 :param config_dict: Resolved MSC config.
129 :param telemetry_provider: A function that provides a telemetry instance.
130 """
131 super().__init__(
132 base_path=base_path,
133 provider_name=PROVIDER,
134 config_dict=config_dict,
135 telemetry_provider=telemetry_provider,
136 )
137
138 # https://aistore.nvidia.com/docs/python-sdk#client.Client
139 client_retry = None if retry is None else Retry(**retry)
140 token = None
141 if credentials_provider:
142 token = credentials_provider.get_credentials().token
143 self.client = Client(
144 endpoint=endpoint,
145 retry=client_retry,
146 skip_verify=skip_verify,
147 ca_cert=ca_cert,
148 timeout=timeout,
149 token=token,
150 )
151 else:
152 self.client = Client(endpoint=endpoint, retry=client_retry)
153 self.provider = provider
154
155 def _translate_errors(
156 self,
157 func: Callable[[], _T],
158 operation: str,
159 bucket: str,
160 key: str,
161 ) -> _T:
162 """
163 Translates errors like timeouts and client errors.
164
165 :param func: The function that performs the actual object storage operation.
166 :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``).
167 :param bucket: The name of the object storage bucket involved in the operation.
168 :param key: The key of the object within the object storage bucket.
169
170 :return: The result of the object storage operation, typically the return value of the `func` callable.
171 """
172
173 try:
174 return func()
175 except AISError as error:
176 status_code = error.status_code
177 if status_code == 404:
178 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
179 error_info = f"status_code: {status_code}, message: {error.message}"
180 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
181 except HTTPError as error:
182 status_code = error.response.status_code
183 if status_code == 404:
184 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
185 else:
186 raise RuntimeError(
187 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}"
188 ) from error
189 except Exception as error:
190 raise RuntimeError(
191 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
192 ) from error
193
194 def _put_object(
195 self,
196 path: str,
197 body: bytes,
198 if_match: Optional[str] = None,
199 if_none_match: Optional[str] = None,
200 attributes: Optional[dict[str, str]] = None,
201 ) -> int:
202 # ais does not support if_match and if_none_match
203 bucket, key = split_path(path)
204
205 def _invoke_api() -> int:
206 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
207 obj.put_content(body)
208 validated_attributes = validate_attributes(attributes)
209 if validated_attributes:
210 obj.set_custom_props(custom_metadata=validated_attributes, replace_existing=True)
211
212 return len(body)
213
214 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
215
216 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
217 bucket, key = split_path(path)
218 if byte_range:
219 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
220 else:
221 bytes_range = None
222
223 def _invoke_api() -> bytes:
224 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
225 if byte_range:
226 reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType]
227 else:
228 reader = obj.get()
229 return reader.read_all()
230
231 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
232
233 def _copy_object(self, src_path: str, dest_path: str) -> int:
234 src_bucket, src_key = split_path(src_path)
235 dest_bucket, dest_key = split_path(dest_path)
236
237 def _invoke_api() -> int:
238 src_obj = self.client.bucket(bck_name=src_bucket, provider=self.provider).object(obj_name=src_key)
239 dest_obj = self.client.bucket(bck_name=dest_bucket, provider=self.provider).object(obj_name=dest_key)
240
241 # Get source size before copying
242 src_headers = src_obj.head()
243 src_props = ObjectProps(src_headers)
244
245 # Server-side copy (preserves custom metadata automatically)
246 src_obj.copy(to_obj=dest_obj) # type: ignore[attr-defined]
247
248 return int(src_props.size)
249
250 return self._translate_errors(
251 _invoke_api, operation="COPY", bucket=f"{src_bucket}->{dest_bucket}", key=f"{src_key}->{dest_key}"
252 )
253
254 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
255 bucket, key = split_path(path)
256
257 def _invoke_api() -> None:
258 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
259 # AIS doesn't support if-match deletion, so we implement a fallback mechanism
260 if if_match:
261 raise NotImplementedError("AIStore does not support if-match deletion")
262 # Perform deletion
263 obj.delete()
264
265 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
266
267 def _is_dir(self, path: str) -> bool:
268 # Ensure the path ends with '/' to mimic a directory
269 path = self._append_delimiter(path)
270
271 bucket, prefix = split_path(path)
272
273 def _invoke_api() -> bool:
274 # List objects with the given prefix (limit to 1 for efficiency)
275 objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter(
276 prefix=prefix, page_size=1
277 )
278 # Check if there are any objects with this prefix
279 return any(True for _ in objects)
280
281 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
282
283 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
284 bucket, key = split_path(path)
285 if path.endswith("/") or (bucket and not key):
286 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
287 # which metadata is not guaranteed to exist for cases such as
288 # "virtual prefix" that was never explicitly created.
289 if self._is_dir(path):
290 return ObjectMetadata(
291 key=path,
292 type="directory",
293 content_length=0,
294 last_modified=AWARE_DATETIME_MIN,
295 )
296 else:
297 raise FileNotFoundError(f"Directory {path} does not exist.")
298 else:
299
300 def _invoke_api() -> ObjectMetadata:
301 obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key)
302 try:
303 headers = obj.head()
304 props = ObjectProps(headers)
305
306 return ObjectMetadata(
307 key=key,
308 content_length=int(props.size), # pyright: ignore [reportArgumentType]
309 last_modified=datetime.fromtimestamp(int(props.access_time) / 1e9).astimezone(timezone.utc),
310 etag=props.checksum_value,
311 metadata=props.custom_metadata,
312 )
313 except (AISError, HTTPError) as e:
314 # Check if this might be a virtual directory (prefix with objects under it)
315 status_code = None
316 if isinstance(e, AISError):
317 status_code = e.status_code
318 elif isinstance(e, HTTPError):
319 status_code = e.response.status_code
320
321 if status_code == 404:
322 if self._is_dir(path):
323 return ObjectMetadata(
324 key=path + "/",
325 type="directory",
326 content_length=0,
327 last_modified=AWARE_DATETIME_MIN,
328 )
329 # Re-raise to be handled by _translate_errors
330 raise
331
332 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
333
334 def _list_objects(
335 self,
336 path: str,
337 start_after: Optional[str] = None,
338 end_at: Optional[str] = None,
339 include_directories: bool = False,
340 follow_symlinks: bool = True,
341 ) -> Iterator[ObjectMetadata]:
342 bucket, prefix = split_path(path)
343
344 # Get the prefix of the start_after and end_at paths relative to the bucket.
345 if start_after:
346 _, start_after = split_path(start_after)
347 if end_at:
348 _, end_at = split_path(end_at)
349
350 def _invoke_api() -> Iterator[ObjectMetadata]:
351 # AIS has no start key option like other object stores.
352 all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter(
353 prefix=prefix, props="name,size,atime,checksum,cone", page_size=DEFAULT_PAGE_SIZE
354 )
355
356 # Assume AIS guarantees lexicographical order.
357 for bucket_entry in all_objects:
358 obj = bucket_entry.object
359 key = obj.name
360 props = bucket_entry.generate_object_props()
361 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
362 yield ObjectMetadata(
363 key=key,
364 content_length=int(props.size),
365 last_modified=dateutil_parser(props.access_time).astimezone(timezone.utc),
366 etag=props.checksum_value,
367 )
368 elif end_at is not None and end_at < key:
369 return
370
371 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
372
373 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
374 file_size: int = 0
375
376 if isinstance(f, str):
377 with open(f, "rb") as fp:
378 body = fp.read()
379 file_size = len(body)
380 self._put_object(remote_path, body, attributes=attributes)
381 else:
382 if isinstance(f, io.StringIO):
383 body = f.read().encode("utf-8")
384 file_size = len(body)
385 self._put_object(remote_path, body, attributes=attributes)
386 else:
387 body = f.read()
388 file_size = len(body)
389 self._put_object(remote_path, body, attributes=attributes)
390
391 return file_size
392
393 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
394 if metadata is None:
395 metadata = self._get_object_metadata(remote_path)
396
397 if isinstance(f, str):
398 if os.path.dirname(f):
399 os.makedirs(os.path.dirname(f), exist_ok=True)
400 with open(f, "wb") as fp:
401 fp.write(self._get_object(remote_path))
402 else:
403 if isinstance(f, io.StringIO):
404 f.write(self._get_object(remote_path).decode("utf-8"))
405 else:
406 f.write(self._get_object(remote_path))
407
408 return metadata.content_length