Source code for multistorageclient.providers.ais

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18from collections.abc import Callable, Iterator
 19from datetime import datetime, timezone
 20from typing import IO, Any, Optional, TypeVar, Union
 21
 22from aistore.sdk import Client
 23from aistore.sdk.authn import AuthNClient
 24from aistore.sdk.errors import AISError
 25from aistore.sdk.obj.object_props import ObjectProps
 26from dateutil.parser import parse as dateutil_parser
 27from requests.exceptions import HTTPError
 28from urllib3.util import Retry
 29
 30from ..telemetry import Telemetry
 31from ..types import (
 32    AWARE_DATETIME_MIN,
 33    Credentials,
 34    CredentialsProvider,
 35    ObjectMetadata,
 36    Range,
 37)
 38from ..utils import split_path, validate_attributes
 39from .base import BaseStorageProvider
 40
 41_T = TypeVar("_T")
 42
 43PROVIDER = "ais"
 44DEFAULT_PAGE_SIZE = 1000
 45
 46
[docs] 47class StaticAISCredentialProvider(CredentialsProvider): 48 """ 49 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 50 """ 51 52 _username: Optional[str] 53 _password: Optional[str] 54 _authn_endpoint: Optional[str] 55 _token: Optional[str] 56 _skip_verify: bool 57 _ca_cert: Optional[str] 58 59 def __init__( 60 self, 61 username: Optional[str] = None, 62 password: Optional[str] = None, 63 authn_endpoint: Optional[str] = None, 64 token: Optional[str] = None, 65 skip_verify: bool = True, 66 ca_cert: Optional[str] = None, 67 ): 68 """ 69 Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials. 70 71 :param username: The username for the AIStore authentication. 72 :param password: The password for the AIStore authentication. 73 :param authn_endpoint: The AIStore authentication endpoint. 74 :param token: The AIStore authentication token. This is used for authentication if username, 75 password and authn_endpoint are not provided. 76 :param skip_verify: If true, skip SSL certificate verification. 77 :param ca_cert: Path to a CA certificate file for SSL verification. 78 """ 79 self._username = username 80 self._password = password 81 self._authn_endpoint = authn_endpoint 82 self._token = token 83 self._skip_verify = skip_verify 84 self._ca_cert = ca_cert 85
[docs] 86 def get_credentials(self) -> Credentials: 87 if self._username and self._password and self._authn_endpoint: 88 authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert) 89 self._token = authn_client.login(self._username, self._password) 90 return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
91
[docs] 92 def refresh_credentials(self) -> None: 93 pass
94 95
[docs] 96class AIStoreStorageProvider(BaseStorageProvider): 97 """ 98 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with NVIDIA AIStore. 99 """ 100 101 def __init__( 102 self, 103 endpoint: str = os.getenv("AIS_ENDPOINT", ""), 104 provider: str = PROVIDER, 105 skip_verify: bool = True, 106 ca_cert: Optional[str] = None, 107 timeout: Optional[Union[float, tuple[float, float]]] = None, 108 retry: Optional[dict[str, Any]] = None, 109 base_path: str = "", 110 credentials_provider: Optional[CredentialsProvider] = None, 111 config_dict: Optional[dict[str, Any]] = None, 112 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 113 **kwargs: Any, 114 ) -> None: 115 """ 116 AIStore client for managing buckets, objects, and ETL jobs. 117 118 :param endpoint: The AIStore endpoint. 119 :param skip_verify: Whether to skip SSL certificate verification. 120 :param ca_cert: Path to a CA certificate file for SSL verification. 121 :param timeout: Request timeout in seconds; a single float 122 for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read 123 timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout. 124 :param retry: ``urllib3.util.Retry`` parameters. 125 :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used. 126 :param base_path: The root prefix path within the bucket where all operations will be scoped. 127 :param credentials_provider: The provider to retrieve AIStore credentials. 128 :param config_dict: Resolved MSC config. 129 :param telemetry_provider: A function that provides a telemetry instance. 130 """ 131 super().__init__( 132 base_path=base_path, 133 provider_name=PROVIDER, 134 config_dict=config_dict, 135 telemetry_provider=telemetry_provider, 136 ) 137 138 # https://aistore.nvidia.com/docs/python-sdk#client.Client 139 client_retry = None if retry is None else Retry(**retry) 140 token = None 141 if credentials_provider: 142 token = credentials_provider.get_credentials().token 143 self.client = Client( 144 endpoint=endpoint, 145 retry=client_retry, 146 skip_verify=skip_verify, 147 ca_cert=ca_cert, 148 timeout=timeout, 149 token=token, 150 ) 151 else: 152 self.client = Client(endpoint=endpoint, retry=client_retry) 153 self.provider = provider 154 155 def _translate_errors( 156 self, 157 func: Callable[[], _T], 158 operation: str, 159 bucket: str, 160 key: str, 161 ) -> _T: 162 """ 163 Translates errors like timeouts and client errors. 164 165 :param func: The function that performs the actual object storage operation. 166 :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``). 167 :param bucket: The name of the object storage bucket involved in the operation. 168 :param key: The key of the object within the object storage bucket. 169 170 :return: The result of the object storage operation, typically the return value of the `func` callable. 171 """ 172 173 try: 174 return func() 175 except AISError as error: 176 status_code = error.status_code 177 if status_code == 404: 178 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 179 error_info = f"status_code: {status_code}, message: {error.message}" 180 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 181 except HTTPError as error: 182 status_code = error.response.status_code 183 if status_code == 404: 184 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 185 else: 186 raise RuntimeError( 187 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}" 188 ) from error 189 except Exception as error: 190 raise RuntimeError( 191 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 192 ) from error 193 194 def _put_object( 195 self, 196 path: str, 197 body: bytes, 198 if_match: Optional[str] = None, 199 if_none_match: Optional[str] = None, 200 attributes: Optional[dict[str, str]] = None, 201 ) -> int: 202 # ais does not support if_match and if_none_match 203 bucket, key = split_path(path) 204 205 def _invoke_api() -> int: 206 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 207 obj.put_content(body) 208 validated_attributes = validate_attributes(attributes) 209 if validated_attributes: 210 obj.set_custom_props(custom_metadata=validated_attributes, replace_existing=True) 211 212 return len(body) 213 214 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 215 216 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 217 bucket, key = split_path(path) 218 if byte_range: 219 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 220 else: 221 bytes_range = None 222 223 def _invoke_api() -> bytes: 224 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 225 if byte_range: 226 reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType] 227 else: 228 reader = obj.get() 229 return reader.read_all() 230 231 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 232 233 def _copy_object(self, src_path: str, dest_path: str) -> int: 234 src_bucket, src_key = split_path(src_path) 235 dest_bucket, dest_key = split_path(dest_path) 236 237 def _invoke_api() -> int: 238 src_obj = self.client.bucket(bck_name=src_bucket, provider=self.provider).object(obj_name=src_key) 239 dest_obj = self.client.bucket(bck_name=dest_bucket, provider=self.provider).object(obj_name=dest_key) 240 241 # Get source size before copying 242 src_headers = src_obj.head() 243 src_props = ObjectProps(src_headers) 244 245 # Server-side copy (preserves custom metadata automatically) 246 src_obj.copy(to_obj=dest_obj) # type: ignore[attr-defined] 247 248 return int(src_props.size) 249 250 return self._translate_errors( 251 _invoke_api, operation="COPY", bucket=f"{src_bucket}->{dest_bucket}", key=f"{src_key}->{dest_key}" 252 ) 253 254 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 255 bucket, key = split_path(path) 256 257 def _invoke_api() -> None: 258 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 259 # AIS doesn't support if-match deletion, so we implement a fallback mechanism 260 if if_match: 261 raise NotImplementedError("AIStore does not support if-match deletion") 262 # Perform deletion 263 obj.delete() 264 265 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 266 267 def _is_dir(self, path: str) -> bool: 268 # Ensure the path ends with '/' to mimic a directory 269 path = self._append_delimiter(path) 270 271 bucket, prefix = split_path(path) 272 273 def _invoke_api() -> bool: 274 # List objects with the given prefix (limit to 1 for efficiency) 275 objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter( 276 prefix=prefix, page_size=1 277 ) 278 # Check if there are any objects with this prefix 279 return any(True for _ in objects) 280 281 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 282 283 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 284 bucket, key = split_path(path) 285 if path.endswith("/") or (bucket and not key): 286 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 287 # which metadata is not guaranteed to exist for cases such as 288 # "virtual prefix" that was never explicitly created. 289 if self._is_dir(path): 290 return ObjectMetadata( 291 key=path, 292 type="directory", 293 content_length=0, 294 last_modified=AWARE_DATETIME_MIN, 295 ) 296 else: 297 raise FileNotFoundError(f"Directory {path} does not exist.") 298 else: 299 300 def _invoke_api() -> ObjectMetadata: 301 obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key) 302 try: 303 headers = obj.head() 304 props = ObjectProps(headers) 305 306 return ObjectMetadata( 307 key=key, 308 content_length=int(props.size), # pyright: ignore [reportArgumentType] 309 last_modified=datetime.fromtimestamp(int(props.access_time) / 1e9).astimezone(timezone.utc), 310 etag=props.checksum_value, 311 metadata=props.custom_metadata, 312 ) 313 except (AISError, HTTPError) as e: 314 # Check if this might be a virtual directory (prefix with objects under it) 315 status_code = None 316 if isinstance(e, AISError): 317 status_code = e.status_code 318 elif isinstance(e, HTTPError): 319 status_code = e.response.status_code 320 321 if status_code == 404: 322 if self._is_dir(path): 323 return ObjectMetadata( 324 key=path + "/", 325 type="directory", 326 content_length=0, 327 last_modified=AWARE_DATETIME_MIN, 328 ) 329 # Re-raise to be handled by _translate_errors 330 raise 331 332 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 333 334 def _list_objects( 335 self, 336 path: str, 337 start_after: Optional[str] = None, 338 end_at: Optional[str] = None, 339 include_directories: bool = False, 340 follow_symlinks: bool = True, 341 ) -> Iterator[ObjectMetadata]: 342 bucket, prefix = split_path(path) 343 344 # Get the prefix of the start_after and end_at paths relative to the bucket. 345 if start_after: 346 _, start_after = split_path(start_after) 347 if end_at: 348 _, end_at = split_path(end_at) 349 350 def _invoke_api() -> Iterator[ObjectMetadata]: 351 # AIS has no start key option like other object stores. 352 all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter( 353 prefix=prefix, props="name,size,atime,checksum,cone", page_size=DEFAULT_PAGE_SIZE 354 ) 355 356 # Assume AIS guarantees lexicographical order. 357 for bucket_entry in all_objects: 358 obj = bucket_entry.object 359 key = obj.name 360 props = bucket_entry.generate_object_props() 361 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 362 yield ObjectMetadata( 363 key=key, 364 content_length=int(props.size), 365 last_modified=dateutil_parser(props.access_time).astimezone(timezone.utc), 366 etag=props.checksum_value, 367 ) 368 elif end_at is not None and end_at < key: 369 return 370 371 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 372 373 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 374 file_size: int = 0 375 376 if isinstance(f, str): 377 with open(f, "rb") as fp: 378 body = fp.read() 379 file_size = len(body) 380 self._put_object(remote_path, body, attributes=attributes) 381 else: 382 if isinstance(f, io.StringIO): 383 body = f.read().encode("utf-8") 384 file_size = len(body) 385 self._put_object(remote_path, body, attributes=attributes) 386 else: 387 body = f.read() 388 file_size = len(body) 389 self._put_object(remote_path, body, attributes=attributes) 390 391 return file_size 392 393 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 394 if metadata is None: 395 metadata = self._get_object_metadata(remote_path) 396 397 if isinstance(f, str): 398 if os.path.dirname(f): 399 os.makedirs(os.path.dirname(f), exist_ok=True) 400 with open(f, "wb") as fp: 401 fp.write(self._get_object(remote_path)) 402 else: 403 if isinstance(f, io.StringIO): 404 f.write(self._get_object(remote_path).decode("utf-8")) 405 else: 406 f.write(self._get_object(remote_path)) 407 408 return metadata.content_length