Source code for multistorageclient.providers.s3

   1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
   2# SPDX-License-Identifier: Apache-2.0
   3#
   4# Licensed under the Apache License, Version 2.0 (the "License");
   5# you may not use this file except in compliance with the License.
   6# You may obtain a copy of the License at
   7#
   8# http://www.apache.org/licenses/LICENSE-2.0
   9#
  10# Unless required by applicable law or agreed to in writing, software
  11# distributed under the License is distributed on an "AS IS" BASIS,
  12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13# See the License for the specific language governing permissions and
  14# limitations under the License.
  15
  16import codecs
  17import io
  18import os
  19import tempfile
  20from collections.abc import Callable, Iterator
  21from typing import IO, Any, Optional, TypeVar, Union
  22
  23import boto3
  24import botocore
  25from boto3.s3.transfer import TransferConfig
  26from botocore.credentials import RefreshableCredentials
  27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
  28from botocore.session import get_session
  29from dateutil.parser import parse as dateutil_parse
  30
  31from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
  32
  33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
  34from ..rust_utils import parse_retry_config, run_async_rust_client_method
  35from ..signers import CloudFrontURLSigner, URLSigner
  36from ..telemetry import Telemetry
  37from ..types import (
  38    AWARE_DATETIME_MIN,
  39    Credentials,
  40    CredentialsProvider,
  41    ObjectMetadata,
  42    PreconditionFailedError,
  43    Range,
  44    RetryableError,
  45    SignerType,
  46)
  47from ..utils import (
  48    get_available_cpu_count,
  49    safe_makedirs,
  50    split_path,
  51    validate_attributes,
  52)
  53from .base import BaseStorageProvider
  54
  55_T = TypeVar("_T")
  56
  57# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
  58MAX_POOL_CONNECTIONS = max(
  59    64,
  60    get_available_cpu_count(),
  61    int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
  62)
  63
  64MiB = 1024 * 1024
  65
  66# Python and Rust share the same multipart_threshold to keep the code simple.
  67MULTIPART_THRESHOLD = 64 * MiB
  68MULTIPART_CHUNKSIZE = 32 * MiB
  69IO_CHUNKSIZE = 32 * MiB
  70PYTHON_MAX_CONCURRENCY = 8
  71
  72PROVIDER = "s3"
  73
  74EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
  75
  76
[docs] 77class StaticS3CredentialsProvider(CredentialsProvider): 78 """ 79 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 80 """ 81 82 _access_key: str 83 _secret_key: str 84 _session_token: Optional[str] 85 86 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 87 """ 88 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 89 session token. 90 91 :param access_key: The access key for S3 authentication. 92 :param secret_key: The secret key for S3 authentication. 93 :param session_token: An optional session token for temporary credentials. 94 """ 95 self._access_key = access_key 96 self._secret_key = secret_key 97 self._session_token = session_token 98
[docs] 99 def get_credentials(self) -> Credentials: 100 return Credentials( 101 access_key=self._access_key, 102 secret_key=self._secret_key, 103 token=self._session_token, 104 expiration=None, 105 )
106
[docs] 107 def refresh_credentials(self) -> None: 108 pass
109 110 111DEFAULT_PRESIGN_EXPIRES_IN = 3600 112 113_S3_METHOD_MAPPING: dict[str, str] = { 114 "GET": "get_object", 115 "PUT": "put_object", 116} 117 118
[docs] 119class S3URLSigner(URLSigner): 120 """Generates pre-signed URLs using the boto3 S3 client. 121 122 When the underlying credentials are temporary (STS, IAM role, EC2 instance 123 profile), the effective URL lifetime is the **shorter** of ``expires_in`` 124 and the remaining credential lifetime — boto3 will not warn if the 125 credential expires before ``expires_in``. 126 127 See https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html 128 """ 129 130 def __init__(self, s3_client: Any, bucket: str, expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN) -> None: 131 self._s3_client = s3_client 132 self._bucket = bucket 133 self._expires_in = expires_in 134
[docs] 135 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str: 136 client_method = _S3_METHOD_MAPPING.get(method.upper()) 137 if client_method is None: 138 raise ValueError(f"Unsupported method for S3 presigning: {method!r}") 139 return self._s3_client.generate_presigned_url( 140 ClientMethod=client_method, 141 Params={"Bucket": self._bucket, "Key": path}, 142 ExpiresIn=self._expires_in, 143 )
144 145
[docs] 146class S3StorageProvider(BaseStorageProvider): 147 """ 148 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 149 """ 150 151 def __init__( 152 self, 153 region_name: str = "", 154 endpoint_url: str = "", 155 base_path: str = "", 156 credentials_provider: Optional[CredentialsProvider] = None, 157 config_dict: Optional[dict[str, Any]] = None, 158 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 159 verify: Optional[Union[bool, str]] = None, 160 **kwargs: Any, 161 ) -> None: 162 """ 163 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 164 165 :param region_name: The AWS region where the S3 bucket is located. 166 :param endpoint_url: The custom endpoint URL for the S3 service. 167 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 168 :param credentials_provider: The provider to retrieve S3 credentials. 169 :param config_dict: Resolved MSC config. 170 :param telemetry_provider: A function that provides a telemetry instance. 171 :param verify: Controls SSL certificate verification. 172 Can be ``True`` (verify using system CA bundle, default), ``False`` (skip verification), or a string path to a custom CA certificate bundle. 173 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. 174 When the underlying S3 client should calculate request checksums. 175 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 176 :param response_checksum_validation: For :py:class:`botocore.config.Config`. 177 When the underlying S3 client should validate response checksums. 178 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 179 :param max_pool_connections: For :py:class:`botocore.config.Config`. 180 The maximum number of connections to keep in a connection pool. 181 :param connect_timeout: For :py:class:`botocore.config.Config`. 182 The time in seconds till a timeout exception is thrown when attempting to make a connection. 183 :param read_timeout: For :py:class:`botocore.config.Config`. 184 The time in seconds till a timeout exception is thrown when attempting to read from a connection. 185 :param retries: For :py:class:`botocore.config.Config`. 186 A dictionary for configuration related to retry behavior. 187 :param s3: For :py:class:`botocore.config.Config`. 188 A dictionary of S3 specific configurations. 189 """ 190 super().__init__( 191 base_path=base_path, 192 provider_name=PROVIDER, 193 config_dict=config_dict, 194 telemetry_provider=telemetry_provider, 195 ) 196 197 self._region_name = region_name 198 self._endpoint_url = endpoint_url 199 self._credentials_provider = credentials_provider 200 self._verify = verify 201 202 self._signature_version = kwargs.get("signature_version", "s3v4") 203 self._s3_client = self._create_s3_client( 204 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 205 response_checksum_validation=kwargs.get("response_checksum_validation"), 206 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS), 207 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT), 208 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT), 209 retries=kwargs.get("retries"), 210 s3=kwargs.get("s3"), 211 ) 212 self._transfer_config = TransferConfig( 213 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 214 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)), 215 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)), 216 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)), 217 use_threads=True, 218 ) 219 220 self._signer_cache: dict[tuple, URLSigner] = {} 221 222 self._rust_client = None 223 if "rust_client" in kwargs: 224 # Inherit the rust client options from the kwargs 225 rust_client_options = kwargs["rust_client"] 226 if "max_pool_connections" in kwargs: 227 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"] 228 if "max_concurrency" in kwargs: 229 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 230 if "multipart_chunksize" in kwargs: 231 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 232 if "read_timeout" in kwargs: 233 rust_client_options["read_timeout"] = kwargs["read_timeout"] 234 if "connect_timeout" in kwargs: 235 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 236 if self._signature_version == "UNSIGNED": 237 rust_client_options["skip_signature"] = True 238 self._rust_client = self._create_rust_client(rust_client_options) 239 240 def _is_directory_bucket(self, bucket: str) -> bool: 241 """ 242 Determines if the bucket is a directory bucket based on bucket name. 243 """ 244 # S3 Express buckets have a specific naming convention 245 return "--x-s3" in bucket 246 247 def _create_s3_client( 248 self, 249 request_checksum_calculation: Optional[str] = None, 250 response_checksum_validation: Optional[str] = None, 251 max_pool_connections: int = MAX_POOL_CONNECTIONS, 252 connect_timeout: Union[float, int, None] = None, 253 read_timeout: Union[float, int, None] = None, 254 retries: Optional[dict[str, Any]] = None, 255 s3: Optional[dict[str, Any]] = None, 256 ): 257 """ 258 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 259 260 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 261 :param response_checksum_validation: For :py:class:`botocore.config.Config`. When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 262 :param max_pool_connections: For :py:class:`botocore.config.Config`. The maximum number of connections to keep in a connection pool. 263 :param connect_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to make a connection. 264 :param read_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to read from a connection. 265 :param retries: For :py:class:`botocore.config.Config`. A dictionary for configuration related to retry behavior. 266 :param s3: For :py:class:`botocore.config.Config`. A dictionary of S3 specific configurations. 267 268 :return: The configured S3 client. 269 """ 270 options = { 271 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 272 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 273 max_pool_connections=max_pool_connections, 274 connect_timeout=connect_timeout, 275 read_timeout=read_timeout, 276 retries=retries or {"mode": "standard"}, 277 request_checksum_calculation=request_checksum_calculation, 278 response_checksum_validation=response_checksum_validation, 279 s3=s3, 280 ), 281 } 282 283 if self._region_name: 284 options["region_name"] = self._region_name 285 286 if self._endpoint_url: 287 options["endpoint_url"] = self._endpoint_url 288 289 if self._verify is not None: 290 options["verify"] = self._verify 291 292 if self._credentials_provider: 293 creds = self._fetch_credentials() 294 if "expiry_time" in creds and creds["expiry_time"]: 295 # Use RefreshableCredentials if expiry_time provided. 296 refreshable_credentials = RefreshableCredentials.create_from_metadata( 297 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 298 ) 299 300 botocore_session = get_session() 301 botocore_session._credentials = refreshable_credentials 302 303 boto3_session = boto3.Session(botocore_session=botocore_session) 304 305 return boto3_session.client("s3", **options) 306 else: 307 # Add static credentials to the options dictionary 308 options["aws_access_key_id"] = creds["access_key"] 309 options["aws_secret_access_key"] = creds["secret_key"] 310 if creds["token"]: 311 options["aws_session_token"] = creds["token"] 312 313 if self._signature_version: 314 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 315 signature_version=botocore.UNSIGNED 316 if self._signature_version == "UNSIGNED" 317 else self._signature_version 318 ) 319 options["config"] = options["config"].merge(signature_config) 320 321 # Fallback to standard credential chain. 322 return boto3.client("s3", **options) 323 324 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 325 """ 326 Creates and configures the rust client, using refreshable credentials if possible. 327 """ 328 configs = dict(rust_client_options) if rust_client_options else {} 329 330 # Extract and parse retry configuration 331 retry_config = parse_retry_config(configs) 332 333 if self._region_name and "region_name" not in configs: 334 configs["region_name"] = self._region_name 335 336 if self._endpoint_url and "endpoint_url" not in configs: 337 configs["endpoint_url"] = self._endpoint_url 338 339 # If the user specifies a bucket, use it. Otherwise, use the base path. 340 if "bucket" not in configs: 341 bucket, _ = split_path(self._base_path) 342 configs["bucket"] = bucket 343 344 if "max_pool_connections" not in configs: 345 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS 346 347 return RustClient( 348 provider=PROVIDER, 349 configs=configs, 350 credentials_provider=self._credentials_provider, 351 retry=retry_config, 352 ) 353 354 def _fetch_credentials(self) -> dict: 355 """ 356 Refreshes the S3 client if the current credentials are expired. 357 """ 358 if not self._credentials_provider: 359 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 360 self._credentials_provider.refresh_credentials() 361 credentials = self._credentials_provider.get_credentials() 362 return { 363 "access_key": credentials.access_key, 364 "secret_key": credentials.secret_key, 365 "token": credentials.token, 366 "expiry_time": credentials.expiration, 367 } 368 369 def _translate_errors( 370 self, 371 func: Callable[[], _T], 372 operation: str, 373 bucket: str, 374 key: str, 375 ) -> _T: 376 """ 377 Translates errors like timeouts and client errors. 378 379 :param func: The function that performs the actual S3 operation. 380 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 381 :param bucket: The name of the S3 bucket involved in the operation. 382 :param key: The key of the object within the S3 bucket. 383 384 :return: The result of the S3 operation, typically the return value of the `func` callable. 385 """ 386 try: 387 return func() 388 except ClientError as error: 389 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 390 request_id = error.response["ResponseMetadata"].get("RequestId") 391 host_id = error.response["ResponseMetadata"].get("HostId") 392 error_code = error.response["Error"]["Code"] 393 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 394 395 if status_code == 404: 396 if error_code == "NoSuchUpload": 397 error_message = error.response["Error"]["Message"] 398 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 399 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 400 elif status_code == 412: # Precondition Failed 401 raise PreconditionFailedError( 402 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 403 ) from error 404 elif status_code == 429: 405 raise RetryableError( 406 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 407 ) from error 408 elif status_code == 503: 409 raise RetryableError( 410 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 411 ) from error 412 elif status_code == 501: 413 raise NotImplementedError( 414 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 415 ) from error 416 elif status_code == 408: 417 # 408 Request Timeout is from Google Cloud Storage 418 raise RetryableError( 419 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}" 420 ) from error 421 else: 422 raise RuntimeError( 423 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 424 f"error_type: {type(error).__name__}" 425 ) from error 426 except RustClientError as error: 427 message = error.args[0] 428 status_code = error.args[1] 429 if status_code == 404: 430 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 431 elif status_code == 403: 432 raise PermissionError( 433 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 434 ) from error 435 else: 436 raise RetryableError( 437 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 438 ) from error 439 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 440 raise RetryableError( 441 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 442 f"error_type: {type(error).__name__}" 443 ) from error 444 except RustRetryableError as error: 445 raise RetryableError( 446 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 447 f"error_type: {type(error).__name__}" 448 ) from error 449 except Exception as error: 450 raise RuntimeError( 451 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 452 ) from error 453 454 def _put_object( 455 self, 456 path: str, 457 body: bytes, 458 if_match: Optional[str] = None, 459 if_none_match: Optional[str] = None, 460 attributes: Optional[dict[str, str]] = None, 461 content_type: Optional[str] = None, 462 ) -> int: 463 """ 464 Uploads an object to the specified S3 path. 465 466 :param path: The S3 path where the object will be uploaded. 467 :param body: The content of the object as bytes. 468 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 469 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 470 :param attributes: Optional attributes to attach to the object. 471 :param content_type: Optional Content-Type header value. 472 """ 473 bucket, key = split_path(path) 474 475 def _invoke_api() -> int: 476 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 477 if content_type: 478 kwargs["ContentType"] = content_type 479 if self._is_directory_bucket(bucket): 480 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 481 if if_match: 482 kwargs["IfMatch"] = if_match 483 if if_none_match: 484 kwargs["IfNoneMatch"] = if_none_match 485 validated_attributes = validate_attributes(attributes) 486 if validated_attributes: 487 kwargs["Metadata"] = validated_attributes 488 489 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client 490 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"} 491 if ( 492 self._rust_client 493 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 494 and not path.endswith("/") 495 and all(key not in kwargs for key in rust_unsupported_feature_keys) 496 ): 497 run_async_rust_client_method(self._rust_client, "put", key, body) 498 else: 499 self._s3_client.put_object(**kwargs) 500 501 return len(body) 502 503 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 504 505 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 506 bucket, key = split_path(path) 507 508 def _invoke_api() -> bytes: 509 if byte_range: 510 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 511 if self._rust_client: 512 response = run_async_rust_client_method( 513 self._rust_client, 514 "get", 515 key, 516 byte_range, 517 ) 518 return response 519 else: 520 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 521 else: 522 if self._rust_client: 523 response = run_async_rust_client_method(self._rust_client, "get", key) 524 return response 525 else: 526 response = self._s3_client.get_object(Bucket=bucket, Key=key) 527 528 return response["Body"].read() 529 530 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 531 532 def _copy_object(self, src_path: str, dest_path: str) -> int: 533 src_bucket, src_key = split_path(src_path) 534 dest_bucket, dest_key = split_path(dest_path) 535 536 src_object = self._get_object_metadata(src_path) 537 538 def _invoke_api() -> int: 539 self._s3_client.copy( 540 CopySource={"Bucket": src_bucket, "Key": src_key}, 541 Bucket=dest_bucket, 542 Key=dest_key, 543 Config=self._transfer_config, 544 ) 545 546 return src_object.content_length 547 548 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key) 549 550 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 551 bucket, key = split_path(path) 552 553 def _invoke_api() -> None: 554 # Delete conditionally when if_match (etag) is provided; otherwise delete unconditionally 555 if if_match: 556 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 557 else: 558 self._s3_client.delete_object(Bucket=bucket, Key=key) 559 560 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 561 562 def _delete_objects(self, paths: list[str]) -> None: 563 if not paths: 564 return 565 566 by_bucket: dict[str, list[str]] = {} 567 for p in paths: 568 bucket, key = split_path(p) 569 by_bucket.setdefault(bucket, []).append(key) 570 571 S3_BATCH_LIMIT = 1000 572 573 def _invoke_api() -> None: 574 all_errors: list[str] = [] 575 for bucket, keys in by_bucket.items(): 576 for i in range(0, len(keys), S3_BATCH_LIMIT): 577 chunk = keys[i : i + S3_BATCH_LIMIT] 578 response = self._s3_client.delete_objects( 579 Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]} 580 ) 581 errors = response.get("Errors") or [] 582 for e in errors: 583 all_errors.append(f"{bucket}/{e.get('Key', '?')}: {e.get('Code', '')} {e.get('Message', '')}") 584 if all_errors: 585 raise RuntimeError(f"DeleteObjects reported errors: {'; '.join(all_errors)}") 586 587 bucket_desc = "(" + "|".join(by_bucket) + ")" 588 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)" 589 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc) 590 591 def _is_dir(self, path: str) -> bool: 592 # Ensure the path ends with '/' to mimic a directory 593 path = self._append_delimiter(path) 594 595 bucket, key = split_path(path) 596 597 def _invoke_api() -> bool: 598 # List objects with the given prefix 599 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 600 601 # Check if there are any contents or common prefixes 602 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 603 604 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 605 606 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 607 bucket, key = split_path(path) 608 if path.endswith("/") or (bucket and not key): 609 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 610 # which metadata is not guaranteed to exist for cases such as 611 # "virtual prefix" that was never explicitly created. 612 if self._is_dir(path): 613 return ObjectMetadata( 614 key=path, 615 type="directory", 616 content_length=0, 617 last_modified=AWARE_DATETIME_MIN, 618 ) 619 else: 620 raise FileNotFoundError(f"Directory {path} does not exist.") 621 else: 622 623 def _invoke_api() -> ObjectMetadata: 624 response = self._s3_client.head_object(Bucket=bucket, Key=key) 625 626 return ObjectMetadata( 627 key=path, 628 type="file", 629 content_length=response["ContentLength"], 630 content_type=response.get("ContentType"), 631 last_modified=response["LastModified"], 632 etag=response["ETag"].strip('"') if "ETag" in response else None, 633 storage_class=response.get("StorageClass"), 634 metadata=response.get("Metadata"), 635 ) 636 637 try: 638 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 639 except FileNotFoundError as error: 640 if strict: 641 # If the object does not exist on the given path, we will append a trailing slash and 642 # check if the path is a directory. 643 path = self._append_delimiter(path) 644 if self._is_dir(path): 645 return ObjectMetadata( 646 key=path, 647 type="directory", 648 content_length=0, 649 last_modified=AWARE_DATETIME_MIN, 650 ) 651 raise error 652 653 def _list_objects( 654 self, 655 path: str, 656 start_after: Optional[str] = None, 657 end_at: Optional[str] = None, 658 include_directories: bool = False, 659 follow_symlinks: bool = True, 660 ) -> Iterator[ObjectMetadata]: 661 bucket, prefix = split_path(path) 662 663 # Get the prefix of the start_after and end_at paths relative to the bucket. 664 if start_after: 665 _, start_after = split_path(start_after) 666 if end_at: 667 _, end_at = split_path(end_at) 668 669 def _invoke_api() -> Iterator[ObjectMetadata]: 670 paginator = self._s3_client.get_paginator("list_objects_v2") 671 if include_directories: 672 page_iterator = paginator.paginate( 673 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 674 ) 675 else: 676 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 677 678 for page in page_iterator: 679 for item in page.get("CommonPrefixes", []): 680 prefix_key = item["Prefix"].rstrip("/") 681 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes 682 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 683 yield ObjectMetadata( 684 key=os.path.join(bucket, prefix_key), 685 type="directory", 686 content_length=0, 687 last_modified=AWARE_DATETIME_MIN, 688 ) 689 elif end_at is not None and end_at < prefix_key: 690 return 691 692 # S3 guarantees lexicographical order for general purpose buckets (for 693 # normal S3) but not directory buckets (for S3 Express One Zone). 694 for response_object in page.get("Contents", []): 695 key = response_object["Key"] 696 if end_at is None or key <= end_at: 697 if key.endswith("/"): 698 if include_directories: 699 yield ObjectMetadata( 700 key=os.path.join(bucket, key.rstrip("/")), 701 type="directory", 702 content_length=0, 703 last_modified=response_object["LastModified"], 704 ) 705 else: 706 yield ObjectMetadata( 707 key=os.path.join(bucket, key), 708 type="file", 709 content_length=response_object["Size"], 710 last_modified=response_object["LastModified"], 711 etag=response_object["ETag"].strip('"'), 712 storage_class=response_object.get("StorageClass"), # Pass storage_class 713 ) 714 else: 715 return 716 717 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 718 719 @property 720 def supports_parallel_listing(self) -> bool: 721 """ 722 S3 supports parallel listing via delimiter-based prefix discovery. 723 724 Note: Directory bucket handling is done in list_objects_recursive(). 725 """ 726 return True 727
[docs] 728 def list_objects_recursive( 729 self, 730 path: str = "", 731 start_after: Optional[str] = None, 732 end_at: Optional[str] = None, 733 max_workers: int = 32, 734 look_ahead: int = 2, 735 follow_symlinks: bool = True, 736 ) -> Iterator[ObjectMetadata]: 737 """ 738 List all objects recursively using parallel prefix discovery for improved performance. 739 740 For S3, uses the Rust client's list_recursive when available for maximum performance. 741 Falls back to Python implementation otherwise. 742 743 Returns files only (no directories), in lexicographic order. 744 745 :param follow_symlinks: Whether to follow symbolic links (POSIX providers only). 746 """ 747 if (start_after is not None) and (end_at is not None) and not (start_after < end_at): 748 raise ValueError(f"start_after ({start_after}) must be before end_at ({end_at})!") 749 750 full_path = self._prepend_base_path(path) 751 bucket, prefix = split_path(full_path) 752 753 if self._is_directory_bucket(bucket): 754 yield from self.list_objects( 755 path, start_after, end_at, include_directories=False, follow_symlinks=follow_symlinks 756 ) 757 return 758 759 if self._rust_client: 760 yield from self._emit_metrics( 761 operation=BaseStorageProvider._Operation.LIST, 762 f=lambda: self._list_objects_recursive_rust(path, full_path, bucket, start_after, end_at, max_workers), 763 ) 764 else: 765 yield from super().list_objects_recursive( 766 path, start_after, end_at, max_workers, look_ahead, follow_symlinks 767 )
768 769 def _list_objects_recursive_rust( 770 self, 771 path: str, 772 full_path: str, 773 bucket: str, 774 start_after: Optional[str], 775 end_at: Optional[str], 776 max_workers: int, 777 ) -> Iterator[ObjectMetadata]: 778 """ 779 Use Rust client's list_recursive for parallel listing. 780 781 The Rust client already handles parallel listing internally. 782 Returns files only in lexicographic order. 783 """ 784 _, prefix = split_path(full_path) 785 786 def _invoke_api() -> Iterator[ObjectMetadata]: 787 result = run_async_rust_client_method( 788 self._rust_client, 789 "list_recursive", 790 [prefix] if prefix else [""], 791 max_concurrency=max_workers, 792 ) 793 794 start_after_full = self._prepend_base_path(start_after) if start_after else None 795 end_at_full = self._prepend_base_path(end_at) if end_at else None 796 797 for obj in result.objects: 798 full_key = os.path.join(bucket, obj.key) 799 800 if start_after_full and full_key <= start_after_full: 801 continue 802 if end_at_full and full_key > end_at_full: 803 break 804 805 relative_key = full_key.removeprefix(self._base_path).lstrip("/") 806 807 yield ObjectMetadata( 808 key=relative_key, 809 content_length=obj.content_length, 810 last_modified=dateutil_parse(obj.last_modified), 811 type="file" if obj.object_type == "object" else obj.object_type, 812 etag=obj.etag, 813 ) 814 815 yield from self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 816 817 def _upload_file( 818 self, 819 remote_path: str, 820 f: Union[str, IO], 821 attributes: Optional[dict[str, str]] = None, 822 content_type: Optional[str] = None, 823 ) -> int: 824 file_size: int = 0 825 826 if isinstance(f, str): 827 bucket, key = split_path(remote_path) 828 file_size = os.path.getsize(f) 829 830 # Upload small files 831 if file_size <= self._transfer_config.multipart_threshold: 832 if self._rust_client and not attributes and not content_type: 833 run_async_rust_client_method(self._rust_client, "upload", f, key) 834 else: 835 with open(f, "rb") as fp: 836 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type) 837 return file_size 838 839 # Upload large files using TransferConfig 840 def _invoke_api() -> int: 841 extra_args = {} 842 if content_type: 843 extra_args["ContentType"] = content_type 844 if self._is_directory_bucket(bucket): 845 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 846 validated_attributes = validate_attributes(attributes) 847 if validated_attributes: 848 extra_args["Metadata"] = validated_attributes 849 if self._rust_client and not extra_args: 850 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 851 else: 852 self._s3_client.upload_file( 853 Filename=f, 854 Bucket=bucket, 855 Key=key, 856 Config=self._transfer_config, 857 ExtraArgs=extra_args, 858 ) 859 860 return file_size 861 862 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 863 else: 864 # Upload small files 865 f.seek(0, io.SEEK_END) 866 file_size = f.tell() 867 f.seek(0) 868 869 if file_size <= self._transfer_config.multipart_threshold: 870 if isinstance(f, io.StringIO): 871 self._put_object( 872 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type 873 ) 874 else: 875 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type) 876 return file_size 877 878 # Upload large files using TransferConfig 879 bucket, key = split_path(remote_path) 880 881 def _invoke_api() -> int: 882 extra_args = {} 883 if content_type: 884 extra_args["ContentType"] = content_type 885 if self._is_directory_bucket(bucket): 886 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 887 validated_attributes = validate_attributes(attributes) 888 if validated_attributes: 889 extra_args["Metadata"] = validated_attributes 890 891 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args: 892 data = f.getbuffer() 893 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data) 894 else: 895 self._s3_client.upload_fileobj( 896 Fileobj=f, 897 Bucket=bucket, 898 Key=key, 899 Config=self._transfer_config, 900 ExtraArgs=extra_args, 901 ) 902 903 return file_size 904 905 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 906 907 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 908 if metadata is None: 909 metadata = self._get_object_metadata(remote_path) 910 911 if isinstance(f, str): 912 bucket, key = split_path(remote_path) 913 if os.path.dirname(f): 914 safe_makedirs(os.path.dirname(f)) 915 916 # Download small files 917 if metadata.content_length <= self._transfer_config.multipart_threshold: 918 if self._rust_client: 919 run_async_rust_client_method(self._rust_client, "download", key, f) 920 else: 921 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 922 temp_file_path = fp.name 923 fp.write(self._get_object(remote_path)) 924 os.rename(src=temp_file_path, dst=f) 925 return metadata.content_length 926 927 # Download large files using TransferConfig 928 def _invoke_api() -> int: 929 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 930 temp_file_path = fp.name 931 if self._rust_client: 932 run_async_rust_client_method( 933 self._rust_client, "download_multipart_to_file", key, temp_file_path 934 ) 935 else: 936 self._s3_client.download_fileobj( 937 Bucket=bucket, 938 Key=key, 939 Fileobj=fp, 940 Config=self._transfer_config, 941 ) 942 943 os.rename(src=temp_file_path, dst=f) 944 945 return metadata.content_length 946 947 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 948 else: 949 # Download small files 950 if metadata.content_length <= self._transfer_config.multipart_threshold: 951 response = self._get_object(remote_path) 952 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol, 953 # so we need to check whether `.decode()` is available. 954 if isinstance(f, io.StringIO): 955 if hasattr(response, "decode"): 956 f.write(response.decode("utf-8")) 957 else: 958 f.write(codecs.decode(memoryview(response), "utf-8")) 959 else: 960 f.write(response) 961 return metadata.content_length 962 963 # Download large files using TransferConfig 964 bucket, key = split_path(remote_path) 965 966 def _invoke_api() -> int: 967 self._s3_client.download_fileobj( 968 Bucket=bucket, 969 Key=key, 970 Fileobj=f, 971 Config=self._transfer_config, 972 ) 973 974 return metadata.content_length 975 976 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 977 978 def _generate_presigned_url( 979 self, 980 path: str, 981 *, 982 method: str = "GET", 983 signer_type: Optional[SignerType] = None, 984 signer_options: Optional[dict[str, Any]] = None, 985 ) -> str: 986 options = signer_options or {} 987 bucket, key = split_path(path) 988 989 if signer_type is None or signer_type == SignerType.S3: 990 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN)) 991 cache_key: tuple = (SignerType.S3, bucket, expires_in) 992 if cache_key not in self._signer_cache: 993 self._signer_cache[cache_key] = S3URLSigner(self._s3_client, bucket, expires_in=expires_in) 994 elif signer_type == SignerType.CLOUDFRONT: 995 cache_key = (SignerType.CLOUDFRONT, frozenset(options.items())) 996 if cache_key not in self._signer_cache: 997 self._signer_cache[cache_key] = CloudFrontURLSigner(**options) 998 else: 999 raise ValueError(f"Unsupported signer type for S3 provider: {signer_type!r}") 1000 1001 return self._signer_cache[cache_key].generate_presigned_url(key, method=method)