Source code for multistorageclient.providers.ais

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18from collections.abc import Callable, Iterator
 19from datetime import datetime, timezone
 20from typing import IO, Any, Optional, TypeVar, Union
 21
 22from aistore.sdk import Client
 23from aistore.sdk.authn import AuthNClient
 24from aistore.sdk.errors import AISError
 25from aistore.sdk.obj.object_props import ObjectProps
 26from dateutil.parser import parse as dateutil_parser
 27from requests.exceptions import HTTPError
 28from urllib3.util import Retry
 29
 30from ..constants import DEFAULT_READ_TIMEOUT
 31from ..telemetry import Telemetry
 32from ..types import (
 33    AWARE_DATETIME_MIN,
 34    Credentials,
 35    CredentialsProvider,
 36    ObjectMetadata,
 37    Range,
 38)
 39from ..utils import split_path, validate_attributes
 40from .base import BaseStorageProvider
 41
 42_T = TypeVar("_T")
 43
 44PROVIDER = "ais"
 45DEFAULT_PAGE_SIZE = 1000
 46
 47
[docs] 48class StaticAISCredentialProvider(CredentialsProvider): 49 """ 50 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 51 """ 52 53 _username: Optional[str] 54 _password: Optional[str] 55 _authn_endpoint: Optional[str] 56 _token: Optional[str] 57 _skip_verify: bool 58 _ca_cert: Optional[str] 59 60 def __init__( 61 self, 62 username: Optional[str] = None, 63 password: Optional[str] = None, 64 authn_endpoint: Optional[str] = None, 65 token: Optional[str] = None, 66 skip_verify: bool = True, 67 ca_cert: Optional[str] = None, 68 ): 69 """ 70 Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials. 71 72 :param username: The username for the AIStore authentication. 73 :param password: The password for the AIStore authentication. 74 :param authn_endpoint: The AIStore authentication endpoint. 75 :param token: The AIStore authentication token. This is used for authentication if username, 76 password and authn_endpoint are not provided. 77 :param skip_verify: If true, skip SSL certificate verification. 78 :param ca_cert: Path to a CA certificate file for SSL verification. 79 """ 80 self._username = username 81 self._password = password 82 self._authn_endpoint = authn_endpoint 83 self._token = token 84 self._skip_verify = skip_verify 85 self._ca_cert = ca_cert 86
[docs] 87 def get_credentials(self) -> Credentials: 88 if self._username and self._password and self._authn_endpoint: 89 authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert) 90 self._token = authn_client.login(self._username, self._password) 91 return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
92
[docs] 93 def refresh_credentials(self) -> None: 94 pass
95 96
[docs] 97class AIStoreStorageProvider(BaseStorageProvider): 98 """ 99 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with NVIDIA AIStore. 100 """ 101 102 def __init__( 103 self, 104 endpoint: str = os.getenv("AIS_ENDPOINT", ""), 105 provider: str = PROVIDER, 106 skip_verify: bool = True, 107 ca_cert: Optional[str] = None, 108 timeout: Optional[Union[float, tuple[float, float]]] = None, 109 retry: Optional[dict[str, Any]] = None, 110 base_path: str = "", 111 credentials_provider: Optional[CredentialsProvider] = None, 112 config_dict: Optional[dict[str, Any]] = None, 113 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 114 **kwargs: Any, 115 ) -> None: 116 """ 117 AIStore client for managing buckets, objects, and ETL jobs. 118 119 :param endpoint: The AIStore endpoint. 120 :param skip_verify: Whether to skip SSL certificate verification. 121 :param ca_cert: Path to a CA certificate file for SSL verification. 122 :param timeout: Request timeout in seconds; a single float 123 for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read 124 timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout. 125 :param retry: ``urllib3.util.Retry`` parameters. 126 :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used. 127 :param base_path: The root prefix path within the bucket where all operations will be scoped. 128 :param credentials_provider: The provider to retrieve AIStore credentials. 129 :param config_dict: Resolved MSC config. 130 :param telemetry_provider: A function that provides a telemetry instance. 131 """ 132 super().__init__( 133 base_path=base_path, 134 provider_name=PROVIDER, 135 config_dict=config_dict, 136 telemetry_provider=telemetry_provider, 137 ) 138 139 # https://aistore.nvidia.com/docs/python-sdk#client.Client 140 client_retry = None if retry is None else Retry(**retry) 141 token = None 142 if timeout is None: 143 timeout = float(DEFAULT_READ_TIMEOUT) 144 if credentials_provider: 145 token = credentials_provider.get_credentials().token 146 self.client = Client( 147 endpoint=endpoint, 148 retry=client_retry, 149 skip_verify=skip_verify, 150 ca_cert=ca_cert, 151 timeout=timeout, 152 token=token, 153 ) 154 else: 155 self.client = Client( 156 endpoint=endpoint, retry=client_retry, timeout=timeout, skip_verify=skip_verify, ca_cert=ca_cert 157 ) 158 self.provider = provider 159 160 def _translate_errors( 161 self, 162 func: Callable[[], _T], 163 operation: str, 164 bucket: str, 165 key: str, 166 ) -> _T: 167 """ 168 Translates errors like timeouts and client errors. 169 170 :param func: The function that performs the actual object storage operation. 171 :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``). 172 :param bucket: The name of the object storage bucket involved in the operation. 173 :param key: The key of the object within the object storage bucket. 174 175 :return: The result of the object storage operation, typically the return value of the `func` callable. 176 """ 177 178 try: 179 return func() 180 except AISError as error: 181 status_code = error.status_code 182 if status_code == 404: 183 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 184 error_info = f"status_code: {status_code}, message: {error.message}" 185 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 186 except HTTPError as error: 187 status_code = error.response.status_code 188 if status_code == 404: 189 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 190 else: 191 raise RuntimeError( 192 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}" 193 ) from error 194 except Exception as error: 195 raise RuntimeError( 196 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 197 ) from error 198 199 def _put_object( 200 self, 201 path: str, 202 body: bytes, 203 if_match: Optional[str] = None, 204 if_none_match: Optional[str] = None, 205 attributes: Optional[dict[str, str]] = None, 206 ) -> int: 207 # ais does not support if_match and if_none_match 208 bucket, key = split_path(path) 209 210 def _invoke_api() -> int: 211 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 212 obj.put_content(body) 213 validated_attributes = validate_attributes(attributes) 214 if validated_attributes: 215 obj.set_custom_props(custom_metadata=validated_attributes, replace_existing=True) 216 217 return len(body) 218 219 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 220 221 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 222 bucket, key = split_path(path) 223 if byte_range: 224 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 225 else: 226 bytes_range = None 227 228 def _invoke_api() -> bytes: 229 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 230 if byte_range: 231 reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType] 232 else: 233 reader = obj.get() 234 return reader.read_all() 235 236 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 237 238 def _copy_object(self, src_path: str, dest_path: str) -> int: 239 src_bucket, src_key = split_path(src_path) 240 dest_bucket, dest_key = split_path(dest_path) 241 242 def _invoke_api() -> int: 243 src_obj = self.client.bucket(bck_name=src_bucket, provider=self.provider).object(obj_name=src_key) 244 dest_obj = self.client.bucket(bck_name=dest_bucket, provider=self.provider).object(obj_name=dest_key) 245 246 # Get source size before copying 247 src_headers = src_obj.head() 248 src_props = ObjectProps(src_headers) 249 250 # Server-side copy (preserves custom metadata automatically) 251 src_obj.copy(to_obj=dest_obj) # type: ignore[attr-defined] 252 253 return int(src_props.size) 254 255 return self._translate_errors( 256 _invoke_api, operation="COPY", bucket=f"{src_bucket}->{dest_bucket}", key=f"{src_key}->{dest_key}" 257 ) 258 259 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 260 bucket, key = split_path(path) 261 262 def _invoke_api() -> None: 263 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 264 # AIS doesn't support if-match deletion, so we implement a fallback mechanism 265 if if_match: 266 raise NotImplementedError("AIStore does not support if-match deletion") 267 # Perform deletion 268 obj.delete() 269 270 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 271 272 def _is_dir(self, path: str) -> bool: 273 # Ensure the path ends with '/' to mimic a directory 274 path = self._append_delimiter(path) 275 276 bucket, prefix = split_path(path) 277 278 def _invoke_api() -> bool: 279 # List objects with the given prefix (limit to 1 for efficiency) 280 objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter( 281 prefix=prefix, page_size=1 282 ) 283 # Check if there are any objects with this prefix 284 return any(True for _ in objects) 285 286 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 287 288 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 289 bucket, key = split_path(path) 290 if path.endswith("/") or (bucket and not key): 291 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 292 # which metadata is not guaranteed to exist for cases such as 293 # "virtual prefix" that was never explicitly created. 294 if self._is_dir(path): 295 return ObjectMetadata( 296 key=path, 297 type="directory", 298 content_length=0, 299 last_modified=AWARE_DATETIME_MIN, 300 ) 301 else: 302 raise FileNotFoundError(f"Directory {path} does not exist.") 303 else: 304 305 def _invoke_api() -> ObjectMetadata: 306 obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key) 307 try: 308 headers = obj.head() 309 props = ObjectProps(headers) 310 311 # The access time is not always present in the response. 312 if props.access_time: 313 last_modified = datetime.fromtimestamp(int(props.access_time) / 1e9).astimezone(timezone.utc) 314 else: 315 last_modified = AWARE_DATETIME_MIN 316 317 return ObjectMetadata( 318 key=key, 319 content_length=int(props.size), # pyright: ignore [reportArgumentType] 320 last_modified=last_modified, 321 etag=props.checksum_value, 322 metadata=props.custom_metadata, 323 ) 324 except (AISError, HTTPError) as e: 325 # Check if this might be a virtual directory (prefix with objects under it) 326 status_code = None 327 if isinstance(e, AISError): 328 status_code = e.status_code 329 elif isinstance(e, HTTPError): 330 status_code = e.response.status_code 331 332 if status_code == 404: 333 if self._is_dir(path): 334 return ObjectMetadata( 335 key=path + "/", 336 type="directory", 337 content_length=0, 338 last_modified=AWARE_DATETIME_MIN, 339 ) 340 # Re-raise to be handled by _translate_errors 341 raise 342 343 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 344 345 def _list_objects( 346 self, 347 path: str, 348 start_after: Optional[str] = None, 349 end_at: Optional[str] = None, 350 include_directories: bool = False, 351 follow_symlinks: bool = True, 352 ) -> Iterator[ObjectMetadata]: 353 bucket, prefix = split_path(path) 354 355 # Get the prefix of the start_after and end_at paths relative to the bucket. 356 if start_after: 357 _, start_after = split_path(start_after) 358 if end_at: 359 _, end_at = split_path(end_at) 360 361 def _invoke_api() -> Iterator[ObjectMetadata]: 362 # AIS has no start key option like other object stores. 363 all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter( 364 prefix=prefix, props="name,size,atime,checksum,cone", page_size=DEFAULT_PAGE_SIZE 365 ) 366 367 # Assume AIS guarantees lexicographical order. 368 for bucket_entry in all_objects: 369 obj = bucket_entry.object 370 key = obj.name 371 props = bucket_entry.generate_object_props() 372 373 # The access time is not always present in the response. 374 if props.access_time: 375 last_modified = dateutil_parser(props.access_time).astimezone(timezone.utc) 376 else: 377 last_modified = AWARE_DATETIME_MIN 378 379 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 380 yield ObjectMetadata( 381 key=key, content_length=int(props.size), last_modified=last_modified, etag=props.checksum_value 382 ) 383 elif end_at is not None and end_at < key: 384 return 385 386 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 387 388 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 389 file_size: int = 0 390 391 if isinstance(f, str): 392 with open(f, "rb") as fp: 393 body = fp.read() 394 file_size = len(body) 395 self._put_object(remote_path, body, attributes=attributes) 396 else: 397 if isinstance(f, io.StringIO): 398 body = f.read().encode("utf-8") 399 file_size = len(body) 400 self._put_object(remote_path, body, attributes=attributes) 401 else: 402 body = f.read() 403 file_size = len(body) 404 self._put_object(remote_path, body, attributes=attributes) 405 406 return file_size 407 408 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 409 if metadata is None: 410 metadata = self._get_object_metadata(remote_path) 411 412 if isinstance(f, str): 413 if os.path.dirname(f): 414 os.makedirs(os.path.dirname(f), exist_ok=True) 415 with open(f, "wb") as fp: 416 fp.write(self._get_object(remote_path)) 417 else: 418 if isinstance(f, io.StringIO): 419 f.write(self._get_object(remote_path).decode("utf-8")) 420 else: 421 f.write(self._get_object(remote_path)) 422 423 return metadata.content_length