1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import time
19from collections.abc import Callable, Iterator, Sequence, Sized
20from typing import IO, Any, Optional, TypeVar, Union
21
22import opentelemetry.metrics as api_metrics
23from aistore.sdk import Client
24from aistore.sdk.authn import AuthNClient
25from aistore.sdk.errors import AISError
26from aistore.sdk.obj.object_props import ObjectProps
27from requests.exceptions import HTTPError
28from urllib3.util import Retry
29
30from ..telemetry import Telemetry
31from ..telemetry.attributes.base import AttributesProvider
32from ..types import (
33 AWARE_DATETIME_MIN,
34 Credentials,
35 CredentialsProvider,
36 ObjectMetadata,
37 Range,
38)
39from ..utils import split_path, validate_attributes
40from .base import BaseStorageProvider
41
42_T = TypeVar("_T")
43
44PROVIDER = "ais"
45DEFAULT_PAGE_SIZE = 1000
46
47
[docs]
48class StaticAISCredentialProvider(CredentialsProvider):
49 """
50 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
51 """
52
53 _username: Optional[str]
54 _password: Optional[str]
55 _authn_endpoint: Optional[str]
56 _token: Optional[str]
57 _skip_verify: bool
58 _ca_cert: Optional[str]
59
60 def __init__(
61 self,
62 username: Optional[str] = None,
63 password: Optional[str] = None,
64 authn_endpoint: Optional[str] = None,
65 token: Optional[str] = None,
66 skip_verify: bool = True,
67 ca_cert: Optional[str] = None,
68 ):
69 """
70 Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials.
71
72 :param username: The username for the AIStore authentication.
73 :param password: The password for the AIStore authentication.
74 :param authn_endpoint: The AIStore authentication endpoint.
75 :param token: The AIStore authentication token. This is used for authentication if username,
76 password and authn_endpoint are not provided.
77 :param skip_verify: If true, skip SSL certificate verification.
78 :param ca_cert: Path to a CA certificate file for SSL verification.
79 """
80 self._username = username
81 self._password = password
82 self._authn_endpoint = authn_endpoint
83 self._token = token
84 self._skip_verify = skip_verify
85 self._ca_cert = ca_cert
86
[docs]
87 def get_credentials(self) -> Credentials:
88 if self._username and self._password and self._authn_endpoint:
89 authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert)
90 self._token = authn_client.login(self._username, self._password)
91 return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
92
[docs]
93 def refresh_credentials(self) -> None:
94 pass
95
96
[docs]
97class AIStoreStorageProvider(BaseStorageProvider):
98 """
99 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with NVIDIA AIStore.
100 """
101
102 def __init__(
103 self,
104 endpoint: str = os.getenv("AIS_ENDPOINT", ""),
105 provider: str = PROVIDER,
106 skip_verify: bool = True,
107 ca_cert: Optional[str] = None,
108 timeout: Optional[Union[float, tuple[float, float]]] = None,
109 retry: Optional[dict[str, Any]] = None,
110 base_path: str = "",
111 credentials_provider: Optional[CredentialsProvider] = None,
112 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
113 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
114 metric_attributes_providers: Sequence[AttributesProvider] = (),
115 **kwargs: Any,
116 ) -> None:
117 """
118 AIStore client for managing buckets, objects, and ETL jobs.
119
120 :param endpoint: The AIStore endpoint.
121 :param skip_verify: Whether to skip SSL certificate verification.
122 :param ca_cert: Path to a CA certificate file for SSL verification.
123 :param timeout: Request timeout in seconds; a single float
124 for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read
125 timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout.
126 :param retry: ``urllib3.util.Retry`` parameters.
127 :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used.
128 :param base_path: The root prefix path within the bucket where all operations will be scoped.
129 :param credentials_provider: The provider to retrieve AIStore credentials.
130 :param metric_counters: Metric counters.
131 :param metric_gauges: Metric gauges.
132 :param metric_attributes_providers: Metric attributes providers.
133 """
134 super().__init__(
135 base_path=base_path,
136 provider_name=PROVIDER,
137 metric_counters=metric_counters,
138 metric_gauges=metric_gauges,
139 metric_attributes_providers=metric_attributes_providers,
140 )
141
142 # https://aistore.nvidia.com/docs/python-sdk#client.Client
143 client_retry = None if retry is None else Retry(**retry)
144 token = None
145 if credentials_provider:
146 token = credentials_provider.get_credentials().token
147 self.client = Client(
148 endpoint=endpoint,
149 retry=client_retry,
150 skip_verify=skip_verify,
151 ca_cert=ca_cert,
152 timeout=timeout,
153 token=token,
154 )
155 else:
156 self.client = Client(endpoint=endpoint, retry=client_retry)
157 self.provider = provider
158
159 def _collect_metrics(
160 self,
161 func: Callable[[], _T],
162 operation: str,
163 bucket: str,
164 key: str,
165 put_object_size: Optional[int] = None,
166 get_object_size: Optional[int] = None,
167 ) -> _T:
168 """
169 Collects and records performance metrics around object storage operations
170 such as ``PUT``, ``GET``, ``DELETE``, etc.
171
172 This method wraps an object storage operation and measures the time it takes to complete, along with recording
173 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
174 proper logging of duration and object size.
175
176 :param func: The function that performs the actual object storage operation.
177 :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``).
178 :param bucket: The name of the object storage bucket involved in the operation.
179 :param key: The key of the object within the object storage bucket.
180 :param put_object_size: The size of the object being uploaded, if applicable (for ``PUT`` operations).
181 :param get_object_size: The size of the object being downloaded, if applicable (for ``GET`` operations).
182
183 :return: The result of the object storage operation, typically the return value of the `func` callable.
184 """
185 start_time = time.time()
186 status_code = 200
187
188 object_size = None
189 if operation == "PUT":
190 object_size = put_object_size
191 elif operation == "GET" and get_object_size:
192 object_size = get_object_size
193
194 try:
195 result = func()
196 if operation == "GET" and object_size is None and isinstance(result, Sized):
197 object_size = len(result)
198 return result
199 except AISError as error:
200 status_code = error.status_code
201 error_info = f"status_code: {status_code}, message: {error.message}"
202 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
203 except HTTPError as error:
204 status_code = error.response.status_code
205 if status_code == 404:
206 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
207 else:
208 raise RuntimeError(
209 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}"
210 ) from error
211 except Exception as error:
212 status_code = -1
213 raise RuntimeError(
214 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
215 ) from error
216 finally:
217 elapsed_time = time.time() - start_time
218 self._metric_helper.record_duration(
219 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code
220 )
221 if object_size:
222 self._metric_helper.record_object_size(
223 object_size,
224 provider=self._provider_name,
225 operation=operation,
226 bucket=bucket,
227 status_code=status_code,
228 )
229
230 def _put_object(
231 self,
232 path: str,
233 body: bytes,
234 if_match: Optional[str] = None,
235 if_none_match: Optional[str] = None,
236 attributes: Optional[dict[str, str]] = None,
237 ) -> int:
238 # ais does not support if_match and if_none_match
239 bucket, key = split_path(path)
240
241 def _invoke_api() -> int:
242 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
243 obj.put_content(body)
244 validated_attributes = validate_attributes(attributes)
245 if validated_attributes:
246 obj.set_custom_props(custom_metadata=validated_attributes, replace_existing=True)
247
248 return len(body)
249
250 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
251
252 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
253 bucket, key = split_path(path)
254 if byte_range:
255 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
256 else:
257 bytes_range = None
258
259 def _invoke_api() -> bytes:
260 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
261 if byte_range:
262 reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType]
263 else:
264 reader = obj.get()
265 return reader.read_all()
266
267 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
268
269 def _copy_object(self, src_path: str, dest_path: str) -> int:
270 raise AttributeError("AIStore does not support copy operations")
271
272 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
273 bucket, key = split_path(path)
274
275 def _invoke_api() -> None:
276 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
277 # AIS doesn't support if-match deletion, so we implement a fallback mechanism
278 if if_match:
279 raise NotImplementedError("AIStore does not support if-match deletion")
280 # Perform deletion
281 obj.delete()
282
283 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
284
285 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
286 bucket, key = split_path(path)
287 if path.endswith("/") or (bucket and not key):
288 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
289 # which metadata is not guaranteed to exist for cases such as
290 # "virtual prefix" that was never explicitly created.
291 if self._is_dir(path):
292 return ObjectMetadata(
293 key=path,
294 type="directory",
295 content_length=0,
296 last_modified=AWARE_DATETIME_MIN,
297 )
298 else:
299 raise FileNotFoundError(f"Directory {path} does not exist.")
300 else:
301
302 def _invoke_api() -> ObjectMetadata:
303 obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key)
304 headers = obj.head()
305 props = ObjectProps(headers)
306
307 return ObjectMetadata(
308 key=key,
309 content_length=int(props.size), # pyright: ignore [reportArgumentType]
310 last_modified=AWARE_DATETIME_MIN,
311 etag=props.checksum_value,
312 metadata=props.custom_metadata,
313 )
314
315 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
316
317 def _list_objects(
318 self,
319 path: str,
320 start_after: Optional[str] = None,
321 end_at: Optional[str] = None,
322 include_directories: bool = False,
323 ) -> Iterator[ObjectMetadata]:
324 bucket, prefix = split_path(path)
325
326 def _invoke_api() -> Iterator[ObjectMetadata]:
327 # AIS has no start key option like other object stores.
328 all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter(
329 prefix=prefix, props="name,size,atime,checksum,cone", page_size=DEFAULT_PAGE_SIZE
330 )
331
332 # Assume AIS guarantees lexicographical order.
333 for bucket_entry in all_objects:
334 obj = bucket_entry.object
335 key = obj.name
336 props = bucket_entry.generate_object_props()
337 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
338 yield ObjectMetadata(
339 key=key,
340 content_length=int(props.size),
341 last_modified=AWARE_DATETIME_MIN,
342 etag=props.checksum_value,
343 )
344 elif end_at is not None and end_at < key:
345 return
346
347 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
348
349 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
350 file_size: int = 0
351
352 if isinstance(f, str):
353 with open(f, "rb") as fp:
354 body = fp.read()
355 file_size = len(body)
356 self._put_object(remote_path, body, attributes=attributes)
357 else:
358 if isinstance(f, io.StringIO):
359 body = f.read().encode("utf-8")
360 file_size = len(body)
361 self._put_object(remote_path, body, attributes=attributes)
362 else:
363 body = f.read()
364 file_size = len(body)
365 self._put_object(remote_path, body, attributes=attributes)
366
367 return file_size
368
369 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
370 if metadata is None:
371 metadata = self._get_object_metadata(remote_path)
372
373 if isinstance(f, str):
374 if os.path.dirname(f):
375 os.makedirs(os.path.dirname(f), exist_ok=True)
376 with open(f, "wb") as fp:
377 fp.write(self._get_object(remote_path))
378 else:
379 if isinstance(f, io.StringIO):
380 f.write(self._get_object(remote_path).decode("utf-8"))
381 else:
382 f.write(self._get_object(remote_path))
383
384 return metadata.content_length