Source code for multistorageclient.providers.ais

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16
 17import io
 18import os
 19import time
 20from datetime import datetime
 21from typing import IO, Any, Callable, Iterator, Optional, Tuple, Union
 22
 23from aistore.sdk import Client
 24from aistore.sdk.authn import AuthNClient
 25from aistore.sdk.errors import AISError
 26from dateutil.parser import parse as dateutil_parser
 27from requests.exceptions import HTTPError
 28
 29from ..types import (
 30    Credentials,
 31    CredentialsProvider,
 32    ObjectMetadata,
 33    Range,
 34)
 35from ..utils import split_path
 36from .base import BaseStorageProvider
 37
 38PROVIDER = "ais"
 39
 40
[docs] 41class StaticAISCredentialProvider(CredentialsProvider): 42 """ 43 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 44 """ 45 46 _username: Optional[str] 47 _password: Optional[str] 48 _authn_endpoint: Optional[str] 49 _token: Optional[str] 50 _skip_verify: bool 51 _ca_cert: Optional[str] 52 53 def __init__( 54 self, 55 username: Optional[str] = None, 56 password: Optional[str] = None, 57 authn_endpoint: Optional[str] = None, 58 token: Optional[str] = None, 59 skip_verify: bool = True, 60 ca_cert: Optional[str] = None, 61 ): 62 """ 63 Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials. 64 65 :param username: The username for the AIStore authentication. 66 :param password: The password for the AIStore authentication. 67 :param authn_endpoint: The AIStore authentication endpoint. 68 :param token: The AIStore authentication token. This is used for authentication if username, 69 password and authn_endpoint are not provided. 70 :param skip_verify: If true, skip SSL certificate verification. 71 :param ca_cert: Path to a CA certificate file for SSL verification. 72 73 """ 74 self._username = username 75 self._password = password 76 self._authn_endpoint = authn_endpoint 77 self._token = token 78 self._skip_verify = skip_verify 79 self._ca_cert = ca_cert 80
[docs] 81 def get_credentials(self) -> Credentials: 82 if self._username and self._password and self._authn_endpoint: 83 authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert) 84 self._token = authn_client.login(self._username, self._password) 85 return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
86
[docs] 87 def refresh_credentials(self) -> None: 88 pass
89 90
[docs] 91class AIStoreStorageProvider(BaseStorageProvider): 92 def __init__( 93 self, 94 endpoint: str, 95 provider: str = PROVIDER, 96 skip_verify: bool = True, 97 ca_cert: Optional[str] = None, 98 timeout: Optional[Union[float, Tuple[float, float]]] = None, 99 base_path: str = "", 100 credentials_provider: Optional[CredentialsProvider] = None, 101 **kwargs: Any, 102 ) -> None: 103 """ 104 AIStore client for managing buckets, objects, and ETL jobs. 105 106 :param endpoint: The AIStore endpoint. 107 :param skip_verify: Whether to skip SSL certificate verification. 108 :param ca_cert: Path to a CA certificate file for SSL verification. 109 :param timeout: Request timeout in seconds; a single float 110 for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read 111 timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout. 112 :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used. 113 :param base_path: The root prefix path within the bucket where all operations will be scoped. 114 """ 115 super().__init__(base_path=base_path, provider_name=PROVIDER) 116 117 token = None 118 if credentials_provider: 119 token = credentials_provider.get_credentials().token 120 self.client = Client( 121 endpoint=endpoint, skip_verify=skip_verify, ca_cert=ca_cert, timeout=timeout, token=token 122 ) 123 else: 124 self.client = Client(endpoint=endpoint) 125 self.provider = provider 126 127 def _collect_metrics( 128 self, 129 func: Callable, 130 operation: str, 131 bucket: str, 132 key: str, 133 put_object_size: Optional[int] = None, 134 get_object_size: Optional[int] = None, 135 ) -> Any: 136 """ 137 Collects and records performance metrics around object storage operations 138 such as ``PUT``, ``GET``, ``DELETE``, etc. 139 140 This method wraps an object storage operation and measures the time it takes to complete, along with recording 141 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 142 proper logging of duration and object size. 143 144 :param func: The function that performs the actual object storage operation. 145 :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``). 146 :param bucket: The name of the object storage bucket involved in the operation. 147 :param key: The key of the object within the object storage bucket. 148 :param put_object_size: The size of the object being uploaded, if applicable (for ``PUT`` operations). 149 :param get_object_size: The size of the object being downloaded, if applicable (for ``GET`` operations). 150 151 :return: The result of the object storage operation, typically the return value of the `func` callable. 152 """ 153 start_time = time.time() 154 status_code = 200 155 156 object_size = None 157 if operation == "PUT": 158 object_size = put_object_size 159 elif operation == "GET" and get_object_size: 160 object_size = get_object_size 161 162 try: 163 result = func() 164 if operation == "GET" and object_size is None: 165 object_size = len(result) 166 return result 167 except AISError as error: 168 status_code = error.status_code 169 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error 170 except HTTPError as error: 171 status_code = error.response.status_code 172 if status_code == 404: 173 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 174 else: 175 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error 176 except Exception as error: 177 status_code = -1 178 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error 179 finally: 180 elapsed_time = time.time() - start_time 181 self._metric_helper.record_duration( 182 elapsed_time, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code 183 ) 184 if object_size: 185 self._metric_helper.record_object_size( 186 object_size, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code 187 ) 188 189 def _put_object(self, path: str, body: bytes) -> None: 190 bucket, key = split_path(path) 191 192 def _invoke_api() -> None: 193 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 194 obj.put_content(body) 195 196 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 197 198 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 199 bucket, key = split_path(path) 200 if byte_range: 201 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 202 else: 203 bytes_range = None 204 205 def _invoke_api() -> bytes: 206 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 207 if byte_range: 208 reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType] 209 else: 210 reader = obj.get() 211 return reader.read_all() 212 213 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 214 215 def _copy_object(self, src_path: str, dest_path: str) -> None: 216 raise AttributeError("AIStore does not support copy operations") 217 218 def _delete_object(self, path: str) -> None: 219 bucket, key = split_path(path) 220 221 def _invoke_api() -> None: 222 obj = self.client.bucket(bucket, self.provider).object(obj_name=key) 223 obj.delete() 224 225 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 226 227 def _get_object_metadata(self, path: str) -> ObjectMetadata: 228 bucket, key = split_path(path) 229 230 def _invoke_api() -> ObjectMetadata: 231 obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key) 232 props = obj.head() 233 last_modified = datetime.fromtimestamp(int(props.get("Ais-Atime")) // 1_000_000_000) # pyright: ignore [reportArgumentType] 234 return ObjectMetadata( 235 key=key, 236 content_length=int(props.get("Content-Length")), # pyright: ignore [reportArgumentType] 237 last_modified=last_modified, 238 etag=props.get("Ais-Checksum-Value", None), 239 ) 240 241 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 242 243 def _list_objects( 244 self, prefix: str, start_after: Optional[str] = None, end_at: Optional[str] = None 245 ) -> Iterator[ObjectMetadata]: 246 bucket, prefix = split_path(prefix) 247 248 def _invoke_api() -> Iterator[ObjectMetadata]: 249 # AIS has no start key option like other object stores. 250 all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_all_objects_iter( 251 prefix=prefix, props="name,size,atime,checksum,cone" 252 ) 253 254 # Assume AIS guarantees lexicographical order. 255 for obj in all_objects: 256 key = obj.name 257 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 258 yield ObjectMetadata( 259 key=key, 260 content_length=int(obj.props.size), 261 last_modified=dateutil_parser(obj.props.access_time), 262 etag=obj.props.checksum_value, 263 ) 264 elif end_at is not None and end_at < key: 265 return 266 267 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 268 269 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None: 270 if isinstance(f, str): 271 with open(f, "rb") as fp: 272 self._put_object(remote_path, fp.read()) 273 else: 274 if isinstance(f, io.StringIO): 275 self._put_object(remote_path, f.read().encode("utf-8")) 276 else: 277 self._put_object(remote_path, f.read()) 278 279 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None: 280 if not metadata: 281 metadata = self._get_object_metadata(remote_path) 282 283 if isinstance(f, str): 284 os.makedirs(os.path.dirname(f), exist_ok=True) 285 with open(f, "wb") as fp: 286 fp.write(self._get_object(remote_path)) 287 else: 288 if isinstance(f, io.StringIO): 289 f.write(self._get_object(remote_path).decode("utf-8")) 290 else: 291 f.write(self._get_object(remote_path))