Source code for multistorageclient.providers.ais

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18from collections.abc import Callable, Iterator
 19from typing import IO, Any, Optional, TypeVar, Union
 20
 21from aistore.sdk import Client
 22from aistore.sdk.authn import AuthNClient
 23from aistore.sdk.errors import AISError
 24from aistore.sdk.obj.object_props import ObjectProps
 25from requests.exceptions import HTTPError
 26from urllib3.util import Retry
 27
 28from ..telemetry import Telemetry
 29from ..types import (
 30    AWARE_DATETIME_MIN,
 31    Credentials,
 32    CredentialsProvider,
 33    ObjectMetadata,
 34    Range,
 35)
 36from ..utils import split_path, validate_attributes
 37from .base import BaseStorageProvider
 38
 39_T = TypeVar("_T")
 40
 41PROVIDER = "ais"
 42DEFAULT_PAGE_SIZE = 1000
 43
 44
[docs] 45class StaticAISCredentialProvider(CredentialsProvider): 46 """ 47 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 48 """ 49 50 _username: Optional[str] 51 _password: Optional[str] 52 _authn_endpoint: Optional[str] 53 _token: Optional[str] 54 _skip_verify: bool 55 _ca_cert: Optional[str] 56 57 def __init__( 58 self, 59 username: Optional[str] = None, 60 password: Optional[str] = None, 61 authn_endpoint: Optional[str] = None, 62 token: Optional[str] = None, 63 skip_verify: bool = True, 64 ca_cert: Optional[str] = None, 65 ): 66 """ 67 Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials. 68 69 :param username: The username for the AIStore authentication. 70 :param password: The password for the AIStore authentication. 71 :param authn_endpoint: The AIStore authentication endpoint. 72 :param token: The AIStore authentication token. This is used for authentication if username, 73 password and authn_endpoint are not provided. 74 :param skip_verify: If true, skip SSL certificate verification. 75 :param ca_cert: Path to a CA certificate file for SSL verification. 76 """ 77 self._username = username 78 self._password = password 79 self._authn_endpoint = authn_endpoint 80 self._token = token 81 self._skip_verify = skip_verify 82 self._ca_cert = ca_cert 83
[docs] 84 def get_credentials(self) -> Credentials: 85 if self._username and self._password and self._authn_endpoint: 86 authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert) 87 self._token = authn_client.login(self._username, self._password) 88 return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
89
[docs] 90 def refresh_credentials(self) -> None: 91 pass
92 93
[docs] 94class AIStoreStorageProvider(BaseStorageProvider): 95 """ 96 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with NVIDIA AIStore. 97 """ 98 99 def __init__( 100 self, 101 endpoint: str = os.getenv("AIS_ENDPOINT", ""), 102 provider: str = PROVIDER, 103 skip_verify: bool = True, 104 ca_cert: Optional[str] = None, 105 timeout: Optional[Union[float, tuple[float, float]]] = None, 106 retry: Optional[dict[str, Any]] = None, 107 base_path: str = "", 108 credentials_provider: Optional[CredentialsProvider] = None, 109 config_dict: Optional[dict[str, Any]] = None, 110 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 111 **kwargs: Any, 112 ) -> None: 113 """ 114 AIStore client for managing buckets, objects, and ETL jobs. 115 116 :param endpoint: The AIStore endpoint. 117 :param skip_verify: Whether to skip SSL certificate verification. 118 :param ca_cert: Path to a CA certificate file for SSL verification. 119 :param timeout: Request timeout in seconds; a single float 120 for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read 121 timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout. 122 :param retry: ``urllib3.util.Retry`` parameters. 123 :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used. 124 :param base_path: The root prefix path within the bucket where all operations will be scoped. 125 :param credentials_provider: The provider to retrieve AIStore credentials. 126 :param config_dict: Resolved MSC config. 127 :param telemetry_provider: A function that provides a telemetry instance. 128 """ 129 super().__init__( 130 base_path=base_path, 131 provider_name=PROVIDER, 132 config_dict=config_dict, 133 telemetry_provider=telemetry_provider, 134 ) 135 136 # https://aistore.nvidia.com/docs/python-sdk#client.Client 137 client_retry = None if retry is None else Retry(**retry) 138 token = None 139 if credentials_provider: 140 token = credentials_provider.get_credentials().token 141 self.client = Client( 142 endpoint=endpoint, 143 retry=client_retry, 144 skip_verify=skip_verify, 145 ca_cert=ca_cert, 146 timeout=timeout, 147 token=token, 148 ) 149 else: 150 self.client = Client(endpoint=endpoint, retry=client_retry) 151 self.provider = provider 152 153 def _translate_errors( 154 self, 155 func: Callable[[], _T], 156 operation: str, 157 bucket: str, 158 key: str, 159 ) -> _T: 160 """ 161 Translates errors like timeouts and client errors. 162 163 :param func: The function that performs the actual object storage operation. 164 :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``). 165 :param bucket: The name of the object storage bucket involved in the operation. 166 :param key: The key of the object within the object storage bucket. 167 168 :return: The result of the object storage operation, typically the return value of the `func` callable. 169 """ 170 171 try: 172 return func() 173 except AISError as error: 174 status_code = error.status_code 175 if status_code == 404: 176 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 177 error_info = f"status_code: {status_code}, message: {error.message}" 178 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 179 except HTTPError as error: 180 status_code = error.response.status_code 181 if status_code == 404: 182 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 183 else: 184 raise RuntimeError( 185 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}" 186 ) from error 187 except Exception as error: 188 raise RuntimeError( 189 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 190 ) from error 191 192 def _put_object( 193 self, 194 path: str, 195 body: bytes, 196 if_match: Optional[str] = None, 197 if_none_match: Optional[str] = None, 198 attributes: Optional[dict[str, str]] = None, 199 ) -> int: 200 # ais does not support if_match and if_none_match 201 bucket, key = split_path(path) 202 203 def _invoke_api() -> int: 204 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 205 obj.put_content(body) 206 validated_attributes = validate_attributes(attributes) 207 if validated_attributes: 208 obj.set_custom_props(custom_metadata=validated_attributes, replace_existing=True) 209 210 return len(body) 211 212 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 213 214 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 215 bucket, key = split_path(path) 216 if byte_range: 217 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 218 else: 219 bytes_range = None 220 221 def _invoke_api() -> bytes: 222 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 223 if byte_range: 224 reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType] 225 else: 226 reader = obj.get() 227 return reader.read_all() 228 229 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 230 231 def _copy_object(self, src_path: str, dest_path: str) -> int: 232 src_bucket, src_key = split_path(src_path) 233 dest_bucket, dest_key = split_path(dest_path) 234 235 def _invoke_api() -> int: 236 src_obj = self.client.bucket(bck_name=src_bucket, provider=self.provider).object(obj_name=src_key) 237 dest_obj = self.client.bucket(bck_name=dest_bucket, provider=self.provider).object(obj_name=dest_key) 238 239 # Get source size before copying 240 src_headers = src_obj.head() 241 src_props = ObjectProps(src_headers) 242 243 # Server-side copy (preserves custom metadata automatically) 244 src_obj.copy(to_obj=dest_obj) # type: ignore[attr-defined] 245 246 return int(src_props.size) 247 248 return self._translate_errors( 249 _invoke_api, operation="COPY", bucket=f"{src_bucket}->{dest_bucket}", key=f"{src_key}->{dest_key}" 250 ) 251 252 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 253 bucket, key = split_path(path) 254 255 def _invoke_api() -> None: 256 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 257 # AIS doesn't support if-match deletion, so we implement a fallback mechanism 258 if if_match: 259 raise NotImplementedError("AIStore does not support if-match deletion") 260 # Perform deletion 261 obj.delete() 262 263 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 264 265 def _is_dir(self, path: str) -> bool: 266 # Ensure the path ends with '/' to mimic a directory 267 path = self._append_delimiter(path) 268 269 bucket, prefix = split_path(path) 270 271 def _invoke_api() -> bool: 272 # List objects with the given prefix (limit to 1 for efficiency) 273 objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter( 274 prefix=prefix, page_size=1 275 ) 276 # Check if there are any objects with this prefix 277 return any(True for _ in objects) 278 279 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 280 281 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 282 bucket, key = split_path(path) 283 if path.endswith("/") or (bucket and not key): 284 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 285 # which metadata is not guaranteed to exist for cases such as 286 # "virtual prefix" that was never explicitly created. 287 if self._is_dir(path): 288 return ObjectMetadata( 289 key=path, 290 type="directory", 291 content_length=0, 292 last_modified=AWARE_DATETIME_MIN, 293 ) 294 else: 295 raise FileNotFoundError(f"Directory {path} does not exist.") 296 else: 297 298 def _invoke_api() -> ObjectMetadata: 299 obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key) 300 try: 301 headers = obj.head() 302 props = ObjectProps(headers) 303 304 return ObjectMetadata( 305 key=key, 306 content_length=int(props.size), # pyright: ignore [reportArgumentType] 307 last_modified=AWARE_DATETIME_MIN, 308 etag=props.checksum_value, 309 metadata=props.custom_metadata, 310 ) 311 except (AISError, HTTPError) as e: 312 # Check if this might be a virtual directory (prefix with objects under it) 313 status_code = None 314 if isinstance(e, AISError): 315 status_code = e.status_code 316 elif isinstance(e, HTTPError): 317 status_code = e.response.status_code 318 319 if status_code == 404: 320 if self._is_dir(path): 321 return ObjectMetadata( 322 key=path + "/", 323 type="directory", 324 content_length=0, 325 last_modified=AWARE_DATETIME_MIN, 326 ) 327 # Re-raise to be handled by _translate_errors 328 raise 329 330 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 331 332 def _list_objects( 333 self, 334 path: str, 335 start_after: Optional[str] = None, 336 end_at: Optional[str] = None, 337 include_directories: bool = False, 338 follow_symlinks: bool = True, 339 ) -> Iterator[ObjectMetadata]: 340 bucket, prefix = split_path(path) 341 342 # Get the prefix of the start_after and end_at paths relative to the bucket. 343 if start_after: 344 _, start_after = split_path(start_after) 345 if end_at: 346 _, end_at = split_path(end_at) 347 348 def _invoke_api() -> Iterator[ObjectMetadata]: 349 # AIS has no start key option like other object stores. 350 all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_objects_iter( 351 prefix=prefix, props="name,size,atime,checksum,cone", page_size=DEFAULT_PAGE_SIZE 352 ) 353 354 # Assume AIS guarantees lexicographical order. 355 for bucket_entry in all_objects: 356 obj = bucket_entry.object 357 key = obj.name 358 props = bucket_entry.generate_object_props() 359 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 360 yield ObjectMetadata( 361 key=key, 362 content_length=int(props.size), 363 last_modified=AWARE_DATETIME_MIN, 364 etag=props.checksum_value, 365 ) 366 elif end_at is not None and end_at < key: 367 return 368 369 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 370 371 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 372 file_size: int = 0 373 374 if isinstance(f, str): 375 with open(f, "rb") as fp: 376 body = fp.read() 377 file_size = len(body) 378 self._put_object(remote_path, body, attributes=attributes) 379 else: 380 if isinstance(f, io.StringIO): 381 body = f.read().encode("utf-8") 382 file_size = len(body) 383 self._put_object(remote_path, body, attributes=attributes) 384 else: 385 body = f.read() 386 file_size = len(body) 387 self._put_object(remote_path, body, attributes=attributes) 388 389 return file_size 390 391 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 392 if metadata is None: 393 metadata = self._get_object_metadata(remote_path) 394 395 if isinstance(f, str): 396 if os.path.dirname(f): 397 os.makedirs(os.path.dirname(f), exist_ok=True) 398 with open(f, "wb") as fp: 399 fp.write(self._get_object(remote_path)) 400 else: 401 if isinstance(f, io.StringIO): 402 f.write(self._get_object(remote_path).decode("utf-8")) 403 else: 404 f.write(self._get_object(remote_path)) 405 406 return metadata.content_length