Source code for multistorageclient.providers.ais

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18from collections.abc import Callable, Iterator
 19from datetime import datetime, timezone
 20from typing import IO, Any, Optional, TypeVar, Union
 21
 22from aistore.sdk import Client
 23from aistore.sdk.authn import AuthNClient
 24from aistore.sdk.errors import AISError
 25from aistore.sdk.obj.object_props import ObjectProps
 26from dateutil.parser import parse as dateutil_parser
 27from requests.exceptions import HTTPError
 28from urllib3.util import Retry
 29
 30from ..constants import DEFAULT_READ_TIMEOUT
 31from ..telemetry import Telemetry
 32from ..types import (
 33    AWARE_DATETIME_MIN,
 34    Credentials,
 35    CredentialsProvider,
 36    ObjectMetadata,
 37    Range,
 38    SymlinkHandling,
 39)
 40from ..utils import safe_makedirs, split_path, validate_attributes
 41from .base import BaseStorageProvider
 42
 43_T = TypeVar("_T")
 44
 45PROVIDER = "ais"
 46DEFAULT_PAGE_SIZE = 1000
 47
 48
[docs] 49class StaticAISCredentialProvider(CredentialsProvider): 50 """ 51 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 52 """ 53 54 _username: Optional[str] 55 _password: Optional[str] 56 _authn_endpoint: Optional[str] 57 _token: Optional[str] 58 _skip_verify: bool 59 _ca_cert: Optional[str] 60 61 def __init__( 62 self, 63 username: Optional[str] = None, 64 password: Optional[str] = None, 65 authn_endpoint: Optional[str] = None, 66 token: Optional[str] = None, 67 skip_verify: bool = True, 68 ca_cert: Optional[str] = None, 69 ): 70 """ 71 Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials. 72 73 :param username: The username for the AIStore authentication. 74 :param password: The password for the AIStore authentication. 75 :param authn_endpoint: The AIStore authentication endpoint. 76 :param token: The AIStore authentication token. This is used for authentication if username, 77 password and authn_endpoint are not provided. 78 :param skip_verify: If true, skip SSL certificate verification. 79 :param ca_cert: Path to a CA certificate file for SSL verification. 80 """ 81 self._username = username 82 self._password = password 83 self._authn_endpoint = authn_endpoint 84 self._token = token 85 self._skip_verify = skip_verify 86 self._ca_cert = ca_cert 87
[docs] 88 def get_credentials(self) -> Credentials: 89 if self._username and self._password and self._authn_endpoint: 90 authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert) 91 self._token = authn_client.login(self._username, self._password) 92 return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
93
[docs] 94 def refresh_credentials(self) -> None: 95 pass
96 97
[docs] 98class AIStoreStorageProvider(BaseStorageProvider): 99 """ 100 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with NVIDIA AIStore. 101 """ 102 103 def __init__( 104 self, 105 endpoint: str = os.getenv("AIS_ENDPOINT", ""), 106 provider: str = PROVIDER, 107 skip_verify: bool = True, 108 ca_cert: Optional[str] = None, 109 timeout: Optional[Union[float, tuple[float, float]]] = None, 110 retry: Optional[dict[str, Any]] = None, 111 base_path: str = "", 112 credentials_provider: Optional[CredentialsProvider] = None, 113 config_dict: Optional[dict[str, Any]] = None, 114 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 115 **kwargs: Any, 116 ) -> None: 117 """ 118 AIStore client for managing buckets, objects, and ETL jobs. 119 120 :param endpoint: The AIStore endpoint. 121 :param skip_verify: Whether to skip SSL certificate verification. 122 :param ca_cert: Path to a CA certificate file for SSL verification. 123 :param timeout: Request timeout in seconds; a single float 124 for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read 125 timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout. 126 :param retry: ``urllib3.util.Retry`` parameters. 127 :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used. 128 :param base_path: The root prefix path within the bucket where all operations will be scoped. 129 :param credentials_provider: The provider to retrieve AIStore credentials. 130 :param config_dict: Resolved MSC config. 131 :param telemetry_provider: A function that provides a telemetry instance. 132 """ 133 super().__init__( 134 base_path=base_path, 135 provider_name=PROVIDER, 136 config_dict=config_dict, 137 telemetry_provider=telemetry_provider, 138 ) 139 140 # https://aistore.nvidia.com/docs/python-sdk#client.Client 141 client_retry = None if retry is None else Retry(**retry) 142 token = None 143 if timeout is None: 144 timeout = float(DEFAULT_READ_TIMEOUT) 145 if credentials_provider: 146 token = credentials_provider.get_credentials().token 147 self.client = Client( 148 endpoint=endpoint, 149 retry=client_retry, 150 skip_verify=skip_verify, 151 ca_cert=ca_cert, 152 timeout=timeout, 153 token=token, 154 ) 155 else: 156 self.client = Client( 157 endpoint=endpoint, retry=client_retry, timeout=timeout, skip_verify=skip_verify, ca_cert=ca_cert 158 ) 159 self.provider = provider 160 161 def _translate_errors( 162 self, 163 func: Callable[[], _T], 164 operation: str, 165 bucket: str, 166 key: str, 167 ) -> _T: 168 """ 169 Translates errors like timeouts and client errors. 170 171 :param func: The function that performs the actual object storage operation. 172 :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``). 173 :param bucket: The name of the object storage bucket involved in the operation. 174 :param key: The key of the object within the object storage bucket. 175 176 :return: The result of the object storage operation, typically the return value of the `func` callable. 177 """ 178 179 try: 180 return func() 181 except AISError as error: 182 status_code = error.status_code 183 if status_code == 404: 184 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 185 error_info = f"status_code: {status_code}, message: {error.message}" 186 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 187 except HTTPError as error: 188 status_code = error.response.status_code 189 if status_code == 404: 190 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 191 else: 192 raise RuntimeError( 193 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}" 194 ) from error 195 except Exception as error: 196 raise RuntimeError( 197 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 198 ) from error 199 200 def _put_object( 201 self, 202 path: str, 203 body: bytes, 204 if_match: Optional[str] = None, 205 if_none_match: Optional[str] = None, 206 attributes: Optional[dict[str, str]] = None, 207 ) -> int: 208 # ais does not support if_match and if_none_match 209 bucket, key = split_path(path) 210 211 def _invoke_api() -> int: 212 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 213 obj.put_content(body) 214 validated_attributes = validate_attributes(attributes) 215 if validated_attributes: 216 obj.set_custom_props(custom_metadata=validated_attributes, replace_existing=True) 217 218 return len(body) 219 220 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 221 222 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 223 bucket, key = split_path(path) 224 if byte_range: 225 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 226 else: 227 bytes_range = None 228 229 def _invoke_api() -> bytes: 230 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 231 if byte_range: 232 reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType] 233 else: 234 reader = obj.get() 235 return reader.read_all() 236 237 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 238 239 def _copy_object(self, src_path: str, dest_path: str) -> int: 240 src_bucket, src_key = split_path(src_path) 241 dest_bucket, dest_key = split_path(dest_path) 242 243 def _invoke_api() -> int: 244 src_obj = self.client.bucket(bck_name=src_bucket, provider=self.provider).object(obj_name=src_key) 245 dest_obj = self.client.bucket(bck_name=dest_bucket, provider=self.provider).object(obj_name=dest_key) 246 247 # Get source size before copying 248 src_headers = src_obj.head() 249 src_props = ObjectProps(src_headers) 250 251 # Server-side copy (preserves custom metadata automatically) 252 src_obj.copy(to_obj=dest_obj) # type: ignore[attr-defined] 253 254 return int(src_props.size) 255 256 return self._translate_errors( 257 _invoke_api, operation="COPY", bucket=f"{src_bucket}->{dest_bucket}", key=f"{src_key}->{dest_key}" 258 ) 259 260 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 261 bucket, key = split_path(path) 262 263 def _invoke_api() -> None: 264 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 265 # AIS doesn't support if-match deletion, so we implement a fallback mechanism 266 if if_match: 267 raise NotImplementedError("AIStore does not support if-match deletion") 268 # Perform deletion 269 obj.delete() 270 271 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 272 273 def _is_dir(self, path: str) -> bool: 274 # Ensure the path ends with '/' to mimic a directory 275 path = self._append_delimiter(path) 276 277 bucket, prefix = split_path(path) 278 279 def _invoke_api() -> bool: 280 # List objects with the given prefix (limit to 1 for efficiency) 281 objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter( 282 prefix=prefix, page_size=1 283 ) 284 # Check if there are any objects with this prefix 285 return any(True for _ in objects) 286 287 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 288 289 def _make_symlink(self, path: str, target: str) -> None: 290 bucket, key = split_path(path) 291 target_bucket, target_key = split_path(target) 292 if bucket != target_bucket: 293 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket}' -> '{target_bucket}'.") 294 relative_target = ObjectMetadata.encode_symlink_target(key, target_key) 295 296 def _invoke_api() -> None: 297 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 298 obj.put_content(b"") 299 obj.set_custom_props(custom_metadata={"msc-symlink-target": relative_target}, replace_existing=True) 300 301 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 302 303 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 304 bucket, key = split_path(path) 305 if path.endswith("/") or (bucket and not key): 306 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 307 # which metadata is not guaranteed to exist for cases such as 308 # "virtual prefix" that was never explicitly created. 309 if self._is_dir(path): 310 return ObjectMetadata( 311 key=path, 312 type="directory", 313 content_length=0, 314 last_modified=AWARE_DATETIME_MIN, 315 ) 316 else: 317 raise FileNotFoundError(f"Directory {path} does not exist.") 318 else: 319 320 def _invoke_api() -> ObjectMetadata: 321 obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key) 322 try: 323 headers = obj.head() 324 props = ObjectProps(headers) 325 326 # The access time is not always present in the response. 327 if props.access_time: 328 last_modified = datetime.fromtimestamp(int(props.access_time) / 1e9).astimezone(timezone.utc) 329 else: 330 last_modified = AWARE_DATETIME_MIN 331 332 user_metadata = props.custom_metadata 333 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None 334 return ObjectMetadata( 335 key=key, 336 content_length=int(props.size), # pyright: ignore [reportArgumentType] 337 last_modified=last_modified, 338 etag=props.checksum_value, 339 metadata=user_metadata, 340 symlink_target=symlink_target, 341 ) 342 except (AISError, HTTPError) as e: 343 # Check if this might be a virtual directory (prefix with objects under it) 344 status_code = None 345 if isinstance(e, AISError): 346 status_code = e.status_code 347 elif isinstance(e, HTTPError): 348 status_code = e.response.status_code 349 350 if status_code == 404: 351 if self._is_dir(path): 352 return ObjectMetadata( 353 key=path + "/", 354 type="directory", 355 content_length=0, 356 last_modified=AWARE_DATETIME_MIN, 357 ) 358 # Re-raise to be handled by _translate_errors 359 raise 360 361 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 362 363 def _list_objects( 364 self, 365 path: str, 366 start_after: Optional[str] = None, 367 end_at: Optional[str] = None, 368 include_directories: bool = False, 369 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW, 370 ) -> Iterator[ObjectMetadata]: 371 bucket, prefix = split_path(path) 372 373 # Get the prefix of the start_after and end_at paths relative to the bucket. 374 if start_after: 375 _, start_after = split_path(start_after) 376 if end_at: 377 _, end_at = split_path(end_at) 378 379 def _invoke_api() -> Iterator[ObjectMetadata]: 380 # AIS has no start key option like other object stores. 381 all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter( 382 prefix=prefix, props="name,size,atime,checksum,cone", page_size=DEFAULT_PAGE_SIZE 383 ) 384 385 # Assume AIS guarantees lexicographical order. 386 for bucket_entry in all_objects: 387 obj = bucket_entry.object 388 key = obj.name 389 props = bucket_entry.generate_object_props() 390 391 # The access time is not always present in the response. 392 if props.access_time: 393 last_modified = dateutil_parser(props.access_time).astimezone(timezone.utc) 394 else: 395 last_modified = AWARE_DATETIME_MIN 396 397 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 398 yield ObjectMetadata( 399 key=key, content_length=int(props.size), last_modified=last_modified, etag=props.checksum_value 400 ) 401 elif end_at is not None and end_at < key: 402 return 403 404 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 405 406 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 407 file_size: int = 0 408 409 if isinstance(f, str): 410 with open(f, "rb") as fp: 411 body = fp.read() 412 file_size = len(body) 413 self._put_object(remote_path, body, attributes=attributes) 414 else: 415 if isinstance(f, io.StringIO): 416 body = f.read().encode("utf-8") 417 file_size = len(body) 418 self._put_object(remote_path, body, attributes=attributes) 419 else: 420 body = f.read() 421 file_size = len(body) 422 self._put_object(remote_path, body, attributes=attributes) 423 424 return file_size 425 426 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 427 if metadata is None: 428 metadata = self._get_object_metadata(remote_path) 429 430 if isinstance(f, str): 431 if os.path.dirname(f): 432 safe_makedirs(os.path.dirname(f)) 433 with open(f, "wb") as fp: 434 fp.write(self._get_object(remote_path)) 435 else: 436 if isinstance(f, io.StringIO): 437 f.write(self._get_object(remote_path).decode("utf-8")) 438 else: 439 f.write(self._get_object(remote_path)) 440 441 return metadata.content_length