Source code for multistorageclient.providers.s3

   1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
   2# SPDX-License-Identifier: Apache-2.0
   3#
   4# Licensed under the Apache License, Version 2.0 (the "License");
   5# you may not use this file except in compliance with the License.
   6# You may obtain a copy of the License at
   7#
   8# http://www.apache.org/licenses/LICENSE-2.0
   9#
  10# Unless required by applicable law or agreed to in writing, software
  11# distributed under the License is distributed on an "AS IS" BASIS,
  12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13# See the License for the specific language governing permissions and
  14# limitations under the License.
  15
  16import codecs
  17import io
  18import os
  19import tempfile
  20from collections.abc import Callable, Iterator
  21from typing import IO, Any, Optional, TypeVar, Union
  22
  23import boto3
  24import botocore
  25from boto3.s3.transfer import TransferConfig
  26from botocore.credentials import RefreshableCredentials
  27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
  28from botocore.session import get_session
  29from dateutil.parser import parse as dateutil_parse
  30
  31from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
  32
  33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_MAX_POOL_CONNECTIONS, DEFAULT_READ_TIMEOUT
  34from ..rust_utils import parse_retry_config, run_async_rust_client_method
  35from ..signers import CloudFrontURLSigner, URLSigner
  36from ..telemetry import Telemetry
  37from ..types import (
  38    AWARE_DATETIME_MIN,
  39    Credentials,
  40    CredentialsProvider,
  41    ObjectMetadata,
  42    PreconditionFailedError,
  43    Range,
  44    RetryableError,
  45    SignerType,
  46    SymlinkHandling,
  47)
  48from ..utils import (
  49    safe_makedirs,
  50    split_path,
  51    validate_attributes,
  52)
  53from .base import BaseStorageProvider
  54
  55_T = TypeVar("_T")
  56
  57MiB = 1024 * 1024
  58
  59# Python and Rust share the same multipart_threshold to keep the code simple.
  60MULTIPART_THRESHOLD = 64 * MiB
  61MULTIPART_CHUNKSIZE = 32 * MiB
  62IO_CHUNKSIZE = 32 * MiB
  63PYTHON_MAX_CONCURRENCY = 8
  64MAX_POOL_CONNECTIONS = DEFAULT_MAX_POOL_CONNECTIONS
  65
  66PROVIDER = "s3"
  67
  68EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
  69
  70# Source: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
  71SUPPORTED_CHECKSUM_ALGORITHMS = frozenset({"CRC32", "CRC32C", "SHA1", "SHA256", "CRC64NVME"})
  72
  73
[docs] 74class StaticS3CredentialsProvider(CredentialsProvider): 75 """ 76 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 77 """ 78 79 _access_key: str 80 _secret_key: str 81 _session_token: Optional[str] 82 83 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 84 """ 85 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 86 session token. 87 88 :param access_key: The access key for S3 authentication. 89 :param secret_key: The secret key for S3 authentication. 90 :param session_token: An optional session token for temporary credentials. 91 """ 92 self._access_key = access_key 93 self._secret_key = secret_key 94 self._session_token = session_token 95
[docs] 96 def get_credentials(self) -> Credentials: 97 return Credentials( 98 access_key=self._access_key, 99 secret_key=self._secret_key, 100 token=self._session_token, 101 expiration=None, 102 )
103
[docs] 104 def refresh_credentials(self) -> None: 105 pass
106 107 108DEFAULT_PRESIGN_EXPIRES_IN = 3600 109 110_S3_METHOD_MAPPING: dict[str, str] = { 111 "GET": "get_object", 112 "PUT": "put_object", 113} 114 115
[docs] 116class S3URLSigner(URLSigner): 117 """Generates pre-signed URLs using the boto3 S3 client. 118 119 When the underlying credentials are temporary (STS, IAM role, EC2 instance 120 profile), the effective URL lifetime is the **shorter** of ``expires_in`` 121 and the remaining credential lifetime — boto3 will not warn if the 122 credential expires before ``expires_in``. 123 124 See https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html 125 """ 126 127 def __init__(self, s3_client: Any, bucket: str, expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN) -> None: 128 self._s3_client = s3_client 129 self._bucket = bucket 130 self._expires_in = expires_in 131
[docs] 132 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str: 133 client_method = _S3_METHOD_MAPPING.get(method.upper()) 134 if client_method is None: 135 raise ValueError(f"Unsupported method for S3 presigning: {method!r}") 136 return self._s3_client.generate_presigned_url( 137 ClientMethod=client_method, 138 Params={"Bucket": self._bucket, "Key": path}, 139 ExpiresIn=self._expires_in, 140 )
141 142
[docs] 143class S3StorageProvider(BaseStorageProvider): 144 """ 145 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 146 """ 147 148 def __init__( 149 self, 150 region_name: str = "", 151 endpoint_url: str = "", 152 base_path: str = "", 153 credentials_provider: Optional[CredentialsProvider] = None, 154 profile_name: Optional[str] = None, 155 config_dict: Optional[dict[str, Any]] = None, 156 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 157 verify: Optional[Union[bool, str]] = None, 158 **kwargs: Any, 159 ) -> None: 160 """ 161 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 162 163 :param region_name: The AWS region where the S3 bucket is located. 164 :param endpoint_url: The custom endpoint URL for the S3 service. 165 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 166 :param credentials_provider: The provider to retrieve S3 credentials. 167 :param profile_name: AWS shared configuration + credentials files profile. For :py:class:`boto3.session.Session`. Ignored if ``credentials_provider`` is set. 168 :param config_dict: Resolved MSC config. 169 :param telemetry_provider: A function that provides a telemetry instance. 170 :param verify: Controls SSL certificate verification. 171 Can be ``True`` (verify using system CA bundle, default), ``False`` (skip verification), or a string path to a custom CA certificate bundle. 172 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. 173 When the underlying S3 client should calculate request checksums. 174 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 175 :param response_checksum_validation: For :py:class:`botocore.config.Config`. 176 When the underlying S3 client should validate response checksums. 177 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 178 :param max_pool_connections: For :py:class:`botocore.config.Config`. 179 The maximum number of connections to keep in a connection pool. 180 :param connect_timeout: For :py:class:`botocore.config.Config`. 181 The time in seconds till a timeout exception is thrown when attempting to make a connection. 182 :param read_timeout: For :py:class:`botocore.config.Config`. 183 The time in seconds till a timeout exception is thrown when attempting to read from a connection. 184 :param retries: For :py:class:`botocore.config.Config`. 185 A dictionary for configuration related to retry behavior. 186 :param s3: For :py:class:`botocore.config.Config`. 187 A dictionary of S3 specific configurations. 188 :param checksum_algorithm: Upload-only object integrity algorithm. 189 One of ``"CRC32"``, ``"CRC32C"``, ``"SHA1"``, ``"SHA256"``, ``"CRC64NVME"`` (case-insensitive). 190 When ``rust_client`` is enabled, only ``"SHA256"`` is accepted. 191 """ 192 super().__init__( 193 base_path=base_path, 194 provider_name=PROVIDER, 195 config_dict=config_dict, 196 telemetry_provider=telemetry_provider, 197 ) 198 199 self._region_name = region_name 200 self._endpoint_url = endpoint_url 201 self._credentials_provider = credentials_provider 202 self._profile_name = profile_name 203 self._verify = verify 204 205 self._signature_version = kwargs.get("signature_version", "s3v4") 206 self._s3_client = self._create_s3_client( 207 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 208 response_checksum_validation=kwargs.get("response_checksum_validation"), 209 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS), 210 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT), 211 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT), 212 retries=kwargs.get("retries"), 213 s3=kwargs.get("s3"), 214 ) 215 self._transfer_config = TransferConfig( 216 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 217 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)), 218 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)), 219 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)), 220 use_threads=True, 221 ) 222 self._multipart_threshold = self._transfer_config.multipart_threshold 223 224 self._signer_cache: dict[tuple, URLSigner] = {} 225 226 self._checksum_algorithm: Optional[str] = self._validate_checksum_algorithm(kwargs.get("checksum_algorithm")) 227 228 self._rust_client = None 229 if "rust_client" in kwargs: 230 # Inherit the rust client options from the kwargs 231 rust_client_options = kwargs["rust_client"] 232 if "max_pool_connections" in kwargs: 233 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"] 234 if "max_concurrency" in kwargs: 235 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 236 if "multipart_chunksize" in kwargs: 237 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 238 if "read_timeout" in kwargs: 239 rust_client_options["read_timeout"] = kwargs["read_timeout"] 240 if "connect_timeout" in kwargs: 241 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 242 if "checksum_algorithm" in kwargs: 243 rust_client_options["checksum_algorithm"] = kwargs["checksum_algorithm"] 244 if self._signature_version == "UNSIGNED": 245 rust_client_options["skip_signature"] = True 246 247 # Rust client only supports SHA256 today (object_store limitation). 248 rust_checksum = rust_client_options.get("checksum_algorithm") 249 if rust_checksum is not None: 250 if str(rust_checksum).upper() != "SHA256": 251 raise ValueError( 252 f"Rust client only supports checksum_algorithm='SHA256' today " 253 f"(object_store limitation), got '{rust_checksum!r}'. " 254 f"Disable rust_client or use SHA256." 255 ) 256 rust_client_options["checksum_algorithm"] = "sha256" 257 self._rust_client = self._create_rust_client(rust_client_options) 258 259 @staticmethod 260 def _validate_checksum_algorithm(value: Optional[str]) -> Optional[str]: 261 """ 262 Normalizes the optional ``checksum_algorithm`` to uppercase, or returns ``None`` if disabled. 263 """ 264 if value is None: 265 return None 266 if not isinstance(value, str) or not value: 267 raise ValueError( 268 f"checksum_algorithm must be one of {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)} or None, got {value!r}." 269 ) 270 normalized = value.upper() 271 if normalized not in SUPPORTED_CHECKSUM_ALGORITHMS: 272 raise ValueError( 273 f"Unsupported checksum_algorithm '{value}'. Supported: {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)}." 274 ) 275 return normalized 276 277 def _is_directory_bucket(self, bucket: str) -> bool: 278 """ 279 Determines if the bucket is a directory bucket based on bucket name. 280 """ 281 # S3 Express buckets have a specific naming convention 282 return "--x-s3" in bucket 283 284 def _create_s3_client( 285 self, 286 request_checksum_calculation: Optional[str] = None, 287 response_checksum_validation: Optional[str] = None, 288 max_pool_connections: int = MAX_POOL_CONNECTIONS, 289 connect_timeout: Union[float, int, None] = None, 290 read_timeout: Union[float, int, None] = None, 291 retries: Optional[dict[str, Any]] = None, 292 s3: Optional[dict[str, Any]] = None, 293 ): 294 """ 295 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 296 297 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 298 :param response_checksum_validation: For :py:class:`botocore.config.Config`. When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 299 :param max_pool_connections: For :py:class:`botocore.config.Config`. The maximum number of connections to keep in a connection pool. 300 :param connect_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to make a connection. 301 :param read_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to read from a connection. 302 :param retries: For :py:class:`botocore.config.Config`. A dictionary for configuration related to retry behavior. 303 :param s3: For :py:class:`botocore.config.Config`. A dictionary of S3 specific configurations. 304 305 :return: The configured S3 client. 306 """ 307 options = { 308 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 309 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 310 max_pool_connections=max_pool_connections, 311 connect_timeout=connect_timeout, 312 read_timeout=read_timeout, 313 retries=retries or {"mode": "standard"}, 314 request_checksum_calculation=request_checksum_calculation, 315 response_checksum_validation=response_checksum_validation, 316 s3=s3, 317 ), 318 } 319 320 if self._region_name: 321 options["region_name"] = self._region_name 322 323 if self._endpoint_url: 324 options["endpoint_url"] = self._endpoint_url 325 326 if self._verify is not None: 327 options["verify"] = self._verify 328 329 if self._credentials_provider: 330 creds = self._fetch_credentials() 331 if "expiry_time" in creds and creds["expiry_time"]: 332 # Use RefreshableCredentials if expiry_time provided. 333 refreshable_credentials = RefreshableCredentials.create_from_metadata( 334 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 335 ) 336 337 botocore_session = get_session() 338 botocore_session._credentials = refreshable_credentials 339 340 boto3_session = boto3.Session(botocore_session=botocore_session) 341 342 return boto3_session.client("s3", **options) 343 else: 344 # Add static credentials to the options dictionary 345 options["aws_access_key_id"] = creds["access_key"] 346 options["aws_secret_access_key"] = creds["secret_key"] 347 if creds["token"]: 348 options["aws_session_token"] = creds["token"] 349 350 if self._signature_version: 351 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 352 signature_version=botocore.UNSIGNED 353 if self._signature_version == "UNSIGNED" 354 else self._signature_version 355 ) 356 options["config"] = options["config"].merge(signature_config) 357 358 # Fallback to standard credential chain. 359 session = boto3.Session(profile_name=self._profile_name) 360 return session.client("s3", **options) 361 362 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 363 """ 364 Creates and configures the rust client, using refreshable credentials if possible. 365 """ 366 configs = dict(rust_client_options) if rust_client_options else {} 367 368 # Extract and parse retry configuration 369 retry_config = parse_retry_config(configs) 370 371 if self._region_name and "region_name" not in configs: 372 configs["region_name"] = self._region_name 373 374 if self._endpoint_url and "endpoint_url" not in configs: 375 configs["endpoint_url"] = self._endpoint_url 376 377 if self._profile_name and "profile_name" not in configs: 378 configs["profile_name"] = self._profile_name 379 380 # If the user specifies a bucket, use it. Otherwise, use the base path. 381 if "bucket" not in configs: 382 bucket, _ = split_path(self._base_path) 383 configs["bucket"] = bucket 384 385 if "max_pool_connections" not in configs: 386 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS 387 388 return RustClient( 389 provider=PROVIDER, 390 configs=configs, 391 credentials_provider=self._credentials_provider, 392 retry=retry_config, 393 ) 394 395 def _fetch_credentials(self) -> dict: 396 """ 397 Refreshes the S3 client if the current credentials are expired. 398 """ 399 if not self._credentials_provider: 400 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 401 self._credentials_provider.refresh_credentials() 402 credentials = self._credentials_provider.get_credentials() 403 return { 404 "access_key": credentials.access_key, 405 "secret_key": credentials.secret_key, 406 "token": credentials.token, 407 "expiry_time": credentials.expiration, 408 } 409 410 def _translate_errors( 411 self, 412 func: Callable[[], _T], 413 operation: str, 414 bucket: str, 415 key: str, 416 ) -> _T: 417 """ 418 Translates errors like timeouts and client errors. 419 420 :param func: The function that performs the actual S3 operation. 421 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 422 :param bucket: The name of the S3 bucket involved in the operation. 423 :param key: The key of the object within the S3 bucket. 424 425 :return: The result of the S3 operation, typically the return value of the `func` callable. 426 """ 427 try: 428 return func() 429 except ClientError as error: 430 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 431 request_id = error.response["ResponseMetadata"].get("RequestId") 432 host_id = error.response["ResponseMetadata"].get("HostId") 433 error_code = error.response["Error"]["Code"] 434 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 435 436 if status_code == 404: 437 if error_code == "NoSuchUpload": 438 error_message = error.response["Error"]["Message"] 439 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 440 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 441 elif status_code == 412: # Precondition Failed 442 raise PreconditionFailedError( 443 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 444 ) from error 445 elif status_code == 429: 446 raise RetryableError( 447 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 448 ) from error 449 elif status_code == 503: 450 raise RetryableError( 451 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 452 ) from error 453 elif status_code == 501: 454 raise NotImplementedError( 455 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 456 ) from error 457 elif status_code == 408: 458 # 408 Request Timeout is from Google Cloud Storage 459 raise RetryableError( 460 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}" 461 ) from error 462 else: 463 raise RuntimeError( 464 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 465 f"error_type: {type(error).__name__}" 466 ) from error 467 except RustClientError as error: 468 message = error.args[0] 469 status_code = error.args[1] 470 if status_code == 404: 471 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 472 elif status_code == 403: 473 raise PermissionError( 474 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 475 ) from error 476 else: 477 raise RetryableError( 478 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 479 ) from error 480 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 481 raise RetryableError( 482 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 483 f"error_type: {type(error).__name__}" 484 ) from error 485 except RustRetryableError as error: 486 raise RetryableError( 487 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 488 f"error_type: {type(error).__name__}" 489 ) from error 490 except Exception as error: 491 raise RuntimeError( 492 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 493 ) from error 494 495 def _put_object( 496 self, 497 path: str, 498 body: bytes, 499 if_match: Optional[str] = None, 500 if_none_match: Optional[str] = None, 501 attributes: Optional[dict[str, str]] = None, 502 content_type: Optional[str] = None, 503 ) -> int: 504 """ 505 Uploads an object to the specified S3 path. 506 507 :param path: The S3 path where the object will be uploaded. 508 :param body: The content of the object as bytes. 509 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 510 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 511 :param attributes: Optional attributes to attach to the object. 512 :param content_type: Optional Content-Type header value. 513 """ 514 bucket, key = split_path(path) 515 516 def _invoke_api() -> int: 517 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 518 if content_type: 519 kwargs["ContentType"] = content_type 520 if self._is_directory_bucket(bucket): 521 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 522 if if_match: 523 kwargs["IfMatch"] = if_match 524 if if_none_match: 525 kwargs["IfNoneMatch"] = if_none_match 526 validated_attributes = validate_attributes(attributes) 527 if validated_attributes: 528 kwargs["Metadata"] = validated_attributes 529 530 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client 531 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"} 532 if ( 533 self._rust_client 534 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 535 and not path.endswith("/") 536 and all(key not in kwargs for key in rust_unsupported_feature_keys) 537 ): 538 run_async_rust_client_method(self._rust_client, "put", key, body) 539 else: 540 if self._checksum_algorithm: 541 kwargs["ChecksumAlgorithm"] = self._checksum_algorithm 542 self._s3_client.put_object(**kwargs) 543 544 return len(body) 545 546 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 547 548 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 549 bucket, key = split_path(path) 550 551 def _invoke_api() -> bytes: 552 if byte_range: 553 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 554 if self._rust_client: 555 response = run_async_rust_client_method( 556 self._rust_client, 557 "get", 558 key, 559 byte_range, 560 ) 561 return response 562 else: 563 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 564 else: 565 if self._rust_client: 566 response = run_async_rust_client_method(self._rust_client, "get", key) 567 return response 568 else: 569 response = self._s3_client.get_object(Bucket=bucket, Key=key) 570 571 return response["Body"].read() 572 573 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 574 575 def _copy_object(self, src_path: str, dest_path: str) -> int: 576 src_bucket, src_key = split_path(src_path) 577 dest_bucket, dest_key = split_path(dest_path) 578 579 src_object = self._get_object_metadata(src_path) 580 581 def _invoke_api() -> int: 582 self._s3_client.copy( 583 CopySource={"Bucket": src_bucket, "Key": src_key}, 584 Bucket=dest_bucket, 585 Key=dest_key, 586 Config=self._transfer_config, 587 ) 588 589 return src_object.content_length 590 591 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key) 592 593 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 594 bucket, key = split_path(path) 595 596 def _invoke_api() -> None: 597 # Delete conditionally when if_match (etag) is provided; otherwise delete unconditionally 598 if if_match: 599 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 600 else: 601 self._s3_client.delete_object(Bucket=bucket, Key=key) 602 603 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 604 605 def _delete_objects(self, paths: list[str]) -> None: 606 if not paths: 607 return 608 609 by_bucket: dict[str, list[str]] = {} 610 for p in paths: 611 bucket, key = split_path(p) 612 by_bucket.setdefault(bucket, []).append(key) 613 614 S3_BATCH_LIMIT = 1000 615 616 def _invoke_api() -> None: 617 all_errors: list[str] = [] 618 for bucket, keys in by_bucket.items(): 619 for i in range(0, len(keys), S3_BATCH_LIMIT): 620 chunk = keys[i : i + S3_BATCH_LIMIT] 621 response = self._s3_client.delete_objects( 622 Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]} 623 ) 624 errors = response.get("Errors") or [] 625 for e in errors: 626 all_errors.append(f"{bucket}/{e.get('Key', '?')}: {e.get('Code', '')} {e.get('Message', '')}") 627 if all_errors: 628 raise RuntimeError(f"DeleteObjects reported errors: {'; '.join(all_errors)}") 629 630 bucket_desc = "(" + "|".join(by_bucket) + ")" 631 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)" 632 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc) 633 634 def _is_dir(self, path: str) -> bool: 635 # Ensure the path ends with '/' to mimic a directory 636 path = self._append_delimiter(path) 637 638 bucket, key = split_path(path) 639 640 def _invoke_api() -> bool: 641 # List objects with the given prefix 642 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 643 644 # Check if there are any contents or common prefixes 645 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 646 647 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 648 649 def _make_symlink(self, path: str, target: str) -> None: 650 bucket, key = split_path(path) 651 target_bucket, target_key = split_path(target) 652 if bucket != target_bucket: 653 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket}' -> '{target_bucket}'.") 654 relative_target = ObjectMetadata.encode_symlink_target(key, target_key) 655 656 def _invoke_api() -> None: 657 self._s3_client.put_object( 658 Bucket=bucket, 659 Key=key, 660 Body=b"", 661 Metadata={"msc-symlink-target": relative_target}, 662 ) 663 664 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 665 666 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 667 bucket, key = split_path(path) 668 if path.endswith("/") or (bucket and not key): 669 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 670 # which metadata is not guaranteed to exist for cases such as 671 # "virtual prefix" that was never explicitly created. 672 if self._is_dir(path): 673 return ObjectMetadata( 674 key=path, 675 type="directory", 676 content_length=0, 677 last_modified=AWARE_DATETIME_MIN, 678 ) 679 else: 680 raise FileNotFoundError(f"Directory {path} does not exist.") 681 else: 682 683 def _invoke_api() -> ObjectMetadata: 684 response = self._s3_client.head_object(Bucket=bucket, Key=key) 685 user_metadata = response.get("Metadata") 686 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None 687 688 return ObjectMetadata( 689 key=path, 690 type="file", 691 content_length=response["ContentLength"], 692 content_type=response.get("ContentType"), 693 last_modified=response.get("LastModified", AWARE_DATETIME_MIN), 694 etag=response["ETag"].strip('"') if "ETag" in response else None, 695 storage_class=response.get("StorageClass"), 696 metadata=user_metadata, 697 symlink_target=symlink_target, 698 ) 699 700 try: 701 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 702 except FileNotFoundError as error: 703 if strict: 704 # If the object does not exist on the given path, we will append a trailing slash and 705 # check if the path is a directory. 706 path = self._append_delimiter(path) 707 if self._is_dir(path): 708 return ObjectMetadata( 709 key=path, 710 type="directory", 711 content_length=0, 712 last_modified=AWARE_DATETIME_MIN, 713 ) 714 raise error 715 716 def _list_objects( 717 self, 718 path: str, 719 start_after: Optional[str] = None, 720 end_at: Optional[str] = None, 721 include_directories: bool = False, 722 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW, 723 ) -> Iterator[ObjectMetadata]: 724 bucket, prefix = split_path(path) 725 726 # Get the prefix of the start_after and end_at paths relative to the bucket. 727 if start_after: 728 _, start_after = split_path(start_after) 729 if end_at: 730 _, end_at = split_path(end_at) 731 732 def _invoke_api() -> Iterator[ObjectMetadata]: 733 paginator = self._s3_client.get_paginator("list_objects_v2") 734 if include_directories: 735 page_iterator = paginator.paginate( 736 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 737 ) 738 else: 739 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 740 741 for page in page_iterator: 742 for item in page.get("CommonPrefixes", []): 743 prefix_key = item["Prefix"].rstrip("/") 744 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes 745 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 746 yield ObjectMetadata( 747 key=os.path.join(bucket, prefix_key), 748 type="directory", 749 content_length=0, 750 last_modified=AWARE_DATETIME_MIN, 751 ) 752 elif end_at is not None and end_at < prefix_key: 753 return 754 755 # S3 guarantees lexicographical order for general purpose buckets (for 756 # normal S3) but not directory buckets (for S3 Express One Zone). 757 for response_object in page.get("Contents", []): 758 key = response_object["Key"] 759 if end_at is None or key <= end_at: 760 if key.endswith("/"): 761 if include_directories: 762 yield ObjectMetadata( 763 key=os.path.join(bucket, key.rstrip("/")), 764 type="directory", 765 content_length=0, 766 last_modified=response_object["LastModified"], 767 ) 768 else: 769 symlink_target = None 770 if response_object.get("Size", 0) == 0: 771 try: 772 meta = self._get_object_metadata(os.path.join(bucket, key)) 773 symlink_target = meta.symlink_target 774 except Exception: 775 symlink_target = None 776 yield ObjectMetadata( 777 key=os.path.join(bucket, key), 778 type="file", 779 content_length=response_object["Size"], 780 last_modified=response_object["LastModified"], 781 etag=response_object["ETag"].strip('"'), 782 storage_class=response_object.get("StorageClass"), 783 symlink_target=symlink_target, 784 ) 785 else: 786 return 787 788 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 789 790 @property 791 def supports_parallel_listing(self) -> bool: 792 """ 793 S3 supports parallel listing via delimiter-based prefix discovery. 794 795 Note: Directory bucket handling is done in list_objects_recursive(). 796 """ 797 return True 798
[docs] 799 def list_objects_recursive( 800 self, 801 path: str = "", 802 start_after: Optional[str] = None, 803 end_at: Optional[str] = None, 804 max_workers: int = 32, 805 look_ahead: int = 2, 806 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW, 807 ) -> Iterator[ObjectMetadata]: 808 """ 809 List all objects recursively using parallel prefix discovery for improved performance. 810 811 For S3, uses the Rust client's list_recursive when available for maximum performance. 812 Falls back to Python implementation otherwise. 813 814 Returns files only (no directories), in lexicographic order. 815 816 :param symlink_handling: How to handle symbolic links during listing. 817 """ 818 if (start_after is not None) and (end_at is not None) and not (start_after < end_at): 819 raise ValueError(f"start_after ({start_after}) must be before end_at ({end_at})!") 820 821 full_path = self._prepend_base_path(path) 822 bucket, prefix = split_path(full_path) 823 824 if self._is_directory_bucket(bucket): 825 yield from self.list_objects( 826 path, start_after, end_at, include_directories=False, symlink_handling=symlink_handling 827 ) 828 return 829 830 if self._rust_client: 831 yield from self._emit_metrics( 832 operation=BaseStorageProvider._Operation.LIST, 833 f=lambda: self._list_objects_recursive_rust(path, full_path, bucket, start_after, end_at, max_workers), 834 ) 835 else: 836 yield from super().list_objects_recursive( 837 path, start_after, end_at, max_workers, look_ahead, symlink_handling 838 )
839 840 def _list_objects_recursive_rust( 841 self, 842 path: str, 843 full_path: str, 844 bucket: str, 845 start_after: Optional[str], 846 end_at: Optional[str], 847 max_workers: int, 848 ) -> Iterator[ObjectMetadata]: 849 """ 850 Use Rust client's list_recursive for parallel listing. 851 852 The Rust client already handles parallel listing internally. 853 Returns files only in lexicographic order. 854 """ 855 _, prefix = split_path(full_path) 856 857 def _invoke_api() -> Iterator[ObjectMetadata]: 858 result = run_async_rust_client_method( 859 self._rust_client, 860 "list_recursive", 861 [prefix] if prefix else [""], 862 max_concurrency=max_workers, 863 ) 864 865 start_after_full = self._prepend_base_path(start_after) if start_after else None 866 end_at_full = self._prepend_base_path(end_at) if end_at else None 867 868 for obj in result.objects: 869 full_key = os.path.join(bucket, obj.key) 870 871 if start_after_full and full_key <= start_after_full: 872 continue 873 if end_at_full and full_key > end_at_full: 874 break 875 876 relative_key = full_key.removeprefix(self._base_path).lstrip("/") 877 878 yield ObjectMetadata( 879 key=relative_key, 880 content_length=obj.content_length, 881 last_modified=dateutil_parse(obj.last_modified), 882 type="file" if obj.object_type == "object" else obj.object_type, 883 etag=obj.etag, 884 ) 885 886 yield from self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 887 888 def _upload_file( 889 self, 890 remote_path: str, 891 f: Union[str, IO], 892 attributes: Optional[dict[str, str]] = None, 893 content_type: Optional[str] = None, 894 ) -> int: 895 file_size: int = 0 896 897 if isinstance(f, str): 898 bucket, key = split_path(remote_path) 899 file_size = os.path.getsize(f) 900 901 # Upload small files 902 if file_size <= self._multipart_threshold: 903 if self._rust_client and not attributes and not content_type: 904 run_async_rust_client_method(self._rust_client, "upload", f, key) 905 else: 906 with open(f, "rb") as fp: 907 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type) 908 return file_size 909 910 # Upload large files using TransferConfig 911 def _invoke_api() -> int: 912 extra_args = {} 913 if content_type: 914 extra_args["ContentType"] = content_type 915 if self._is_directory_bucket(bucket): 916 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 917 validated_attributes = validate_attributes(attributes) 918 if validated_attributes: 919 extra_args["Metadata"] = validated_attributes 920 if self._rust_client and not extra_args: 921 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 922 else: 923 if self._checksum_algorithm: 924 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm 925 self._s3_client.upload_file( 926 Filename=f, 927 Bucket=bucket, 928 Key=key, 929 Config=self._transfer_config, 930 ExtraArgs=extra_args, 931 ) 932 933 return file_size 934 935 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 936 else: 937 # Upload small files 938 f.seek(0, io.SEEK_END) 939 file_size = f.tell() 940 f.seek(0) 941 942 if file_size <= self._multipart_threshold: 943 if isinstance(f, io.StringIO): 944 self._put_object( 945 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type 946 ) 947 else: 948 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type) 949 return file_size 950 951 # Upload large files using TransferConfig 952 bucket, key = split_path(remote_path) 953 954 def _invoke_api() -> int: 955 extra_args = {} 956 if content_type: 957 extra_args["ContentType"] = content_type 958 if self._is_directory_bucket(bucket): 959 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 960 validated_attributes = validate_attributes(attributes) 961 if validated_attributes: 962 extra_args["Metadata"] = validated_attributes 963 964 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args: 965 data = f.getbuffer() 966 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data) 967 else: 968 if self._checksum_algorithm: 969 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm 970 self._s3_client.upload_fileobj( 971 Fileobj=f, 972 Bucket=bucket, 973 Key=key, 974 Config=self._transfer_config, 975 ExtraArgs=extra_args, 976 ) 977 978 return file_size 979 980 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 981 982 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 983 if metadata is None: 984 metadata = self._get_object_metadata(remote_path) 985 986 if isinstance(f, str): 987 bucket, key = split_path(remote_path) 988 if os.path.dirname(f): 989 safe_makedirs(os.path.dirname(f)) 990 991 # Download small files 992 if metadata.content_length <= self._multipart_threshold: 993 if self._rust_client: 994 run_async_rust_client_method(self._rust_client, "download", key, f) 995 else: 996 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 997 temp_file_path = fp.name 998 fp.write(self._get_object(remote_path)) 999 os.rename(src=temp_file_path, dst=f) 1000 return metadata.content_length 1001 1002 # Download large files using TransferConfig 1003 def _invoke_api() -> int: 1004 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 1005 temp_file_path = fp.name 1006 if self._rust_client: 1007 run_async_rust_client_method( 1008 self._rust_client, "download_multipart_to_file", key, temp_file_path 1009 ) 1010 else: 1011 self._s3_client.download_fileobj( 1012 Bucket=bucket, 1013 Key=key, 1014 Fileobj=fp, 1015 Config=self._transfer_config, 1016 ) 1017 1018 os.rename(src=temp_file_path, dst=f) 1019 1020 return metadata.content_length 1021 1022 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 1023 else: 1024 # Download small files 1025 if metadata.content_length <= self._multipart_threshold: 1026 response = self._get_object(remote_path) 1027 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol, 1028 # so we need to check whether `.decode()` is available. 1029 if isinstance(f, io.StringIO): 1030 if hasattr(response, "decode"): 1031 f.write(response.decode("utf-8")) 1032 else: 1033 f.write(codecs.decode(memoryview(response), "utf-8")) 1034 else: 1035 f.write(response) 1036 return metadata.content_length 1037 1038 # Download large files using TransferConfig 1039 bucket, key = split_path(remote_path) 1040 1041 def _invoke_api() -> int: 1042 self._s3_client.download_fileobj( 1043 Bucket=bucket, 1044 Key=key, 1045 Fileobj=f, 1046 Config=self._transfer_config, 1047 ) 1048 1049 return metadata.content_length 1050 1051 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 1052 1053 def _generate_presigned_url( 1054 self, 1055 path: str, 1056 *, 1057 method: str = "GET", 1058 signer_type: Optional[SignerType] = None, 1059 signer_options: Optional[dict[str, Any]] = None, 1060 ) -> str: 1061 options = signer_options or {} 1062 bucket, key = split_path(path) 1063 1064 if signer_type is None or signer_type == SignerType.S3: 1065 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN)) 1066 cache_key: tuple = (SignerType.S3, bucket, expires_in) 1067 if cache_key not in self._signer_cache: 1068 self._signer_cache[cache_key] = S3URLSigner(self._s3_client, bucket, expires_in=expires_in) 1069 elif signer_type == SignerType.CLOUDFRONT: 1070 cache_key = (SignerType.CLOUDFRONT, frozenset(options.items())) 1071 if cache_key not in self._signer_cache: 1072 self._signer_cache[cache_key] = CloudFrontURLSigner(**options) 1073 else: 1074 raise ValueError(f"Unsupported signer type for S3 provider: {signer_type!r}") 1075 1076 return self._signer_cache[cache_key].generate_presigned_url(key, method=method)