Source code for multistorageclient.providers.ais

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import io
import os
import time
from datetime import datetime
from typing import IO, Any, Callable, Iterator, Optional, Tuple, Union

from aistore.sdk import Client
from aistore.sdk.authn import AuthNClient
from aistore.sdk.errors import AISError
from dateutil.parser import parse as dateutil_parser
from requests.exceptions import HTTPError

from ..types import (
    Credentials,
    CredentialsProvider,
    ObjectMetadata,
    Range,
)
from ..utils import split_path
from .base import BaseStorageProvider

PROVIDER = "ais"


[docs]class StaticAISCredentialProvider(CredentialsProvider): """ A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. """ _username: Optional[str] _password: Optional[str] _authn_endpoint: Optional[str] _token: Optional[str] _skip_verify: bool _ca_cert: Optional[str] def __init__( self, username: Optional[str] = None, password: Optional[str] = None, authn_endpoint: Optional[str] = None, token: Optional[str] = None, skip_verify: bool = True, ca_cert: Optional[str] = None, ): """ Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials. :param username: The username for the AIStore authentication. :param password: The password for the AIStore authentication. :param authn_endpoint: The AIStore authentication endpoint. :param token: The AIStore authentication token. This is used for authentication if username, password and authn_endpoint are not provided. :param skip_verify: If true, skip SSL certificate verification. :param ca_cert: Path to a CA certificate file for SSL verification. """ self._username = username self._password = password self._authn_endpoint = authn_endpoint self._token = token self._skip_verify = skip_verify self._ca_cert = ca_cert
[docs] def get_credentials(self) -> Credentials: if self._username and self._password and self._authn_endpoint: authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert) self._token = authn_client.login(self._username, self._password) return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
[docs] def refresh_credentials(self) -> None: pass
[docs]class AIStoreStorageProvider(BaseStorageProvider): def __init__( self, endpoint: str, provider: str = PROVIDER, skip_verify: bool = True, ca_cert: Optional[str] = None, timeout: Optional[Union[float, Tuple[float, float]]] = None, base_path: str = "", credentials_provider: Optional[CredentialsProvider] = None, **kwargs: Any, ) -> None: """ AIStore client for managing buckets, objects, and ETL jobs. :param endpoint: The AIStore endpoint. :param skip_verify: Whether to skip SSL certificate verification. :param ca_cert: Path to a CA certificate file for SSL verification. :param timeout: Request timeout in seconds; a single float for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout. :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used. :param base_path: The root prefix path within the bucket where all operations will be scoped. """ super().__init__(base_path=base_path, provider_name=PROVIDER) token = None if credentials_provider: token = credentials_provider.get_credentials().token self.client = Client( endpoint=endpoint, skip_verify=skip_verify, ca_cert=ca_cert, timeout=timeout, token=token ) else: self.client = Client(endpoint=endpoint) self.provider = provider def _collect_metrics( self, func: Callable, operation: str, bucket: str, key: str, put_object_size: Optional[int] = None, get_object_size: Optional[int] = None, ) -> Any: """ Collects and records performance metrics around object storage operations such as ``PUT``, ``GET``, ``DELETE``, etc. This method wraps an object storage operation and measures the time it takes to complete, along with recording the size of the object if applicable. It handles errors like timeouts and client errors and ensures proper logging of duration and object size. :param func: The function that performs the actual object storage operation. :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``). :param bucket: The name of the object storage bucket involved in the operation. :param key: The key of the object within the object storage bucket. :param put_object_size: The size of the object being uploaded, if applicable (for ``PUT`` operations). :param get_object_size: The size of the object being downloaded, if applicable (for ``GET`` operations). :return: The result of the object storage operation, typically the return value of the `func` callable. """ start_time = time.time() status_code = 200 object_size = None if operation == "PUT": object_size = put_object_size elif operation == "GET" and get_object_size: object_size = get_object_size try: result = func() if operation == "GET" and object_size is None: object_size = len(result) return result except AISError as error: status_code = error.status_code raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error except HTTPError as error: status_code = error.response.status_code if status_code == 404: raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from else: raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error except Exception as error: status_code = -1 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error finally: elapsed_time = time.time() - start_time self._metric_helper.record_duration( elapsed_time, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code ) if object_size: self._metric_helper.record_object_size( object_size, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code ) def _put_object(self, path: str, body: bytes) -> None: bucket, key = split_path(path) def _invoke_api() -> None: obj = self.client.bucket(bucket, self.provider).object(obj_name=key) obj.put_content(body) return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: bucket, key = split_path(path) if byte_range: bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" else: bytes_range = None def _invoke_api() -> bytes: obj = self.client.bucket(bucket, self.provider).object(obj_name=key) if byte_range: reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType] else: reader = obj.get() return reader.read_all() return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) def _delete_object(self, path: str) -> None: bucket, key = split_path(path) def _invoke_api() -> None: obj = self.client.bucket(bucket, self.provider).object(obj_name=key) obj.delete() return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) def _get_object_metadata(self, path: str) -> ObjectMetadata: bucket, key = split_path(path) def _invoke_api() -> ObjectMetadata: obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key) props = obj.head() last_modified = datetime.fromtimestamp(int(props.get("Ais-Atime")) // 1_000_000_000) # pyright: ignore [reportArgumentType] return ObjectMetadata( key=key, content_length=int(props.get("Content-Length")), # pyright: ignore [reportArgumentType] last_modified=last_modified, etag=props.get("Ais-Checksum-Value", None), ) return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) def _list_objects( self, prefix: str, start_after: Optional[str] = None, end_at: Optional[str] = None ) -> Iterator[ObjectMetadata]: bucket, prefix = split_path(prefix) def _invoke_api() -> Iterator[ObjectMetadata]: # AIS has no start key option like other object stores. all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_all_objects_iter( prefix=prefix, props="name,size,atime,checksum,cone" ) # Assume AIS guarantees lexicographical order. for obj in all_objects: key = obj.name if (start_after is None or start_after < key) and (end_at is None or key <= end_at): yield ObjectMetadata( key=key, content_length=int(obj.props.size), last_modified=dateutil_parser(obj.props.access_time), etag=obj.props.checksum_value, ) elif end_at is not None and end_at < key: return return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None: if isinstance(f, str): with open(f, "rb") as fp: self._put_object(remote_path, fp.read()) else: if isinstance(f, io.StringIO): self._put_object(remote_path, f.read().encode("utf-8")) else: self._put_object(remote_path, f.read()) def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None: if not metadata: metadata = self._get_object_metadata(remote_path) if isinstance(f, str): os.makedirs(os.path.dirname(f), exist_ok=True) with open(f, "wb") as fp: fp.write(self._get_object(remote_path)) else: if isinstance(f, io.StringIO): f.write(self._get_object(remote_path).decode("utf-8")) else: f.write(self._get_object(remote_path))