Source code for multistorageclient.providers.s3

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import io
import os
import tempfile
import time
from typing import IO, Any, Callable, Iterator, Optional, Union

import boto3
from boto3.s3.transfer import TransferConfig
from botocore.credentials import RefreshableCredentials
from botocore.exceptions import ClientError, ReadTimeoutError
from botocore.session import get_session

from ..types import (
    Credentials,
    CredentialsProvider,
    ObjectMetadata,
    Range,
    RetryableError,
)
from ..utils import split_path
from .base import BaseStorageProvider

BOTO3_MAX_POOL_CONNECTIONS = 32
BOTO3_CONNECT_TIMEOUT = 10
BOTO3_READ_TIMEOUT = 10

MB = 1024 * 1024

MULTIPART_THRESHOLD = 512 * MB
MULTIPART_CHUNK_SIZE = 256 * MB
IO_CHUNK_SIZE = 128 * MB
MAX_CONCURRENCY = 16
PROVIDER = "s3"


[docs]class StaticS3CredentialsProvider(CredentialsProvider): """ A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. """ _access_key: str _secret_key: str _session_token: Optional[str] def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): """ Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional session token. :param access_key: The access key for S3 authentication. :param secret_key: The secret key for S3 authentication. :param session_token: An optional session token for temporary credentials. """ self._access_key = access_key self._secret_key = secret_key self._session_token = session_token
[docs] def get_credentials(self) -> Credentials: return Credentials( access_key=self._access_key, secret_key=self._secret_key, token=self._session_token, expiration=None, )
[docs] def refresh_credentials(self) -> None: pass
[docs]class S3StorageProvider(BaseStorageProvider): """ A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or SwiftStack. """ def __init__( self, region_name: str = "", endpoint_url: str = "", base_path: str = "", credentials_provider: Optional[CredentialsProvider] = None, **kwargs: Any, ) -> None: """ Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. :param region_name: The AWS region where the S3 bucket is located. :param endpoint_url: The custom endpoint URL for the S3 service. :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. :param credentials_provider: The provider to retrieve S3 credentials. """ super().__init__(base_path=base_path, provider_name=PROVIDER) self._region_name = region_name self._endpoint_url = endpoint_url self._credentials_provider = credentials_provider self._s3_client = self._create_s3_client() self._transfer_config = TransferConfig( multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), max_concurrency=int(kwargs.get("max_concurrency", MAX_CONCURRENCY)), multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNK_SIZE)), io_chunksize=int(kwargs.get("io_chunk_size", IO_CHUNK_SIZE)), use_threads=True, ) def _create_s3_client(self): """ Creates and configures the boto3 S3 client, using refreshable credentials if possible. :return The configured S3 client. """ options = { "region_name": self._region_name, "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] max_pool_connections=BOTO3_MAX_POOL_CONNECTIONS, connect_timeout=BOTO3_CONNECT_TIMEOUT, read_timeout=BOTO3_READ_TIMEOUT, retries=dict(mode="standard"), ), } if self._endpoint_url: options["endpoint_url"] = self._endpoint_url if self._credentials_provider: creds = self._fetch_credentials() if "expiry_time" in creds and creds["expiry_time"]: # Use RefreshableCredentials if expiry_time provided. refreshable_credentials = RefreshableCredentials.create_from_metadata( metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" ) botocore_session = get_session() botocore_session._credentials = refreshable_credentials boto3_session = boto3.Session(botocore_session=botocore_session) return boto3_session.client("s3", **options) else: # Add static credentials to the options dictionary options["aws_access_key_id"] = creds["access_key"] options["aws_secret_access_key"] = creds["secret_key"] if creds["token"]: options["aws_session_token"] = creds["token"] # Fallback to standard credential chain. return boto3.client("s3", **options) def _fetch_credentials(self) -> dict: """ Refreshes the S3 client if the current credentials are expired. """ if not self._credentials_provider: raise RuntimeError("Cannot fetch credentials if no credential provider configured.") self._credentials_provider.refresh_credentials() credentials = self._credentials_provider.get_credentials() return { "access_key": credentials.access_key, "secret_key": credentials.secret_key, "token": credentials.token, "expiry_time": credentials.expiration, } def _collect_metrics( self, func: Callable, operation: str, bucket: str, key: str, put_object_size: Optional[int] = None, get_object_size: Optional[int] = None, ) -> Any: """ Collects and records performance metrics around S3 operations such as PUT, GET, DELETE, etc. This method wraps an S3 operation and measures the time it takes to complete, along with recording the size of the object if applicable. It handles errors like timeouts and client errors and ensures proper logging of duration and object size. :param func: The function that performs the actual S3 operation. :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). :param bucket: The name of the S3 bucket involved in the operation. :param key: The key of the object within the S3 bucket. :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). :return: The result of the S3 operation, typically the return value of the `func` callable. """ start_time = time.time() status_code = 200 object_size = None if operation == "PUT": object_size = put_object_size elif operation == "GET" and get_object_size: object_size = get_object_size try: result = func() if operation == "GET" and object_size is None: object_size = len(result) return result except ClientError as error: status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] if status_code == 404: raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from elif status_code == 429: raise RetryableError(f"Too many request to {operation} object(s) at {bucket}/{key}") from error else: raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error except FileNotFoundError as error: status_code = -1 raise error except ReadTimeoutError as error: status_code = -1 raise RetryableError(f"Failed to {operation} object(s) at {bucket}/{key}") from error except Exception as error: status_code = -1 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error finally: elapsed_time = time.time() - start_time self._metric_helper.record_duration( elapsed_time, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code ) if object_size: self._metric_helper.record_object_size( object_size, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code ) def _put_object(self, path: str, body: bytes) -> None: bucket, key = split_path(path) def _invoke_api() -> None: self._s3_client.put_object(Bucket=bucket, Key=key, Body=body) return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: bucket, key = split_path(path) def _invoke_api() -> bytes: if byte_range: bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) else: response = self._s3_client.get_object(Bucket=bucket, Key=key) return response["Body"].read() return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) def _delete_object(self, path: str) -> None: bucket, key = split_path(path) def _invoke_api() -> None: self._s3_client.delete_object(Bucket=bucket, Key=key) return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) def _is_dir(self, path: str) -> bool: # Ensure the path ends with '/' to mimic a directory path = self._append_delimiter(path) bucket, key = split_path(path) def _invoke_api() -> bool: # List objects with the given prefix response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") # Check if there are any contents or common prefixes return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) def _get_object_metadata(self, path: str) -> ObjectMetadata: if path.endswith("/"): # If path is a "directory", then metadata is not guaranteed to exist if # it is a "virtual prefix" that was never explicitly created. if self._is_dir(path): return ObjectMetadata( key=self._append_delimiter(path), type="directory", content_length=0, last_modified=datetime.datetime.min, ) else: raise FileNotFoundError(f"Directory {path} does not exist.") else: bucket, key = split_path(path) def _invoke_api() -> ObjectMetadata: response = self._s3_client.head_object(Bucket=bucket, Key=key) return ObjectMetadata( key=path, type="file", content_length=response["ContentLength"], content_type=response["ContentType"], last_modified=response["LastModified"], etag=response["ETag"].strip('"'), ) try: return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) except FileNotFoundError as error: # If the object does not exist on the given path, we will append a trailing slash and # check if the path is a directory. path = self._append_delimiter(path) if self._is_dir(path): return ObjectMetadata( key=self._append_delimiter(path), type="directory", content_length=0, last_modified=datetime.datetime.min, ) else: raise error def _list_objects( self, prefix: str, start_after: Optional[str] = None, end_at: Optional[str] = None ) -> Iterator[ObjectMetadata]: bucket, prefix = split_path(prefix) def _invoke_api() -> Iterator[ObjectMetadata]: paginator = self._s3_client.get_paginator("list_objects_v2") page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) for page in page_iterator: # S3 guarantees lexicographical order for general purpose buckets (for # normal S3) but not directory buckets (for S3 Express One Zone). for response_object in page.get("Contents", []): key = response_object["Key"] if end_at is None or key <= end_at: yield ObjectMetadata( key=key, type="file", content_length=response_object["Size"], last_modified=response_object["LastModified"], etag=response_object["ETag"].strip('"'), ) else: return return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None: if isinstance(f, str): filesize = os.path.getsize(f) # Upload small files if filesize <= self._transfer_config.multipart_threshold: with open(f, "rb") as fp: self._put_object(remote_path, fp.read()) return # Upload large files using TransferConfig bucket, key = split_path(remote_path) def _invoke_api() -> None: self._s3_client.upload_file( Filename=f, Bucket=bucket, Key=key, Config=self._transfer_config, ) return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize) else: # Upload small files f.seek(0, io.SEEK_END) filesize = f.tell() f.seek(0) if filesize <= self._transfer_config.multipart_threshold: if isinstance(f, io.StringIO): self._put_object(remote_path, f.read().encode("utf-8")) else: self._put_object(remote_path, f.read()) return # Upload large files using TransferConfig bucket, key = split_path(remote_path) def _invoke_api() -> None: self._s3_client.upload_fileobj( Fileobj=f, Bucket=bucket, Key=key, Config=self._transfer_config, ) return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize) def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None: if not metadata: metadata = self._get_object_metadata(remote_path) if isinstance(f, str): os.makedirs(os.path.dirname(f), exist_ok=True) # Download small files if metadata.content_length <= self._transfer_config.multipart_threshold: with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: temp_file_path = fp.name fp.write(self._get_object(remote_path)) os.rename(src=temp_file_path, dst=f) return # Download large files using TransferConfig bucket, key = split_path(remote_path) def _invoke_api() -> None: with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: temp_file_path = fp.name self._s3_client.download_fileobj( Bucket=bucket, Key=key, Fileobj=fp, Config=self._transfer_config, ) os.rename(src=temp_file_path, dst=f) return self._collect_metrics( _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length ) else: # Download small files if metadata.content_length <= self._transfer_config.multipart_threshold: if isinstance(f, io.StringIO): f.write(self._get_object(remote_path).decode("utf-8")) else: f.write(self._get_object(remote_path)) return # Download large files using TransferConfig bucket, key = split_path(remote_path) def _invoke_api() -> None: self._s3_client.download_fileobj( Bucket=bucket, Key=key, Fileobj=f, Config=self._transfer_config, ) return self._collect_metrics( _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length )