Source code for multistorageclient.providers.s3

   1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
   2# SPDX-License-Identifier: Apache-2.0
   3#
   4# Licensed under the Apache License, Version 2.0 (the "License");
   5# you may not use this file except in compliance with the License.
   6# You may obtain a copy of the License at
   7#
   8# http://www.apache.org/licenses/LICENSE-2.0
   9#
  10# Unless required by applicable law or agreed to in writing, software
  11# distributed under the License is distributed on an "AS IS" BASIS,
  12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13# See the License for the specific language governing permissions and
  14# limitations under the License.
  15
  16import codecs
  17import io
  18import os
  19import tempfile
  20from collections.abc import Callable, Iterator
  21from typing import IO, Any, Optional, TypeVar, Union
  22
  23import boto3
  24import botocore
  25from boto3.s3.transfer import TransferConfig
  26from botocore.credentials import RefreshableCredentials
  27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
  28from botocore.session import get_session
  29from dateutil.parser import parse as dateutil_parse
  30
  31from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
  32
  33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
  34from ..rust_utils import parse_retry_config, run_async_rust_client_method
  35from ..signers import CloudFrontURLSigner, URLSigner
  36from ..telemetry import Telemetry
  37from ..types import (
  38    AWARE_DATETIME_MIN,
  39    Credentials,
  40    CredentialsProvider,
  41    ObjectMetadata,
  42    PreconditionFailedError,
  43    Range,
  44    RetryableError,
  45    SignerType,
  46    SymlinkHandling,
  47)
  48from ..utils import (
  49    get_available_cpu_count,
  50    safe_makedirs,
  51    split_path,
  52    validate_attributes,
  53)
  54from .base import BaseStorageProvider
  55
  56_T = TypeVar("_T")
  57
  58# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
  59MAX_POOL_CONNECTIONS = max(
  60    64,
  61    get_available_cpu_count(),
  62    int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
  63)
  64
  65MiB = 1024 * 1024
  66
  67# Python and Rust share the same multipart_threshold to keep the code simple.
  68MULTIPART_THRESHOLD = 64 * MiB
  69MULTIPART_CHUNKSIZE = 32 * MiB
  70IO_CHUNKSIZE = 32 * MiB
  71PYTHON_MAX_CONCURRENCY = 8
  72
  73PROVIDER = "s3"
  74
  75EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
  76
  77# Source: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
  78SUPPORTED_CHECKSUM_ALGORITHMS = frozenset({"CRC32", "CRC32C", "SHA1", "SHA256", "CRC64NVME"})
  79
  80
[docs] 81class StaticS3CredentialsProvider(CredentialsProvider): 82 """ 83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 84 """ 85 86 _access_key: str 87 _secret_key: str 88 _session_token: Optional[str] 89 90 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 91 """ 92 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 93 session token. 94 95 :param access_key: The access key for S3 authentication. 96 :param secret_key: The secret key for S3 authentication. 97 :param session_token: An optional session token for temporary credentials. 98 """ 99 self._access_key = access_key 100 self._secret_key = secret_key 101 self._session_token = session_token 102
[docs] 103 def get_credentials(self) -> Credentials: 104 return Credentials( 105 access_key=self._access_key, 106 secret_key=self._secret_key, 107 token=self._session_token, 108 expiration=None, 109 )
110
[docs] 111 def refresh_credentials(self) -> None: 112 pass
113 114 115DEFAULT_PRESIGN_EXPIRES_IN = 3600 116 117_S3_METHOD_MAPPING: dict[str, str] = { 118 "GET": "get_object", 119 "PUT": "put_object", 120} 121 122
[docs] 123class S3URLSigner(URLSigner): 124 """Generates pre-signed URLs using the boto3 S3 client. 125 126 When the underlying credentials are temporary (STS, IAM role, EC2 instance 127 profile), the effective URL lifetime is the **shorter** of ``expires_in`` 128 and the remaining credential lifetime — boto3 will not warn if the 129 credential expires before ``expires_in``. 130 131 See https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html 132 """ 133 134 def __init__(self, s3_client: Any, bucket: str, expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN) -> None: 135 self._s3_client = s3_client 136 self._bucket = bucket 137 self._expires_in = expires_in 138
[docs] 139 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str: 140 client_method = _S3_METHOD_MAPPING.get(method.upper()) 141 if client_method is None: 142 raise ValueError(f"Unsupported method for S3 presigning: {method!r}") 143 return self._s3_client.generate_presigned_url( 144 ClientMethod=client_method, 145 Params={"Bucket": self._bucket, "Key": path}, 146 ExpiresIn=self._expires_in, 147 )
148 149
[docs] 150class S3StorageProvider(BaseStorageProvider): 151 """ 152 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 153 """ 154 155 def __init__( 156 self, 157 region_name: str = "", 158 endpoint_url: str = "", 159 base_path: str = "", 160 credentials_provider: Optional[CredentialsProvider] = None, 161 profile_name: Optional[str] = None, 162 config_dict: Optional[dict[str, Any]] = None, 163 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 164 verify: Optional[Union[bool, str]] = None, 165 **kwargs: Any, 166 ) -> None: 167 """ 168 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 169 170 :param region_name: The AWS region where the S3 bucket is located. 171 :param endpoint_url: The custom endpoint URL for the S3 service. 172 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 173 :param credentials_provider: The provider to retrieve S3 credentials. 174 :param profile_name: AWS shared configuration + credentials files profile. For :py:class:`boto3.session.Session`. Ignored if ``credentials_provider`` is set. 175 :param config_dict: Resolved MSC config. 176 :param telemetry_provider: A function that provides a telemetry instance. 177 :param verify: Controls SSL certificate verification. 178 Can be ``True`` (verify using system CA bundle, default), ``False`` (skip verification), or a string path to a custom CA certificate bundle. 179 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. 180 When the underlying S3 client should calculate request checksums. 181 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 182 :param response_checksum_validation: For :py:class:`botocore.config.Config`. 183 When the underlying S3 client should validate response checksums. 184 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 185 :param max_pool_connections: For :py:class:`botocore.config.Config`. 186 The maximum number of connections to keep in a connection pool. 187 :param connect_timeout: For :py:class:`botocore.config.Config`. 188 The time in seconds till a timeout exception is thrown when attempting to make a connection. 189 :param read_timeout: For :py:class:`botocore.config.Config`. 190 The time in seconds till a timeout exception is thrown when attempting to read from a connection. 191 :param retries: For :py:class:`botocore.config.Config`. 192 A dictionary for configuration related to retry behavior. 193 :param s3: For :py:class:`botocore.config.Config`. 194 A dictionary of S3 specific configurations. 195 :param checksum_algorithm: Upload-only object integrity algorithm. 196 One of ``"CRC32"``, ``"CRC32C"``, ``"SHA1"``, ``"SHA256"``, ``"CRC64NVME"`` (case-insensitive). 197 When ``rust_client`` is enabled, only ``"SHA256"`` is accepted. 198 """ 199 super().__init__( 200 base_path=base_path, 201 provider_name=PROVIDER, 202 config_dict=config_dict, 203 telemetry_provider=telemetry_provider, 204 ) 205 206 self._region_name = region_name 207 self._endpoint_url = endpoint_url 208 self._credentials_provider = credentials_provider 209 self._profile_name = profile_name 210 self._verify = verify 211 212 self._signature_version = kwargs.get("signature_version", "s3v4") 213 self._s3_client = self._create_s3_client( 214 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 215 response_checksum_validation=kwargs.get("response_checksum_validation"), 216 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS), 217 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT), 218 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT), 219 retries=kwargs.get("retries"), 220 s3=kwargs.get("s3"), 221 ) 222 self._transfer_config = TransferConfig( 223 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 224 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)), 225 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)), 226 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)), 227 use_threads=True, 228 ) 229 self._multipart_threshold = self._transfer_config.multipart_threshold 230 231 self._signer_cache: dict[tuple, URLSigner] = {} 232 233 self._checksum_algorithm: Optional[str] = self._validate_checksum_algorithm(kwargs.get("checksum_algorithm")) 234 235 self._rust_client = None 236 if "rust_client" in kwargs: 237 # Inherit the rust client options from the kwargs 238 rust_client_options = kwargs["rust_client"] 239 if "max_pool_connections" in kwargs: 240 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"] 241 if "max_concurrency" in kwargs: 242 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 243 if "multipart_chunksize" in kwargs: 244 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 245 if "read_timeout" in kwargs: 246 rust_client_options["read_timeout"] = kwargs["read_timeout"] 247 if "connect_timeout" in kwargs: 248 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 249 if "checksum_algorithm" in kwargs: 250 rust_client_options["checksum_algorithm"] = kwargs["checksum_algorithm"] 251 if self._signature_version == "UNSIGNED": 252 rust_client_options["skip_signature"] = True 253 254 # Rust client only supports SHA256 today (object_store limitation). 255 rust_checksum = rust_client_options.get("checksum_algorithm") 256 if rust_checksum is not None: 257 if str(rust_checksum).upper() != "SHA256": 258 raise ValueError( 259 f"Rust client only supports checksum_algorithm='SHA256' today " 260 f"(object_store limitation), got '{rust_checksum!r}'. " 261 f"Disable rust_client or use SHA256." 262 ) 263 rust_client_options["checksum_algorithm"] = "sha256" 264 self._rust_client = self._create_rust_client(rust_client_options) 265 266 @staticmethod 267 def _validate_checksum_algorithm(value: Optional[str]) -> Optional[str]: 268 """ 269 Normalizes the optional ``checksum_algorithm`` to uppercase, or returns ``None`` if disabled. 270 """ 271 if value is None: 272 return None 273 if not isinstance(value, str) or not value: 274 raise ValueError( 275 f"checksum_algorithm must be one of {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)} or None, got {value!r}." 276 ) 277 normalized = value.upper() 278 if normalized not in SUPPORTED_CHECKSUM_ALGORITHMS: 279 raise ValueError( 280 f"Unsupported checksum_algorithm '{value}'. Supported: {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)}." 281 ) 282 return normalized 283 284 def _is_directory_bucket(self, bucket: str) -> bool: 285 """ 286 Determines if the bucket is a directory bucket based on bucket name. 287 """ 288 # S3 Express buckets have a specific naming convention 289 return "--x-s3" in bucket 290 291 def _create_s3_client( 292 self, 293 request_checksum_calculation: Optional[str] = None, 294 response_checksum_validation: Optional[str] = None, 295 max_pool_connections: int = MAX_POOL_CONNECTIONS, 296 connect_timeout: Union[float, int, None] = None, 297 read_timeout: Union[float, int, None] = None, 298 retries: Optional[dict[str, Any]] = None, 299 s3: Optional[dict[str, Any]] = None, 300 ): 301 """ 302 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 303 304 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 305 :param response_checksum_validation: For :py:class:`botocore.config.Config`. When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 306 :param max_pool_connections: For :py:class:`botocore.config.Config`. The maximum number of connections to keep in a connection pool. 307 :param connect_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to make a connection. 308 :param read_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to read from a connection. 309 :param retries: For :py:class:`botocore.config.Config`. A dictionary for configuration related to retry behavior. 310 :param s3: For :py:class:`botocore.config.Config`. A dictionary of S3 specific configurations. 311 312 :return: The configured S3 client. 313 """ 314 options = { 315 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 316 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 317 max_pool_connections=max_pool_connections, 318 connect_timeout=connect_timeout, 319 read_timeout=read_timeout, 320 retries=retries or {"mode": "standard"}, 321 request_checksum_calculation=request_checksum_calculation, 322 response_checksum_validation=response_checksum_validation, 323 s3=s3, 324 ), 325 } 326 327 if self._region_name: 328 options["region_name"] = self._region_name 329 330 if self._endpoint_url: 331 options["endpoint_url"] = self._endpoint_url 332 333 if self._verify is not None: 334 options["verify"] = self._verify 335 336 if self._credentials_provider: 337 creds = self._fetch_credentials() 338 if "expiry_time" in creds and creds["expiry_time"]: 339 # Use RefreshableCredentials if expiry_time provided. 340 refreshable_credentials = RefreshableCredentials.create_from_metadata( 341 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 342 ) 343 344 botocore_session = get_session() 345 botocore_session._credentials = refreshable_credentials 346 347 boto3_session = boto3.Session(botocore_session=botocore_session) 348 349 return boto3_session.client("s3", **options) 350 else: 351 # Add static credentials to the options dictionary 352 options["aws_access_key_id"] = creds["access_key"] 353 options["aws_secret_access_key"] = creds["secret_key"] 354 if creds["token"]: 355 options["aws_session_token"] = creds["token"] 356 357 if self._signature_version: 358 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 359 signature_version=botocore.UNSIGNED 360 if self._signature_version == "UNSIGNED" 361 else self._signature_version 362 ) 363 options["config"] = options["config"].merge(signature_config) 364 365 # Fallback to standard credential chain. 366 session = boto3.Session(profile_name=self._profile_name) 367 return session.client("s3", **options) 368 369 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 370 """ 371 Creates and configures the rust client, using refreshable credentials if possible. 372 """ 373 configs = dict(rust_client_options) if rust_client_options else {} 374 375 # Extract and parse retry configuration 376 retry_config = parse_retry_config(configs) 377 378 if self._region_name and "region_name" not in configs: 379 configs["region_name"] = self._region_name 380 381 if self._endpoint_url and "endpoint_url" not in configs: 382 configs["endpoint_url"] = self._endpoint_url 383 384 if self._profile_name and "profile_name" not in configs: 385 configs["profile_name"] = self._profile_name 386 387 # If the user specifies a bucket, use it. Otherwise, use the base path. 388 if "bucket" not in configs: 389 bucket, _ = split_path(self._base_path) 390 configs["bucket"] = bucket 391 392 if "max_pool_connections" not in configs: 393 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS 394 395 return RustClient( 396 provider=PROVIDER, 397 configs=configs, 398 credentials_provider=self._credentials_provider, 399 retry=retry_config, 400 ) 401 402 def _fetch_credentials(self) -> dict: 403 """ 404 Refreshes the S3 client if the current credentials are expired. 405 """ 406 if not self._credentials_provider: 407 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 408 self._credentials_provider.refresh_credentials() 409 credentials = self._credentials_provider.get_credentials() 410 return { 411 "access_key": credentials.access_key, 412 "secret_key": credentials.secret_key, 413 "token": credentials.token, 414 "expiry_time": credentials.expiration, 415 } 416 417 def _translate_errors( 418 self, 419 func: Callable[[], _T], 420 operation: str, 421 bucket: str, 422 key: str, 423 ) -> _T: 424 """ 425 Translates errors like timeouts and client errors. 426 427 :param func: The function that performs the actual S3 operation. 428 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 429 :param bucket: The name of the S3 bucket involved in the operation. 430 :param key: The key of the object within the S3 bucket. 431 432 :return: The result of the S3 operation, typically the return value of the `func` callable. 433 """ 434 try: 435 return func() 436 except ClientError as error: 437 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 438 request_id = error.response["ResponseMetadata"].get("RequestId") 439 host_id = error.response["ResponseMetadata"].get("HostId") 440 error_code = error.response["Error"]["Code"] 441 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 442 443 if status_code == 404: 444 if error_code == "NoSuchUpload": 445 error_message = error.response["Error"]["Message"] 446 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 447 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 448 elif status_code == 412: # Precondition Failed 449 raise PreconditionFailedError( 450 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 451 ) from error 452 elif status_code == 429: 453 raise RetryableError( 454 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 455 ) from error 456 elif status_code == 503: 457 raise RetryableError( 458 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 459 ) from error 460 elif status_code == 501: 461 raise NotImplementedError( 462 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 463 ) from error 464 elif status_code == 408: 465 # 408 Request Timeout is from Google Cloud Storage 466 raise RetryableError( 467 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}" 468 ) from error 469 else: 470 raise RuntimeError( 471 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 472 f"error_type: {type(error).__name__}" 473 ) from error 474 except RustClientError as error: 475 message = error.args[0] 476 status_code = error.args[1] 477 if status_code == 404: 478 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 479 elif status_code == 403: 480 raise PermissionError( 481 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 482 ) from error 483 else: 484 raise RetryableError( 485 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 486 ) from error 487 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 488 raise RetryableError( 489 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 490 f"error_type: {type(error).__name__}" 491 ) from error 492 except RustRetryableError as error: 493 raise RetryableError( 494 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 495 f"error_type: {type(error).__name__}" 496 ) from error 497 except Exception as error: 498 raise RuntimeError( 499 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 500 ) from error 501 502 def _put_object( 503 self, 504 path: str, 505 body: bytes, 506 if_match: Optional[str] = None, 507 if_none_match: Optional[str] = None, 508 attributes: Optional[dict[str, str]] = None, 509 content_type: Optional[str] = None, 510 ) -> int: 511 """ 512 Uploads an object to the specified S3 path. 513 514 :param path: The S3 path where the object will be uploaded. 515 :param body: The content of the object as bytes. 516 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 517 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 518 :param attributes: Optional attributes to attach to the object. 519 :param content_type: Optional Content-Type header value. 520 """ 521 bucket, key = split_path(path) 522 523 def _invoke_api() -> int: 524 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 525 if content_type: 526 kwargs["ContentType"] = content_type 527 if self._is_directory_bucket(bucket): 528 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 529 if if_match: 530 kwargs["IfMatch"] = if_match 531 if if_none_match: 532 kwargs["IfNoneMatch"] = if_none_match 533 validated_attributes = validate_attributes(attributes) 534 if validated_attributes: 535 kwargs["Metadata"] = validated_attributes 536 537 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client 538 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"} 539 if ( 540 self._rust_client 541 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 542 and not path.endswith("/") 543 and all(key not in kwargs for key in rust_unsupported_feature_keys) 544 ): 545 run_async_rust_client_method(self._rust_client, "put", key, body) 546 else: 547 if self._checksum_algorithm: 548 kwargs["ChecksumAlgorithm"] = self._checksum_algorithm 549 self._s3_client.put_object(**kwargs) 550 551 return len(body) 552 553 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 554 555 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 556 bucket, key = split_path(path) 557 558 def _invoke_api() -> bytes: 559 if byte_range: 560 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 561 if self._rust_client: 562 response = run_async_rust_client_method( 563 self._rust_client, 564 "get", 565 key, 566 byte_range, 567 ) 568 return response 569 else: 570 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 571 else: 572 if self._rust_client: 573 response = run_async_rust_client_method(self._rust_client, "get", key) 574 return response 575 else: 576 response = self._s3_client.get_object(Bucket=bucket, Key=key) 577 578 return response["Body"].read() 579 580 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 581 582 def _copy_object(self, src_path: str, dest_path: str) -> int: 583 src_bucket, src_key = split_path(src_path) 584 dest_bucket, dest_key = split_path(dest_path) 585 586 src_object = self._get_object_metadata(src_path) 587 588 def _invoke_api() -> int: 589 self._s3_client.copy( 590 CopySource={"Bucket": src_bucket, "Key": src_key}, 591 Bucket=dest_bucket, 592 Key=dest_key, 593 Config=self._transfer_config, 594 ) 595 596 return src_object.content_length 597 598 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key) 599 600 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 601 bucket, key = split_path(path) 602 603 def _invoke_api() -> None: 604 # Delete conditionally when if_match (etag) is provided; otherwise delete unconditionally 605 if if_match: 606 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 607 else: 608 self._s3_client.delete_object(Bucket=bucket, Key=key) 609 610 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 611 612 def _delete_objects(self, paths: list[str]) -> None: 613 if not paths: 614 return 615 616 by_bucket: dict[str, list[str]] = {} 617 for p in paths: 618 bucket, key = split_path(p) 619 by_bucket.setdefault(bucket, []).append(key) 620 621 S3_BATCH_LIMIT = 1000 622 623 def _invoke_api() -> None: 624 all_errors: list[str] = [] 625 for bucket, keys in by_bucket.items(): 626 for i in range(0, len(keys), S3_BATCH_LIMIT): 627 chunk = keys[i : i + S3_BATCH_LIMIT] 628 response = self._s3_client.delete_objects( 629 Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]} 630 ) 631 errors = response.get("Errors") or [] 632 for e in errors: 633 all_errors.append(f"{bucket}/{e.get('Key', '?')}: {e.get('Code', '')} {e.get('Message', '')}") 634 if all_errors: 635 raise RuntimeError(f"DeleteObjects reported errors: {'; '.join(all_errors)}") 636 637 bucket_desc = "(" + "|".join(by_bucket) + ")" 638 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)" 639 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc) 640 641 def _is_dir(self, path: str) -> bool: 642 # Ensure the path ends with '/' to mimic a directory 643 path = self._append_delimiter(path) 644 645 bucket, key = split_path(path) 646 647 def _invoke_api() -> bool: 648 # List objects with the given prefix 649 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 650 651 # Check if there are any contents or common prefixes 652 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 653 654 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 655 656 def _make_symlink(self, path: str, target: str) -> None: 657 bucket, key = split_path(path) 658 target_bucket, target_key = split_path(target) 659 if bucket != target_bucket: 660 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket}' -> '{target_bucket}'.") 661 relative_target = ObjectMetadata.encode_symlink_target(key, target_key) 662 663 def _invoke_api() -> None: 664 self._s3_client.put_object( 665 Bucket=bucket, 666 Key=key, 667 Body=b"", 668 Metadata={"msc-symlink-target": relative_target}, 669 ) 670 671 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 672 673 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 674 bucket, key = split_path(path) 675 if path.endswith("/") or (bucket and not key): 676 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 677 # which metadata is not guaranteed to exist for cases such as 678 # "virtual prefix" that was never explicitly created. 679 if self._is_dir(path): 680 return ObjectMetadata( 681 key=path, 682 type="directory", 683 content_length=0, 684 last_modified=AWARE_DATETIME_MIN, 685 ) 686 else: 687 raise FileNotFoundError(f"Directory {path} does not exist.") 688 else: 689 690 def _invoke_api() -> ObjectMetadata: 691 response = self._s3_client.head_object(Bucket=bucket, Key=key) 692 user_metadata = response.get("Metadata") 693 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None 694 695 return ObjectMetadata( 696 key=path, 697 type="file", 698 content_length=response["ContentLength"], 699 content_type=response.get("ContentType"), 700 last_modified=response.get("LastModified", AWARE_DATETIME_MIN), 701 etag=response["ETag"].strip('"') if "ETag" in response else None, 702 storage_class=response.get("StorageClass"), 703 metadata=user_metadata, 704 symlink_target=symlink_target, 705 ) 706 707 try: 708 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 709 except FileNotFoundError as error: 710 if strict: 711 # If the object does not exist on the given path, we will append a trailing slash and 712 # check if the path is a directory. 713 path = self._append_delimiter(path) 714 if self._is_dir(path): 715 return ObjectMetadata( 716 key=path, 717 type="directory", 718 content_length=0, 719 last_modified=AWARE_DATETIME_MIN, 720 ) 721 raise error 722 723 def _list_objects( 724 self, 725 path: str, 726 start_after: Optional[str] = None, 727 end_at: Optional[str] = None, 728 include_directories: bool = False, 729 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW, 730 ) -> Iterator[ObjectMetadata]: 731 bucket, prefix = split_path(path) 732 733 # Get the prefix of the start_after and end_at paths relative to the bucket. 734 if start_after: 735 _, start_after = split_path(start_after) 736 if end_at: 737 _, end_at = split_path(end_at) 738 739 def _invoke_api() -> Iterator[ObjectMetadata]: 740 paginator = self._s3_client.get_paginator("list_objects_v2") 741 if include_directories: 742 page_iterator = paginator.paginate( 743 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 744 ) 745 else: 746 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 747 748 for page in page_iterator: 749 for item in page.get("CommonPrefixes", []): 750 prefix_key = item["Prefix"].rstrip("/") 751 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes 752 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 753 yield ObjectMetadata( 754 key=os.path.join(bucket, prefix_key), 755 type="directory", 756 content_length=0, 757 last_modified=AWARE_DATETIME_MIN, 758 ) 759 elif end_at is not None and end_at < prefix_key: 760 return 761 762 # S3 guarantees lexicographical order for general purpose buckets (for 763 # normal S3) but not directory buckets (for S3 Express One Zone). 764 for response_object in page.get("Contents", []): 765 key = response_object["Key"] 766 if end_at is None or key <= end_at: 767 if key.endswith("/"): 768 if include_directories: 769 yield ObjectMetadata( 770 key=os.path.join(bucket, key.rstrip("/")), 771 type="directory", 772 content_length=0, 773 last_modified=response_object["LastModified"], 774 ) 775 else: 776 symlink_target = None 777 if response_object.get("Size", 0) == 0: 778 try: 779 meta = self._get_object_metadata(os.path.join(bucket, key)) 780 symlink_target = meta.symlink_target 781 except Exception: 782 symlink_target = None 783 yield ObjectMetadata( 784 key=os.path.join(bucket, key), 785 type="file", 786 content_length=response_object["Size"], 787 last_modified=response_object["LastModified"], 788 etag=response_object["ETag"].strip('"'), 789 storage_class=response_object.get("StorageClass"), 790 symlink_target=symlink_target, 791 ) 792 else: 793 return 794 795 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 796 797 @property 798 def supports_parallel_listing(self) -> bool: 799 """ 800 S3 supports parallel listing via delimiter-based prefix discovery. 801 802 Note: Directory bucket handling is done in list_objects_recursive(). 803 """ 804 return True 805
[docs] 806 def list_objects_recursive( 807 self, 808 path: str = "", 809 start_after: Optional[str] = None, 810 end_at: Optional[str] = None, 811 max_workers: int = 32, 812 look_ahead: int = 2, 813 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW, 814 ) -> Iterator[ObjectMetadata]: 815 """ 816 List all objects recursively using parallel prefix discovery for improved performance. 817 818 For S3, uses the Rust client's list_recursive when available for maximum performance. 819 Falls back to Python implementation otherwise. 820 821 Returns files only (no directories), in lexicographic order. 822 823 :param symlink_handling: How to handle symbolic links during listing. 824 """ 825 if (start_after is not None) and (end_at is not None) and not (start_after < end_at): 826 raise ValueError(f"start_after ({start_after}) must be before end_at ({end_at})!") 827 828 full_path = self._prepend_base_path(path) 829 bucket, prefix = split_path(full_path) 830 831 if self._is_directory_bucket(bucket): 832 yield from self.list_objects( 833 path, start_after, end_at, include_directories=False, symlink_handling=symlink_handling 834 ) 835 return 836 837 if self._rust_client: 838 yield from self._emit_metrics( 839 operation=BaseStorageProvider._Operation.LIST, 840 f=lambda: self._list_objects_recursive_rust(path, full_path, bucket, start_after, end_at, max_workers), 841 ) 842 else: 843 yield from super().list_objects_recursive( 844 path, start_after, end_at, max_workers, look_ahead, symlink_handling 845 )
846 847 def _list_objects_recursive_rust( 848 self, 849 path: str, 850 full_path: str, 851 bucket: str, 852 start_after: Optional[str], 853 end_at: Optional[str], 854 max_workers: int, 855 ) -> Iterator[ObjectMetadata]: 856 """ 857 Use Rust client's list_recursive for parallel listing. 858 859 The Rust client already handles parallel listing internally. 860 Returns files only in lexicographic order. 861 """ 862 _, prefix = split_path(full_path) 863 864 def _invoke_api() -> Iterator[ObjectMetadata]: 865 result = run_async_rust_client_method( 866 self._rust_client, 867 "list_recursive", 868 [prefix] if prefix else [""], 869 max_concurrency=max_workers, 870 ) 871 872 start_after_full = self._prepend_base_path(start_after) if start_after else None 873 end_at_full = self._prepend_base_path(end_at) if end_at else None 874 875 for obj in result.objects: 876 full_key = os.path.join(bucket, obj.key) 877 878 if start_after_full and full_key <= start_after_full: 879 continue 880 if end_at_full and full_key > end_at_full: 881 break 882 883 relative_key = full_key.removeprefix(self._base_path).lstrip("/") 884 885 yield ObjectMetadata( 886 key=relative_key, 887 content_length=obj.content_length, 888 last_modified=dateutil_parse(obj.last_modified), 889 type="file" if obj.object_type == "object" else obj.object_type, 890 etag=obj.etag, 891 ) 892 893 yield from self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 894 895 def _upload_file( 896 self, 897 remote_path: str, 898 f: Union[str, IO], 899 attributes: Optional[dict[str, str]] = None, 900 content_type: Optional[str] = None, 901 ) -> int: 902 file_size: int = 0 903 904 if isinstance(f, str): 905 bucket, key = split_path(remote_path) 906 file_size = os.path.getsize(f) 907 908 # Upload small files 909 if file_size <= self._multipart_threshold: 910 if self._rust_client and not attributes and not content_type: 911 run_async_rust_client_method(self._rust_client, "upload", f, key) 912 else: 913 with open(f, "rb") as fp: 914 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type) 915 return file_size 916 917 # Upload large files using TransferConfig 918 def _invoke_api() -> int: 919 extra_args = {} 920 if content_type: 921 extra_args["ContentType"] = content_type 922 if self._is_directory_bucket(bucket): 923 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 924 validated_attributes = validate_attributes(attributes) 925 if validated_attributes: 926 extra_args["Metadata"] = validated_attributes 927 if self._rust_client and not extra_args: 928 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 929 else: 930 if self._checksum_algorithm: 931 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm 932 self._s3_client.upload_file( 933 Filename=f, 934 Bucket=bucket, 935 Key=key, 936 Config=self._transfer_config, 937 ExtraArgs=extra_args, 938 ) 939 940 return file_size 941 942 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 943 else: 944 # Upload small files 945 f.seek(0, io.SEEK_END) 946 file_size = f.tell() 947 f.seek(0) 948 949 if file_size <= self._multipart_threshold: 950 if isinstance(f, io.StringIO): 951 self._put_object( 952 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type 953 ) 954 else: 955 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type) 956 return file_size 957 958 # Upload large files using TransferConfig 959 bucket, key = split_path(remote_path) 960 961 def _invoke_api() -> int: 962 extra_args = {} 963 if content_type: 964 extra_args["ContentType"] = content_type 965 if self._is_directory_bucket(bucket): 966 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 967 validated_attributes = validate_attributes(attributes) 968 if validated_attributes: 969 extra_args["Metadata"] = validated_attributes 970 971 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args: 972 data = f.getbuffer() 973 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data) 974 else: 975 if self._checksum_algorithm: 976 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm 977 self._s3_client.upload_fileobj( 978 Fileobj=f, 979 Bucket=bucket, 980 Key=key, 981 Config=self._transfer_config, 982 ExtraArgs=extra_args, 983 ) 984 985 return file_size 986 987 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 988 989 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 990 if metadata is None: 991 metadata = self._get_object_metadata(remote_path) 992 993 if isinstance(f, str): 994 bucket, key = split_path(remote_path) 995 if os.path.dirname(f): 996 safe_makedirs(os.path.dirname(f)) 997 998 # Download small files 999 if metadata.content_length <= self._multipart_threshold: 1000 if self._rust_client: 1001 run_async_rust_client_method(self._rust_client, "download", key, f) 1002 else: 1003 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 1004 temp_file_path = fp.name 1005 fp.write(self._get_object(remote_path)) 1006 os.rename(src=temp_file_path, dst=f) 1007 return metadata.content_length 1008 1009 # Download large files using TransferConfig 1010 def _invoke_api() -> int: 1011 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 1012 temp_file_path = fp.name 1013 if self._rust_client: 1014 run_async_rust_client_method( 1015 self._rust_client, "download_multipart_to_file", key, temp_file_path 1016 ) 1017 else: 1018 self._s3_client.download_fileobj( 1019 Bucket=bucket, 1020 Key=key, 1021 Fileobj=fp, 1022 Config=self._transfer_config, 1023 ) 1024 1025 os.rename(src=temp_file_path, dst=f) 1026 1027 return metadata.content_length 1028 1029 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 1030 else: 1031 # Download small files 1032 if metadata.content_length <= self._multipart_threshold: 1033 response = self._get_object(remote_path) 1034 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol, 1035 # so we need to check whether `.decode()` is available. 1036 if isinstance(f, io.StringIO): 1037 if hasattr(response, "decode"): 1038 f.write(response.decode("utf-8")) 1039 else: 1040 f.write(codecs.decode(memoryview(response), "utf-8")) 1041 else: 1042 f.write(response) 1043 return metadata.content_length 1044 1045 # Download large files using TransferConfig 1046 bucket, key = split_path(remote_path) 1047 1048 def _invoke_api() -> int: 1049 self._s3_client.download_fileobj( 1050 Bucket=bucket, 1051 Key=key, 1052 Fileobj=f, 1053 Config=self._transfer_config, 1054 ) 1055 1056 return metadata.content_length 1057 1058 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 1059 1060 def _generate_presigned_url( 1061 self, 1062 path: str, 1063 *, 1064 method: str = "GET", 1065 signer_type: Optional[SignerType] = None, 1066 signer_options: Optional[dict[str, Any]] = None, 1067 ) -> str: 1068 options = signer_options or {} 1069 bucket, key = split_path(path) 1070 1071 if signer_type is None or signer_type == SignerType.S3: 1072 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN)) 1073 cache_key: tuple = (SignerType.S3, bucket, expires_in) 1074 if cache_key not in self._signer_cache: 1075 self._signer_cache[cache_key] = S3URLSigner(self._s3_client, bucket, expires_in=expires_in) 1076 elif signer_type == SignerType.CLOUDFRONT: 1077 cache_key = (SignerType.CLOUDFRONT, frozenset(options.items())) 1078 if cache_key not in self._signer_cache: 1079 self._signer_cache[cache_key] = CloudFrontURLSigner(**options) 1080 else: 1081 raise ValueError(f"Unsupported signer type for S3 provider: {signer_type!r}") 1082 1083 return self._signer_cache[cache_key].generate_presigned_url(key, method=method)