Source code for multistorageclient.providers.s3

   1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
   2# SPDX-License-Identifier: Apache-2.0
   3#
   4# Licensed under the Apache License, Version 2.0 (the "License");
   5# you may not use this file except in compliance with the License.
   6# You may obtain a copy of the License at
   7#
   8# http://www.apache.org/licenses/LICENSE-2.0
   9#
  10# Unless required by applicable law or agreed to in writing, software
  11# distributed under the License is distributed on an "AS IS" BASIS,
  12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13# See the License for the specific language governing permissions and
  14# limitations under the License.
  15
  16import codecs
  17import io
  18import os
  19import tempfile
  20from collections.abc import Callable, Iterator
  21from typing import IO, Any, Optional, TypeVar, Union
  22
  23import boto3
  24import botocore
  25from boto3.s3.transfer import TransferConfig
  26from botocore.credentials import RefreshableCredentials
  27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
  28from botocore.session import get_session
  29from dateutil.parser import parse as dateutil_parse
  30
  31from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
  32
  33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_MAX_POOL_CONNECTIONS, DEFAULT_READ_TIMEOUT
  34from ..rust_utils import parse_retry_config, run_async_rust_client_method
  35from ..signers import CloudFrontURLSigner, URLSigner
  36from ..telemetry import Telemetry
  37from ..types import (
  38    AWARE_DATETIME_MIN,
  39    Credentials,
  40    CredentialsProvider,
  41    ObjectMetadata,
  42    PreconditionFailedError,
  43    Range,
  44    RetryableError,
  45    SignerType,
  46    SymlinkHandling,
  47)
  48from ..utils import (
  49    safe_makedirs,
  50    split_path,
  51    validate_attributes,
  52)
  53from .base import BaseStorageProvider
  54
  55_T = TypeVar("_T")
  56
  57MiB = 1024 * 1024
  58
  59# Python and Rust share the same multipart_threshold to keep the code simple.
  60MULTIPART_THRESHOLD = 64 * MiB
  61MULTIPART_CHUNKSIZE = 32 * MiB
  62IO_CHUNKSIZE = 32 * MiB
  63PYTHON_MAX_CONCURRENCY = 8
  64MAX_POOL_CONNECTIONS = DEFAULT_MAX_POOL_CONNECTIONS
  65
  66PROVIDER = "s3"
  67
  68EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
  69
  70# Source: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
  71SUPPORTED_CHECKSUM_ALGORITHMS = frozenset({"CRC32", "CRC32C", "SHA1", "SHA256", "CRC64NVME"})
  72
  73
[docs] 74class StaticS3CredentialsProvider(CredentialsProvider): 75 """ 76 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 77 """ 78 79 _access_key: str 80 _secret_key: str 81 _session_token: Optional[str] 82 83 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 84 """ 85 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 86 session token. 87 88 :param access_key: The access key for S3 authentication. 89 :param secret_key: The secret key for S3 authentication. 90 :param session_token: An optional session token for temporary credentials. 91 """ 92 self._access_key = access_key 93 self._secret_key = secret_key 94 self._session_token = session_token 95
[docs] 96 def get_credentials(self) -> Credentials: 97 return Credentials( 98 access_key=self._access_key, 99 secret_key=self._secret_key, 100 token=self._session_token, 101 expiration=None, 102 )
103
[docs] 104 def refresh_credentials(self) -> None: 105 pass
106 107 108DEFAULT_PRESIGN_EXPIRES_IN = 3600 109 110_S3_METHOD_MAPPING: dict[str, str] = { 111 "GET": "get_object", 112 "PUT": "put_object", 113} 114 115
[docs] 116class S3URLSigner(URLSigner): 117 """Generates pre-signed URLs using the boto3 S3 client. 118 119 When the underlying credentials are temporary (STS, IAM role, EC2 instance 120 profile), the effective URL lifetime is the **shorter** of ``expires_in`` 121 and the remaining credential lifetime — boto3 will not warn if the 122 credential expires before ``expires_in``. 123 124 See https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html 125 """ 126 127 def __init__(self, s3_client: Any, bucket: str, expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN) -> None: 128 self._s3_client = s3_client 129 self._bucket = bucket 130 self._expires_in = expires_in 131
[docs] 132 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str: 133 client_method = _S3_METHOD_MAPPING.get(method.upper()) 134 if client_method is None: 135 raise ValueError(f"Unsupported method for S3 presigning: {method!r}") 136 return self._s3_client.generate_presigned_url( 137 ClientMethod=client_method, 138 Params={"Bucket": self._bucket, "Key": path}, 139 ExpiresIn=self._expires_in, 140 )
141 142
[docs] 143class S3StorageProvider(BaseStorageProvider): 144 """ 145 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 146 """ 147 148 def __init__( 149 self, 150 region_name: str = "", 151 endpoint_url: str = "", 152 base_path: str = "", 153 credentials_provider: Optional[CredentialsProvider] = None, 154 profile_name: Optional[str] = None, 155 config_dict: Optional[dict[str, Any]] = None, 156 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 157 verify: Optional[Union[bool, str]] = None, 158 **kwargs: Any, 159 ) -> None: 160 """ 161 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 162 163 :param region_name: The AWS region where the S3 bucket is located. 164 :param endpoint_url: The custom endpoint URL for the S3 service. 165 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 166 :param credentials_provider: The provider to retrieve S3 credentials. 167 :param profile_name: AWS shared configuration + credentials files profile. For :py:class:`boto3.session.Session`. Ignored if ``credentials_provider`` is set. 168 :param config_dict: Resolved MSC config. 169 :param telemetry_provider: A function that provides a telemetry instance. 170 :param verify: Controls SSL certificate verification. 171 Can be ``True`` (verify using system CA bundle, default), ``False`` (skip verification), or a string path to a custom CA certificate bundle. 172 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. 173 When the underlying S3 client should calculate request checksums. 174 :param response_checksum_validation: For :py:class:`botocore.config.Config`. 175 When the underlying S3 client should validate response checksums. 176 :param max_pool_connections: For :py:class:`botocore.config.Config`. 177 The maximum number of connections to keep in a connection pool. 178 :param connect_timeout: For :py:class:`botocore.config.Config`. 179 The time in seconds till a timeout exception is thrown when attempting to make a connection. 180 :param read_timeout: For :py:class:`botocore.config.Config`. 181 The time in seconds till a timeout exception is thrown when attempting to read from a connection. 182 :param retries: For :py:class:`botocore.config.Config`. 183 A dictionary for configuration related to retry behavior. 184 :param s3: For :py:class:`botocore.config.Config`. 185 A dictionary of S3 specific configurations. 186 :param checksum_algorithm: Upload-only object integrity algorithm. 187 One of ``"CRC32"``, ``"CRC32C"``, ``"SHA1"``, ``"SHA256"``, ``"CRC64NVME"`` (case-insensitive). 188 When ``rust_client`` is enabled, only ``"SHA256"`` is accepted. 189 :param multipart_threshold: For :py:class:`boto3.s3.transfer.TransferConfig`. 190 The transfer size threshold for which multipart uploads, downloads, and copies will automatically be triggered. 191 :param max_concurrency: For :py:class:`boto3.s3.transfer.TransferConfig`. 192 The maximum number of threads that will be making requests to perform a transfer. 193 If ``use_threads`` is set to ``False``, the value provided is ignored as the transfer will only ever use the current thread. 194 :param multipart_chunksize: For :py:class:`boto3.s3.transfer.TransferConfig`. 195 The partition size of each part for a multipart transfer. 196 :param io_chunksize: For :py:class:`boto3.s3.transfer.TransferConfig`. 197 The max size of each chunk in the ``io`` queue. Currently, this is the size used when read is called on the downloaded stream as well. 198 Note: This value is ignored when resolved transfer manager type is CRTTransferManager. 199 """ 200 super().__init__( 201 base_path=base_path, 202 provider_name=PROVIDER, 203 config_dict=config_dict, 204 telemetry_provider=telemetry_provider, 205 ) 206 207 self._region_name = region_name 208 self._endpoint_url = endpoint_url 209 self._credentials_provider = credentials_provider 210 self._profile_name = profile_name 211 self._verify = verify 212 213 self._signature_version = kwargs.get("signature_version", "s3v4") 214 self._s3_client = self._create_s3_client( 215 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 216 response_checksum_validation=kwargs.get("response_checksum_validation"), 217 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS), 218 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT), 219 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT), 220 retries=kwargs.get("retries"), 221 s3=kwargs.get("s3"), 222 ) 223 self._transfer_config = TransferConfig( 224 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 225 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)), 226 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)), 227 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)), 228 use_threads=True, 229 ) 230 self._multipart_threshold = self._transfer_config.multipart_threshold 231 232 self._signer_cache: dict[tuple, URLSigner] = {} 233 234 self._checksum_algorithm: Optional[str] = self._validate_checksum_algorithm(kwargs.get("checksum_algorithm")) 235 236 self._rust_client = None 237 if "rust_client" in kwargs: 238 # Inherit the rust client options from the kwargs 239 rust_client_options = kwargs["rust_client"] 240 if "max_pool_connections" in kwargs: 241 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"] 242 if "max_concurrency" in kwargs: 243 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 244 if "multipart_chunksize" in kwargs: 245 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 246 if "read_timeout" in kwargs: 247 rust_client_options["read_timeout"] = kwargs["read_timeout"] 248 if "connect_timeout" in kwargs: 249 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 250 if "checksum_algorithm" in kwargs: 251 rust_client_options["checksum_algorithm"] = kwargs["checksum_algorithm"] 252 if self._signature_version == "UNSIGNED": 253 rust_client_options["skip_signature"] = True 254 255 # Rust client only supports SHA256 today (object_store limitation). 256 rust_checksum = rust_client_options.get("checksum_algorithm") 257 if rust_checksum is not None: 258 if str(rust_checksum).upper() != "SHA256": 259 raise ValueError( 260 f"Rust client only supports checksum_algorithm='SHA256' today " 261 f"(object_store limitation), got '{rust_checksum!r}'. " 262 f"Disable rust_client or use SHA256." 263 ) 264 rust_client_options["checksum_algorithm"] = "sha256" 265 self._rust_client = self._create_rust_client(rust_client_options) 266 267 @staticmethod 268 def _validate_checksum_algorithm(value: Optional[str]) -> Optional[str]: 269 """ 270 Normalizes the optional ``checksum_algorithm`` to uppercase, or returns ``None`` if disabled. 271 """ 272 if value is None: 273 return None 274 if not isinstance(value, str) or not value: 275 raise ValueError( 276 f"checksum_algorithm must be one of {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)} or None, got {value!r}." 277 ) 278 normalized = value.upper() 279 if normalized not in SUPPORTED_CHECKSUM_ALGORITHMS: 280 raise ValueError( 281 f"Unsupported checksum_algorithm '{value}'. Supported: {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)}." 282 ) 283 return normalized 284 285 def _is_directory_bucket(self, bucket: str) -> bool: 286 """ 287 Determines if the bucket is a directory bucket based on bucket name. 288 """ 289 # S3 Express buckets have a specific naming convention 290 return "--x-s3" in bucket 291 292 def _create_s3_client( 293 self, 294 request_checksum_calculation: Optional[str] = None, 295 response_checksum_validation: Optional[str] = None, 296 max_pool_connections: int = MAX_POOL_CONNECTIONS, 297 connect_timeout: Union[float, int, None] = None, 298 read_timeout: Union[float, int, None] = None, 299 retries: Optional[dict[str, Any]] = None, 300 s3: Optional[dict[str, Any]] = None, 301 ): 302 """ 303 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 304 305 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 306 :param response_checksum_validation: For :py:class:`botocore.config.Config`. When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 307 :param max_pool_connections: For :py:class:`botocore.config.Config`. The maximum number of connections to keep in a connection pool. 308 :param connect_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to make a connection. 309 :param read_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to read from a connection. 310 :param retries: For :py:class:`botocore.config.Config`. A dictionary for configuration related to retry behavior. 311 :param s3: For :py:class:`botocore.config.Config`. A dictionary of S3 specific configurations. 312 313 :return: The configured S3 client. 314 """ 315 options = { 316 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 317 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 318 max_pool_connections=max_pool_connections, 319 connect_timeout=connect_timeout, 320 read_timeout=read_timeout, 321 retries=retries or {"mode": "standard"}, 322 request_checksum_calculation=request_checksum_calculation, 323 response_checksum_validation=response_checksum_validation, 324 s3=s3, 325 ), 326 } 327 328 if self._region_name: 329 options["region_name"] = self._region_name 330 331 if self._endpoint_url: 332 options["endpoint_url"] = self._endpoint_url 333 334 if self._verify is not None: 335 options["verify"] = self._verify 336 337 if self._credentials_provider: 338 creds = self._fetch_credentials() 339 if "expiry_time" in creds and creds["expiry_time"]: 340 # Use RefreshableCredentials if expiry_time provided. 341 refreshable_credentials = RefreshableCredentials.create_from_metadata( 342 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 343 ) 344 345 botocore_session = get_session() 346 botocore_session._credentials = refreshable_credentials 347 348 boto3_session = boto3.Session(botocore_session=botocore_session) 349 350 return boto3_session.client("s3", **options) 351 else: 352 # Add static credentials to the options dictionary 353 options["aws_access_key_id"] = creds["access_key"] 354 options["aws_secret_access_key"] = creds["secret_key"] 355 if creds["token"]: 356 options["aws_session_token"] = creds["token"] 357 358 if self._signature_version: 359 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 360 signature_version=botocore.UNSIGNED 361 if self._signature_version == "UNSIGNED" 362 else self._signature_version 363 ) 364 options["config"] = options["config"].merge(signature_config) 365 366 # Fallback to standard credential chain. 367 session = boto3.Session(profile_name=self._profile_name) 368 return session.client("s3", **options) 369 370 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 371 """ 372 Creates and configures the rust client, using refreshable credentials if possible. 373 """ 374 configs = dict(rust_client_options) if rust_client_options else {} 375 376 # Extract and parse retry configuration 377 retry_config = parse_retry_config(configs) 378 379 if self._region_name and "region_name" not in configs: 380 configs["region_name"] = self._region_name 381 382 if self._endpoint_url and "endpoint_url" not in configs: 383 configs["endpoint_url"] = self._endpoint_url 384 385 if self._profile_name and "profile_name" not in configs: 386 configs["profile_name"] = self._profile_name 387 388 # If the user specifies a bucket, use it. Otherwise, use the base path. 389 if "bucket" not in configs: 390 bucket, _ = split_path(self._base_path) 391 configs["bucket"] = bucket 392 393 if "max_pool_connections" not in configs: 394 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS 395 396 return RustClient( 397 provider=PROVIDER, 398 configs=configs, 399 credentials_provider=self._credentials_provider, 400 retry=retry_config, 401 ) 402 403 def _fetch_credentials(self) -> dict: 404 """ 405 Refreshes the S3 client if the current credentials are expired. 406 """ 407 if not self._credentials_provider: 408 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 409 self._credentials_provider.refresh_credentials() 410 credentials = self._credentials_provider.get_credentials() 411 return { 412 "access_key": credentials.access_key, 413 "secret_key": credentials.secret_key, 414 "token": credentials.token, 415 "expiry_time": credentials.expiration, 416 } 417 418 def _translate_errors( 419 self, 420 func: Callable[[], _T], 421 operation: str, 422 bucket: str, 423 key: str, 424 ) -> _T: 425 """ 426 Translates errors like timeouts and client errors. 427 428 :param func: The function that performs the actual S3 operation. 429 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 430 :param bucket: The name of the S3 bucket involved in the operation. 431 :param key: The key of the object within the S3 bucket. 432 433 :return: The result of the S3 operation, typically the return value of the `func` callable. 434 """ 435 try: 436 return func() 437 except ClientError as error: 438 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 439 request_id = error.response["ResponseMetadata"].get("RequestId") 440 host_id = error.response["ResponseMetadata"].get("HostId") 441 error_code = error.response["Error"]["Code"] 442 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 443 444 if status_code == 404: 445 if error_code == "NoSuchUpload": 446 error_message = error.response["Error"]["Message"] 447 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 448 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 449 elif status_code == 412: # Precondition Failed 450 raise PreconditionFailedError( 451 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 452 ) from error 453 elif status_code == 429: 454 raise RetryableError( 455 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 456 ) from error 457 elif status_code == 503: 458 raise RetryableError( 459 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 460 ) from error 461 elif status_code == 501: 462 raise NotImplementedError( 463 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 464 ) from error 465 elif status_code == 408: 466 # 408 Request Timeout is from Google Cloud Storage 467 raise RetryableError( 468 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}" 469 ) from error 470 else: 471 raise RuntimeError( 472 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 473 f"error_type: {type(error).__name__}" 474 ) from error 475 except RustClientError as error: 476 message = error.args[0] 477 status_code = error.args[1] 478 if status_code == 404: 479 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 480 elif status_code == 403: 481 raise PermissionError( 482 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 483 ) from error 484 else: 485 raise RetryableError( 486 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 487 ) from error 488 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 489 raise RetryableError( 490 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 491 f"error_type: {type(error).__name__}" 492 ) from error 493 except RustRetryableError as error: 494 raise RetryableError( 495 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 496 f"error_type: {type(error).__name__}" 497 ) from error 498 except Exception as error: 499 raise RuntimeError( 500 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 501 ) from error 502 503 def _put_object( 504 self, 505 path: str, 506 body: bytes, 507 if_match: Optional[str] = None, 508 if_none_match: Optional[str] = None, 509 attributes: Optional[dict[str, str]] = None, 510 content_type: Optional[str] = None, 511 ) -> int: 512 """ 513 Uploads an object to the specified S3 path. 514 515 :param path: The S3 path where the object will be uploaded. 516 :param body: The content of the object as bytes. 517 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 518 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 519 :param attributes: Optional attributes to attach to the object. 520 :param content_type: Optional Content-Type header value. 521 """ 522 bucket, key = split_path(path) 523 524 def _invoke_api() -> int: 525 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 526 if content_type: 527 kwargs["ContentType"] = content_type 528 if self._is_directory_bucket(bucket): 529 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 530 if if_match: 531 kwargs["IfMatch"] = if_match 532 if if_none_match: 533 kwargs["IfNoneMatch"] = if_none_match 534 validated_attributes = validate_attributes(attributes) 535 if validated_attributes: 536 kwargs["Metadata"] = validated_attributes 537 538 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client 539 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"} 540 if ( 541 self._rust_client 542 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 543 and not path.endswith("/") 544 and all(key not in kwargs for key in rust_unsupported_feature_keys) 545 ): 546 run_async_rust_client_method(self._rust_client, "put", key, body) 547 else: 548 if self._checksum_algorithm: 549 kwargs["ChecksumAlgorithm"] = self._checksum_algorithm 550 self._s3_client.put_object(**kwargs) 551 552 return len(body) 553 554 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 555 556 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 557 bucket, key = split_path(path) 558 559 def _invoke_api() -> bytes: 560 if byte_range: 561 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 562 if self._rust_client: 563 response = run_async_rust_client_method( 564 self._rust_client, 565 "get", 566 key, 567 byte_range, 568 ) 569 return response 570 else: 571 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 572 else: 573 if self._rust_client: 574 response = run_async_rust_client_method(self._rust_client, "get", key) 575 return response 576 else: 577 response = self._s3_client.get_object(Bucket=bucket, Key=key) 578 579 return response["Body"].read() 580 581 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 582 583 def _copy_object(self, src_path: str, dest_path: str) -> int: 584 src_bucket, src_key = split_path(src_path) 585 dest_bucket, dest_key = split_path(dest_path) 586 587 src_object = self._get_object_metadata(src_path) 588 589 def _invoke_api() -> int: 590 self._s3_client.copy( 591 CopySource={"Bucket": src_bucket, "Key": src_key}, 592 Bucket=dest_bucket, 593 Key=dest_key, 594 Config=self._transfer_config, 595 ) 596 597 return src_object.content_length 598 599 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key) 600 601 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 602 bucket, key = split_path(path) 603 604 def _invoke_api() -> None: 605 # Delete conditionally when if_match (etag) is provided; otherwise delete unconditionally 606 if if_match: 607 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 608 else: 609 self._s3_client.delete_object(Bucket=bucket, Key=key) 610 611 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 612 613 def _delete_objects(self, paths: list[str]) -> None: 614 if not paths: 615 return 616 617 by_bucket: dict[str, list[str]] = {} 618 for p in paths: 619 bucket, key = split_path(p) 620 by_bucket.setdefault(bucket, []).append(key) 621 622 S3_BATCH_LIMIT = 1000 623 624 def _invoke_api() -> None: 625 all_errors: list[str] = [] 626 for bucket, keys in by_bucket.items(): 627 for i in range(0, len(keys), S3_BATCH_LIMIT): 628 chunk = keys[i : i + S3_BATCH_LIMIT] 629 response = self._s3_client.delete_objects( 630 Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]} 631 ) 632 errors = response.get("Errors") or [] 633 for e in errors: 634 all_errors.append(f"{bucket}/{e.get('Key', '?')}: {e.get('Code', '')} {e.get('Message', '')}") 635 if all_errors: 636 raise RuntimeError(f"DeleteObjects reported errors: {'; '.join(all_errors)}") 637 638 bucket_desc = "(" + "|".join(by_bucket) + ")" 639 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)" 640 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc) 641 642 def _is_dir(self, path: str) -> bool: 643 # Ensure the path ends with '/' to mimic a directory 644 path = self._append_delimiter(path) 645 646 bucket, key = split_path(path) 647 648 def _invoke_api() -> bool: 649 # List objects with the given prefix 650 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 651 652 # Check if there are any contents or common prefixes 653 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 654 655 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 656 657 def _make_symlink(self, path: str, target: str) -> None: 658 bucket, key = split_path(path) 659 target_bucket, target_key = split_path(target) 660 if bucket != target_bucket: 661 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket}' -> '{target_bucket}'.") 662 relative_target = ObjectMetadata.encode_symlink_target(key, target_key) 663 664 def _invoke_api() -> None: 665 self._s3_client.put_object( 666 Bucket=bucket, 667 Key=key, 668 Body=b"", 669 Metadata={"msc-symlink-target": relative_target}, 670 ) 671 672 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 673 674 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 675 bucket, key = split_path(path) 676 if path.endswith("/") or (bucket and not key): 677 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 678 # which metadata is not guaranteed to exist for cases such as 679 # "virtual prefix" that was never explicitly created. 680 if self._is_dir(path): 681 return ObjectMetadata( 682 key=path, 683 type="directory", 684 content_length=0, 685 last_modified=AWARE_DATETIME_MIN, 686 ) 687 else: 688 raise FileNotFoundError(f"Directory {path} does not exist.") 689 else: 690 691 def _invoke_api() -> ObjectMetadata: 692 response = self._s3_client.head_object(Bucket=bucket, Key=key) 693 user_metadata = response.get("Metadata") 694 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None 695 696 return ObjectMetadata( 697 key=path, 698 type="file", 699 content_length=response["ContentLength"], 700 content_type=response.get("ContentType"), 701 last_modified=response.get("LastModified", AWARE_DATETIME_MIN), 702 etag=response["ETag"].strip('"') if "ETag" in response else None, 703 storage_class=response.get("StorageClass"), 704 metadata=user_metadata, 705 symlink_target=symlink_target, 706 ) 707 708 try: 709 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 710 except FileNotFoundError as error: 711 if strict: 712 # If the object does not exist on the given path, we will append a trailing slash and 713 # check if the path is a directory. 714 path = self._append_delimiter(path) 715 if self._is_dir(path): 716 return ObjectMetadata( 717 key=path, 718 type="directory", 719 content_length=0, 720 last_modified=AWARE_DATETIME_MIN, 721 ) 722 raise error 723 724 def _list_objects( 725 self, 726 path: str, 727 start_after: Optional[str] = None, 728 end_at: Optional[str] = None, 729 include_directories: bool = False, 730 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW, 731 ) -> Iterator[ObjectMetadata]: 732 bucket, prefix = split_path(path) 733 734 # Get the prefix of the start_after and end_at paths relative to the bucket. 735 if start_after: 736 _, start_after = split_path(start_after) 737 if end_at: 738 _, end_at = split_path(end_at) 739 740 def _invoke_api() -> Iterator[ObjectMetadata]: 741 paginator = self._s3_client.get_paginator("list_objects_v2") 742 if include_directories: 743 page_iterator = paginator.paginate( 744 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 745 ) 746 else: 747 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 748 749 for page in page_iterator: 750 for item in page.get("CommonPrefixes", []): 751 prefix_key = item["Prefix"].rstrip("/") 752 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes 753 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 754 yield ObjectMetadata( 755 key=os.path.join(bucket, prefix_key), 756 type="directory", 757 content_length=0, 758 last_modified=AWARE_DATETIME_MIN, 759 ) 760 elif end_at is not None and end_at < prefix_key: 761 return 762 763 # S3 guarantees lexicographical order for general purpose buckets (for 764 # normal S3) but not directory buckets (for S3 Express One Zone). 765 for response_object in page.get("Contents", []): 766 key = response_object["Key"] 767 if end_at is None or key <= end_at: 768 if key.endswith("/"): 769 if include_directories: 770 yield ObjectMetadata( 771 key=os.path.join(bucket, key.rstrip("/")), 772 type="directory", 773 content_length=0, 774 last_modified=response_object["LastModified"], 775 ) 776 else: 777 symlink_target = None 778 if response_object.get("Size", 0) == 0: 779 try: 780 meta = self._get_object_metadata(os.path.join(bucket, key)) 781 symlink_target = meta.symlink_target 782 except Exception: 783 symlink_target = None 784 yield ObjectMetadata( 785 key=os.path.join(bucket, key), 786 type="file", 787 content_length=response_object["Size"], 788 last_modified=response_object["LastModified"], 789 etag=response_object["ETag"].strip('"'), 790 storage_class=response_object.get("StorageClass"), 791 symlink_target=symlink_target, 792 ) 793 else: 794 return 795 796 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 797 798 @property 799 def supports_parallel_listing(self) -> bool: 800 """ 801 S3 supports parallel listing via delimiter-based prefix discovery. 802 803 Note: Directory bucket handling is done in list_objects_recursive(). 804 """ 805 return True 806
[docs] 807 def list_objects_recursive( 808 self, 809 path: str = "", 810 start_after: Optional[str] = None, 811 end_at: Optional[str] = None, 812 max_workers: int = 32, 813 look_ahead: int = 2, 814 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW, 815 ) -> Iterator[ObjectMetadata]: 816 """ 817 List all objects recursively using parallel prefix discovery for improved performance. 818 819 For S3, uses the Rust client's list_recursive when available for maximum performance. 820 Falls back to Python implementation otherwise. 821 822 Returns files only (no directories), in lexicographic order. 823 824 :param symlink_handling: How to handle symbolic links during listing. 825 """ 826 if (start_after is not None) and (end_at is not None) and not (start_after < end_at): 827 raise ValueError(f"start_after ({start_after}) must be before end_at ({end_at})!") 828 829 full_path = self._prepend_base_path(path) 830 bucket, prefix = split_path(full_path) 831 832 if self._is_directory_bucket(bucket): 833 yield from self.list_objects( 834 path, start_after, end_at, include_directories=False, symlink_handling=symlink_handling 835 ) 836 return 837 838 if self._rust_client: 839 yield from self._emit_metrics( 840 operation=BaseStorageProvider._Operation.LIST, 841 f=lambda: self._list_objects_recursive_rust(path, full_path, bucket, start_after, end_at, max_workers), 842 ) 843 else: 844 yield from super().list_objects_recursive( 845 path, start_after, end_at, max_workers, look_ahead, symlink_handling 846 )
847 848 def _list_objects_recursive_rust( 849 self, 850 path: str, 851 full_path: str, 852 bucket: str, 853 start_after: Optional[str], 854 end_at: Optional[str], 855 max_workers: int, 856 ) -> Iterator[ObjectMetadata]: 857 """ 858 Use Rust client's list_recursive for parallel listing. 859 860 The Rust client already handles parallel listing internally. 861 Returns files only in lexicographic order. 862 """ 863 _, prefix = split_path(full_path) 864 865 def _invoke_api() -> Iterator[ObjectMetadata]: 866 result = run_async_rust_client_method( 867 self._rust_client, 868 "list_recursive", 869 [prefix] if prefix else [""], 870 max_concurrency=max_workers, 871 ) 872 873 start_after_full = self._prepend_base_path(start_after) if start_after else None 874 end_at_full = self._prepend_base_path(end_at) if end_at else None 875 876 for obj in result.objects: 877 full_key = os.path.join(bucket, obj.key) 878 879 if start_after_full and full_key <= start_after_full: 880 continue 881 if end_at_full and full_key > end_at_full: 882 break 883 884 relative_key = full_key.removeprefix(self._base_path).lstrip("/") 885 886 yield ObjectMetadata( 887 key=relative_key, 888 content_length=obj.content_length, 889 last_modified=dateutil_parse(obj.last_modified), 890 type="file" if obj.object_type == "object" else obj.object_type, 891 etag=obj.etag, 892 ) 893 894 yield from self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 895 896 def _upload_file( 897 self, 898 remote_path: str, 899 f: Union[str, IO], 900 attributes: Optional[dict[str, str]] = None, 901 content_type: Optional[str] = None, 902 ) -> int: 903 file_size: int = 0 904 905 if isinstance(f, str): 906 bucket, key = split_path(remote_path) 907 file_size = os.path.getsize(f) 908 909 # Upload small files 910 if file_size <= self._multipart_threshold: 911 if self._rust_client and not attributes and not content_type: 912 run_async_rust_client_method(self._rust_client, "upload", f, key) 913 else: 914 with open(f, "rb") as fp: 915 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type) 916 return file_size 917 918 # Upload large files using TransferConfig 919 def _invoke_api() -> int: 920 extra_args = {} 921 if content_type: 922 extra_args["ContentType"] = content_type 923 if self._is_directory_bucket(bucket): 924 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 925 validated_attributes = validate_attributes(attributes) 926 if validated_attributes: 927 extra_args["Metadata"] = validated_attributes 928 if self._rust_client and not extra_args: 929 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 930 else: 931 if self._checksum_algorithm: 932 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm 933 self._s3_client.upload_file( 934 Filename=f, 935 Bucket=bucket, 936 Key=key, 937 Config=self._transfer_config, 938 ExtraArgs=extra_args, 939 ) 940 941 return file_size 942 943 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 944 else: 945 # Upload small files 946 f.seek(0, io.SEEK_END) 947 file_size = f.tell() 948 f.seek(0) 949 950 if file_size <= self._multipart_threshold: 951 if isinstance(f, io.StringIO): 952 self._put_object( 953 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type 954 ) 955 else: 956 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type) 957 return file_size 958 959 # Upload large files using TransferConfig 960 bucket, key = split_path(remote_path) 961 962 def _invoke_api() -> int: 963 extra_args = {} 964 if content_type: 965 extra_args["ContentType"] = content_type 966 if self._is_directory_bucket(bucket): 967 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 968 validated_attributes = validate_attributes(attributes) 969 if validated_attributes: 970 extra_args["Metadata"] = validated_attributes 971 972 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args: 973 data = f.getbuffer() 974 try: 975 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data) 976 finally: 977 data.release() 978 else: 979 if self._checksum_algorithm: 980 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm 981 self._s3_client.upload_fileobj( 982 Fileobj=f, 983 Bucket=bucket, 984 Key=key, 985 Config=self._transfer_config, 986 ExtraArgs=extra_args, 987 ) 988 989 return file_size 990 991 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 992 993 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 994 if metadata is None: 995 metadata = self._get_object_metadata(remote_path) 996 997 if isinstance(f, str): 998 bucket, key = split_path(remote_path) 999 if os.path.dirname(f): 1000 safe_makedirs(os.path.dirname(f)) 1001 1002 # Download small files 1003 if metadata.content_length <= self._multipart_threshold: 1004 if self._rust_client: 1005 run_async_rust_client_method(self._rust_client, "download", key, f) 1006 else: 1007 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 1008 temp_file_path = fp.name 1009 fp.write(self._get_object(remote_path)) 1010 os.rename(src=temp_file_path, dst=f) 1011 return metadata.content_length 1012 1013 # Download large files using TransferConfig 1014 def _invoke_api() -> int: 1015 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 1016 temp_file_path = fp.name 1017 if self._rust_client: 1018 run_async_rust_client_method( 1019 self._rust_client, "download_multipart_to_file", key, temp_file_path 1020 ) 1021 else: 1022 self._s3_client.download_fileobj( 1023 Bucket=bucket, 1024 Key=key, 1025 Fileobj=fp, 1026 Config=self._transfer_config, 1027 ) 1028 1029 os.rename(src=temp_file_path, dst=f) 1030 1031 return metadata.content_length 1032 1033 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 1034 else: 1035 # Download small files 1036 if metadata.content_length <= self._multipart_threshold: 1037 response = self._get_object(remote_path) 1038 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol, 1039 # so we need to check whether `.decode()` is available. 1040 if isinstance(f, io.StringIO): 1041 if hasattr(response, "decode"): 1042 f.write(response.decode("utf-8")) 1043 else: 1044 f.write(codecs.decode(memoryview(response), "utf-8")) 1045 else: 1046 f.write(response) 1047 return metadata.content_length 1048 1049 # Download large files using TransferConfig 1050 bucket, key = split_path(remote_path) 1051 1052 def _invoke_api() -> int: 1053 self._s3_client.download_fileobj( 1054 Bucket=bucket, 1055 Key=key, 1056 Fileobj=f, 1057 Config=self._transfer_config, 1058 ) 1059 1060 return metadata.content_length 1061 1062 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 1063 1064 def _generate_presigned_url( 1065 self, 1066 path: str, 1067 *, 1068 method: str = "GET", 1069 signer_type: Optional[SignerType] = None, 1070 signer_options: Optional[dict[str, Any]] = None, 1071 ) -> str: 1072 options = signer_options or {} 1073 bucket, key = split_path(path) 1074 1075 if signer_type is None or signer_type == SignerType.S3: 1076 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN)) 1077 cache_key: tuple = (SignerType.S3, bucket, expires_in) 1078 if cache_key not in self._signer_cache: 1079 self._signer_cache[cache_key] = S3URLSigner(self._s3_client, bucket, expires_in=expires_in) 1080 elif signer_type == SignerType.CLOUDFRONT: 1081 cache_key = (SignerType.CLOUDFRONT, frozenset(options.items())) 1082 if cache_key not in self._signer_cache: 1083 self._signer_cache[cache_key] = CloudFrontURLSigner(**options) 1084 else: 1085 raise ValueError(f"Unsupported signer type for S3 provider: {signer_type!r}") 1086 1087 return self._signer_cache[cache_key].generate_presigned_url(key, method=method)