Source code for multistorageclient.providers.s3

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import io
 18import os
 19import tempfile
 20from collections.abc import Callable, Iterator
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23import boto3
 24import botocore
 25from boto3.s3.transfer import TransferConfig
 26from botocore.credentials import RefreshableCredentials
 27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
 28from botocore.session import get_session
 29from dateutil.parser import parse as dateutil_parse
 30
 31from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
 32
 33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
 34from ..rust_utils import parse_retry_config, run_async_rust_client_method
 35from ..telemetry import Telemetry
 36from ..types import (
 37    AWARE_DATETIME_MIN,
 38    Credentials,
 39    CredentialsProvider,
 40    ObjectMetadata,
 41    PreconditionFailedError,
 42    Range,
 43    RetryableError,
 44)
 45from ..utils import (
 46    get_available_cpu_count,
 47    safe_makedirs,
 48    split_path,
 49    validate_attributes,
 50)
 51from .base import BaseStorageProvider
 52
 53_T = TypeVar("_T")
 54
 55# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
 56MAX_POOL_CONNECTIONS = max(
 57    64,
 58    get_available_cpu_count(),
 59    int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
 60)
 61
 62MiB = 1024 * 1024
 63
 64# Python and Rust share the same multipart_threshold to keep the code simple.
 65MULTIPART_THRESHOLD = 64 * MiB
 66MULTIPART_CHUNKSIZE = 32 * MiB
 67IO_CHUNKSIZE = 32 * MiB
 68PYTHON_MAX_CONCURRENCY = 8
 69
 70PROVIDER = "s3"
 71
 72EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
 73
 74
[docs] 75class StaticS3CredentialsProvider(CredentialsProvider): 76 """ 77 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 78 """ 79 80 _access_key: str 81 _secret_key: str 82 _session_token: Optional[str] 83 84 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 85 """ 86 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 87 session token. 88 89 :param access_key: The access key for S3 authentication. 90 :param secret_key: The secret key for S3 authentication. 91 :param session_token: An optional session token for temporary credentials. 92 """ 93 self._access_key = access_key 94 self._secret_key = secret_key 95 self._session_token = session_token 96
[docs] 97 def get_credentials(self) -> Credentials: 98 return Credentials( 99 access_key=self._access_key, 100 secret_key=self._secret_key, 101 token=self._session_token, 102 expiration=None, 103 )
104
[docs] 105 def refresh_credentials(self) -> None: 106 pass
107 108
[docs] 109class S3StorageProvider(BaseStorageProvider): 110 """ 111 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 112 """ 113 114 def __init__( 115 self, 116 region_name: str = "", 117 endpoint_url: str = "", 118 base_path: str = "", 119 credentials_provider: Optional[CredentialsProvider] = None, 120 config_dict: Optional[dict[str, Any]] = None, 121 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 122 verify: Optional[Union[bool, str]] = None, 123 **kwargs: Any, 124 ) -> None: 125 """ 126 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 127 128 :param region_name: The AWS region where the S3 bucket is located. 129 :param endpoint_url: The custom endpoint URL for the S3 service. 130 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 131 :param credentials_provider: The provider to retrieve S3 credentials. 132 :param config_dict: Resolved MSC config. 133 :param telemetry_provider: A function that provides a telemetry instance. 134 :param verify: Controls SSL certificate verification. 135 Can be ``True`` (verify using system CA bundle, default), ``False`` (skip verification), or a string path to a custom CA certificate bundle. 136 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. 137 When the underlying S3 client should calculate request checksums. 138 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 139 :param response_checksum_validation: For :py:class:`botocore.config.Config`. 140 When the underlying S3 client should validate response checksums. 141 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 142 :param max_pool_connections: For :py:class:`botocore.config.Config`. 143 The maximum number of connections to keep in a connection pool. 144 :param connect_timeout: For :py:class:`botocore.config.Config`. 145 The time in seconds till a timeout exception is thrown when attempting to make a connection. 146 :param read_timeout: For :py:class:`botocore.config.Config`. 147 The time in seconds till a timeout exception is thrown when attempting to read from a connection. 148 :param retries: For :py:class:`botocore.config.Config`. 149 A dictionary for configuration related to retry behavior. 150 :param s3: For :py:class:`botocore.config.Config`. 151 A dictionary of S3 specific configurations. 152 """ 153 super().__init__( 154 base_path=base_path, 155 provider_name=PROVIDER, 156 config_dict=config_dict, 157 telemetry_provider=telemetry_provider, 158 ) 159 160 self._region_name = region_name 161 self._endpoint_url = endpoint_url 162 self._credentials_provider = credentials_provider 163 self._verify = verify 164 165 self._signature_version = kwargs.get("signature_version", "") 166 self._s3_client = self._create_s3_client( 167 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 168 response_checksum_validation=kwargs.get("response_checksum_validation"), 169 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS), 170 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT), 171 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT), 172 retries=kwargs.get("retries"), 173 s3=kwargs.get("s3"), 174 ) 175 self._transfer_config = TransferConfig( 176 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 177 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)), 178 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)), 179 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)), 180 use_threads=True, 181 ) 182 183 self._rust_client = None 184 if "rust_client" in kwargs: 185 # Inherit the rust client options from the kwargs 186 rust_client_options = kwargs["rust_client"] 187 if "max_pool_connections" in kwargs: 188 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"] 189 if "max_concurrency" in kwargs: 190 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 191 if "multipart_chunksize" in kwargs: 192 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 193 if "read_timeout" in kwargs: 194 rust_client_options["read_timeout"] = kwargs["read_timeout"] 195 if "connect_timeout" in kwargs: 196 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 197 if self._signature_version == "UNSIGNED": 198 rust_client_options["skip_signature"] = True 199 self._rust_client = self._create_rust_client(rust_client_options) 200 201 def _is_directory_bucket(self, bucket: str) -> bool: 202 """ 203 Determines if the bucket is a directory bucket based on bucket name. 204 """ 205 # S3 Express buckets have a specific naming convention 206 return "--x-s3" in bucket 207 208 def _create_s3_client( 209 self, 210 request_checksum_calculation: Optional[str] = None, 211 response_checksum_validation: Optional[str] = None, 212 max_pool_connections: int = MAX_POOL_CONNECTIONS, 213 connect_timeout: Union[float, int, None] = None, 214 read_timeout: Union[float, int, None] = None, 215 retries: Optional[dict[str, Any]] = None, 216 s3: Optional[dict[str, Any]] = None, 217 ): 218 """ 219 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 220 221 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 222 :param response_checksum_validation: For :py:class:`botocore.config.Config`. When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 223 :param max_pool_connections: For :py:class:`botocore.config.Config`. The maximum number of connections to keep in a connection pool. 224 :param connect_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to make a connection. 225 :param read_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to read from a connection. 226 :param retries: For :py:class:`botocore.config.Config`. A dictionary for configuration related to retry behavior. 227 :param s3: For :py:class:`botocore.config.Config`. A dictionary of S3 specific configurations. 228 229 :return: The configured S3 client. 230 """ 231 options = { 232 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 233 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 234 max_pool_connections=max_pool_connections, 235 connect_timeout=connect_timeout, 236 read_timeout=read_timeout, 237 retries=retries or {"mode": "standard"}, 238 request_checksum_calculation=request_checksum_calculation, 239 response_checksum_validation=response_checksum_validation, 240 s3=s3, 241 ), 242 } 243 244 if self._region_name: 245 options["region_name"] = self._region_name 246 247 if self._endpoint_url: 248 options["endpoint_url"] = self._endpoint_url 249 250 if self._verify is not None: 251 options["verify"] = self._verify 252 253 if self._credentials_provider: 254 creds = self._fetch_credentials() 255 if "expiry_time" in creds and creds["expiry_time"]: 256 # Use RefreshableCredentials if expiry_time provided. 257 refreshable_credentials = RefreshableCredentials.create_from_metadata( 258 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 259 ) 260 261 botocore_session = get_session() 262 botocore_session._credentials = refreshable_credentials 263 264 boto3_session = boto3.Session(botocore_session=botocore_session) 265 266 return boto3_session.client("s3", **options) 267 else: 268 # Add static credentials to the options dictionary 269 options["aws_access_key_id"] = creds["access_key"] 270 options["aws_secret_access_key"] = creds["secret_key"] 271 if creds["token"]: 272 options["aws_session_token"] = creds["token"] 273 274 if self._signature_version: 275 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 276 signature_version=botocore.UNSIGNED 277 if self._signature_version == "UNSIGNED" 278 else self._signature_version 279 ) 280 options["config"] = options["config"].merge(signature_config) 281 282 # Fallback to standard credential chain. 283 return boto3.client("s3", **options) 284 285 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 286 """ 287 Creates and configures the rust client, using refreshable credentials if possible. 288 """ 289 configs = dict(rust_client_options) if rust_client_options else {} 290 291 # Extract and parse retry configuration 292 retry_config = parse_retry_config(configs) 293 294 if self._region_name and "region_name" not in configs: 295 configs["region_name"] = self._region_name 296 297 if self._endpoint_url and "endpoint_url" not in configs: 298 configs["endpoint_url"] = self._endpoint_url 299 300 # If the user specifies a bucket, use it. Otherwise, use the base path. 301 if "bucket" not in configs: 302 bucket, _ = split_path(self._base_path) 303 configs["bucket"] = bucket 304 305 if "max_pool_connections" not in configs: 306 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS 307 308 return RustClient( 309 provider=PROVIDER, 310 configs=configs, 311 credentials_provider=self._credentials_provider, 312 retry=retry_config, 313 ) 314 315 def _fetch_credentials(self) -> dict: 316 """ 317 Refreshes the S3 client if the current credentials are expired. 318 """ 319 if not self._credentials_provider: 320 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 321 self._credentials_provider.refresh_credentials() 322 credentials = self._credentials_provider.get_credentials() 323 return { 324 "access_key": credentials.access_key, 325 "secret_key": credentials.secret_key, 326 "token": credentials.token, 327 "expiry_time": credentials.expiration, 328 } 329 330 def _translate_errors( 331 self, 332 func: Callable[[], _T], 333 operation: str, 334 bucket: str, 335 key: str, 336 ) -> _T: 337 """ 338 Translates errors like timeouts and client errors. 339 340 :param func: The function that performs the actual S3 operation. 341 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 342 :param bucket: The name of the S3 bucket involved in the operation. 343 :param key: The key of the object within the S3 bucket. 344 345 :return: The result of the S3 operation, typically the return value of the `func` callable. 346 """ 347 try: 348 return func() 349 except ClientError as error: 350 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 351 request_id = error.response["ResponseMetadata"].get("RequestId") 352 host_id = error.response["ResponseMetadata"].get("HostId") 353 error_code = error.response["Error"]["Code"] 354 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 355 356 if status_code == 404: 357 if error_code == "NoSuchUpload": 358 error_message = error.response["Error"]["Message"] 359 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 360 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 361 elif status_code == 412: # Precondition Failed 362 raise PreconditionFailedError( 363 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 364 ) from error 365 elif status_code == 429: 366 raise RetryableError( 367 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 368 ) from error 369 elif status_code == 503: 370 raise RetryableError( 371 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 372 ) from error 373 elif status_code == 501: 374 raise NotImplementedError( 375 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 376 ) from error 377 elif status_code == 408: 378 # 408 Request Timeout is from Google Cloud Storage 379 raise RetryableError( 380 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}" 381 ) from error 382 else: 383 raise RuntimeError( 384 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 385 f"error_type: {type(error).__name__}" 386 ) from error 387 except RustClientError as error: 388 message = error.args[0] 389 status_code = error.args[1] 390 if status_code == 404: 391 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 392 elif status_code == 403: 393 raise PermissionError( 394 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 395 ) from error 396 else: 397 raise RetryableError( 398 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 399 ) from error 400 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 401 raise RetryableError( 402 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 403 f"error_type: {type(error).__name__}" 404 ) from error 405 except RustRetryableError as error: 406 raise RetryableError( 407 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 408 f"error_type: {type(error).__name__}" 409 ) from error 410 except Exception as error: 411 raise RuntimeError( 412 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 413 ) from error 414 415 def _put_object( 416 self, 417 path: str, 418 body: bytes, 419 if_match: Optional[str] = None, 420 if_none_match: Optional[str] = None, 421 attributes: Optional[dict[str, str]] = None, 422 content_type: Optional[str] = None, 423 ) -> int: 424 """ 425 Uploads an object to the specified S3 path. 426 427 :param path: The S3 path where the object will be uploaded. 428 :param body: The content of the object as bytes. 429 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 430 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 431 :param attributes: Optional attributes to attach to the object. 432 :param content_type: Optional Content-Type header value. 433 """ 434 bucket, key = split_path(path) 435 436 def _invoke_api() -> int: 437 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 438 if content_type: 439 kwargs["ContentType"] = content_type 440 if self._is_directory_bucket(bucket): 441 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 442 if if_match: 443 kwargs["IfMatch"] = if_match 444 if if_none_match: 445 kwargs["IfNoneMatch"] = if_none_match 446 validated_attributes = validate_attributes(attributes) 447 if validated_attributes: 448 kwargs["Metadata"] = validated_attributes 449 450 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client 451 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"} 452 if ( 453 self._rust_client 454 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 455 and not path.endswith("/") 456 and all(key not in kwargs for key in rust_unsupported_feature_keys) 457 ): 458 run_async_rust_client_method(self._rust_client, "put", key, body) 459 else: 460 self._s3_client.put_object(**kwargs) 461 462 return len(body) 463 464 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 465 466 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 467 bucket, key = split_path(path) 468 469 def _invoke_api() -> bytes: 470 if byte_range: 471 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 472 if self._rust_client: 473 response = run_async_rust_client_method( 474 self._rust_client, 475 "get", 476 key, 477 byte_range, 478 ) 479 return response 480 else: 481 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 482 else: 483 if self._rust_client: 484 response = run_async_rust_client_method(self._rust_client, "get", key) 485 return response 486 else: 487 response = self._s3_client.get_object(Bucket=bucket, Key=key) 488 489 return response["Body"].read() 490 491 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 492 493 def _copy_object(self, src_path: str, dest_path: str) -> int: 494 src_bucket, src_key = split_path(src_path) 495 dest_bucket, dest_key = split_path(dest_path) 496 497 src_object = self._get_object_metadata(src_path) 498 499 def _invoke_api() -> int: 500 self._s3_client.copy( 501 CopySource={"Bucket": src_bucket, "Key": src_key}, 502 Bucket=dest_bucket, 503 Key=dest_key, 504 Config=self._transfer_config, 505 ) 506 507 return src_object.content_length 508 509 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key) 510 511 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 512 bucket, key = split_path(path) 513 514 def _invoke_api() -> None: 515 # Delete conditionally when if_match (etag) is provided; otherwise delete unconditionally 516 if if_match: 517 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 518 else: 519 self._s3_client.delete_object(Bucket=bucket, Key=key) 520 521 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 522 523 def _delete_objects(self, paths: list[str]) -> None: 524 if not paths: 525 return 526 527 by_bucket: dict[str, list[str]] = {} 528 for p in paths: 529 bucket, key = split_path(p) 530 by_bucket.setdefault(bucket, []).append(key) 531 532 S3_BATCH_LIMIT = 1000 533 534 def _invoke_api() -> None: 535 all_errors: list[str] = [] 536 for bucket, keys in by_bucket.items(): 537 for i in range(0, len(keys), S3_BATCH_LIMIT): 538 chunk = keys[i : i + S3_BATCH_LIMIT] 539 response = self._s3_client.delete_objects( 540 Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]} 541 ) 542 errors = response.get("Errors") or [] 543 for e in errors: 544 all_errors.append(f"{bucket}/{e.get('Key', '?')}: {e.get('Code', '')} {e.get('Message', '')}") 545 if all_errors: 546 raise RuntimeError(f"DeleteObjects reported errors: {'; '.join(all_errors)}") 547 548 bucket_desc = "(" + "|".join(by_bucket) + ")" 549 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)" 550 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc) 551 552 def _is_dir(self, path: str) -> bool: 553 # Ensure the path ends with '/' to mimic a directory 554 path = self._append_delimiter(path) 555 556 bucket, key = split_path(path) 557 558 def _invoke_api() -> bool: 559 # List objects with the given prefix 560 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 561 562 # Check if there are any contents or common prefixes 563 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 564 565 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 566 567 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 568 bucket, key = split_path(path) 569 if path.endswith("/") or (bucket and not key): 570 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 571 # which metadata is not guaranteed to exist for cases such as 572 # "virtual prefix" that was never explicitly created. 573 if self._is_dir(path): 574 return ObjectMetadata( 575 key=path, 576 type="directory", 577 content_length=0, 578 last_modified=AWARE_DATETIME_MIN, 579 ) 580 else: 581 raise FileNotFoundError(f"Directory {path} does not exist.") 582 else: 583 584 def _invoke_api() -> ObjectMetadata: 585 response = self._s3_client.head_object(Bucket=bucket, Key=key) 586 587 return ObjectMetadata( 588 key=path, 589 type="file", 590 content_length=response["ContentLength"], 591 content_type=response.get("ContentType"), 592 last_modified=response["LastModified"], 593 etag=response["ETag"].strip('"') if "ETag" in response else None, 594 storage_class=response.get("StorageClass"), 595 metadata=response.get("Metadata"), 596 ) 597 598 try: 599 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 600 except FileNotFoundError as error: 601 if strict: 602 # If the object does not exist on the given path, we will append a trailing slash and 603 # check if the path is a directory. 604 path = self._append_delimiter(path) 605 if self._is_dir(path): 606 return ObjectMetadata( 607 key=path, 608 type="directory", 609 content_length=0, 610 last_modified=AWARE_DATETIME_MIN, 611 ) 612 raise error 613 614 def _list_objects( 615 self, 616 path: str, 617 start_after: Optional[str] = None, 618 end_at: Optional[str] = None, 619 include_directories: bool = False, 620 follow_symlinks: bool = True, 621 ) -> Iterator[ObjectMetadata]: 622 bucket, prefix = split_path(path) 623 624 # Get the prefix of the start_after and end_at paths relative to the bucket. 625 if start_after: 626 _, start_after = split_path(start_after) 627 if end_at: 628 _, end_at = split_path(end_at) 629 630 def _invoke_api() -> Iterator[ObjectMetadata]: 631 paginator = self._s3_client.get_paginator("list_objects_v2") 632 if include_directories: 633 page_iterator = paginator.paginate( 634 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 635 ) 636 else: 637 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 638 639 for page in page_iterator: 640 for item in page.get("CommonPrefixes", []): 641 prefix_key = item["Prefix"].rstrip("/") 642 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes 643 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 644 yield ObjectMetadata( 645 key=os.path.join(bucket, prefix_key), 646 type="directory", 647 content_length=0, 648 last_modified=AWARE_DATETIME_MIN, 649 ) 650 elif end_at is not None and end_at < prefix_key: 651 return 652 653 # S3 guarantees lexicographical order for general purpose buckets (for 654 # normal S3) but not directory buckets (for S3 Express One Zone). 655 for response_object in page.get("Contents", []): 656 key = response_object["Key"] 657 if end_at is None or key <= end_at: 658 if key.endswith("/"): 659 if include_directories: 660 yield ObjectMetadata( 661 key=os.path.join(bucket, key.rstrip("/")), 662 type="directory", 663 content_length=0, 664 last_modified=response_object["LastModified"], 665 ) 666 else: 667 yield ObjectMetadata( 668 key=os.path.join(bucket, key), 669 type="file", 670 content_length=response_object["Size"], 671 last_modified=response_object["LastModified"], 672 etag=response_object["ETag"].strip('"'), 673 storage_class=response_object.get("StorageClass"), # Pass storage_class 674 ) 675 else: 676 return 677 678 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 679 680 @property 681 def supports_parallel_listing(self) -> bool: 682 """ 683 S3 supports parallel listing via delimiter-based prefix discovery. 684 685 Note: Directory bucket handling is done in list_objects_recursive(). 686 """ 687 return True 688
[docs] 689 def list_objects_recursive( 690 self, 691 path: str = "", 692 start_after: Optional[str] = None, 693 end_at: Optional[str] = None, 694 max_workers: int = 32, 695 look_ahead: int = 2, 696 follow_symlinks: bool = True, 697 ) -> Iterator[ObjectMetadata]: 698 """ 699 List all objects recursively using parallel prefix discovery for improved performance. 700 701 For S3, uses the Rust client's list_recursive when available for maximum performance. 702 Falls back to Python implementation otherwise. 703 704 Returns files only (no directories), in lexicographic order. 705 706 :param follow_symlinks: Whether to follow symbolic links (POSIX providers only). 707 """ 708 if (start_after is not None) and (end_at is not None) and not (start_after < end_at): 709 raise ValueError(f"start_after ({start_after}) must be before end_at ({end_at})!") 710 711 full_path = self._prepend_base_path(path) 712 bucket, prefix = split_path(full_path) 713 714 if self._is_directory_bucket(bucket): 715 yield from self.list_objects( 716 path, start_after, end_at, include_directories=False, follow_symlinks=follow_symlinks 717 ) 718 return 719 720 if self._rust_client: 721 yield from self._emit_metrics( 722 operation=BaseStorageProvider._Operation.LIST, 723 f=lambda: self._list_objects_recursive_rust(path, full_path, bucket, start_after, end_at, max_workers), 724 ) 725 else: 726 yield from super().list_objects_recursive( 727 path, start_after, end_at, max_workers, look_ahead, follow_symlinks 728 )
729 730 def _list_objects_recursive_rust( 731 self, 732 path: str, 733 full_path: str, 734 bucket: str, 735 start_after: Optional[str], 736 end_at: Optional[str], 737 max_workers: int, 738 ) -> Iterator[ObjectMetadata]: 739 """ 740 Use Rust client's list_recursive for parallel listing. 741 742 The Rust client already handles parallel listing internally. 743 Returns files only in lexicographic order. 744 """ 745 _, prefix = split_path(full_path) 746 747 def _invoke_api() -> Iterator[ObjectMetadata]: 748 result = run_async_rust_client_method( 749 self._rust_client, 750 "list_recursive", 751 [prefix] if prefix else [""], 752 max_concurrency=max_workers, 753 ) 754 755 start_after_full = self._prepend_base_path(start_after) if start_after else None 756 end_at_full = self._prepend_base_path(end_at) if end_at else None 757 758 for obj in result.objects: 759 full_key = os.path.join(bucket, obj.key) 760 761 if start_after_full and full_key <= start_after_full: 762 continue 763 if end_at_full and full_key > end_at_full: 764 break 765 766 relative_key = full_key.removeprefix(self._base_path).lstrip("/") 767 768 yield ObjectMetadata( 769 key=relative_key, 770 content_length=obj.content_length, 771 last_modified=dateutil_parse(obj.last_modified), 772 type="file" if obj.object_type == "object" else obj.object_type, 773 etag=obj.etag, 774 ) 775 776 yield from self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 777 778 def _upload_file( 779 self, 780 remote_path: str, 781 f: Union[str, IO], 782 attributes: Optional[dict[str, str]] = None, 783 content_type: Optional[str] = None, 784 ) -> int: 785 file_size: int = 0 786 787 if isinstance(f, str): 788 bucket, key = split_path(remote_path) 789 file_size = os.path.getsize(f) 790 791 # Upload small files 792 if file_size <= self._transfer_config.multipart_threshold: 793 if self._rust_client and not attributes and not content_type: 794 run_async_rust_client_method(self._rust_client, "upload", f, key) 795 else: 796 with open(f, "rb") as fp: 797 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type) 798 return file_size 799 800 # Upload large files using TransferConfig 801 def _invoke_api() -> int: 802 extra_args = {} 803 if content_type: 804 extra_args["ContentType"] = content_type 805 if self._is_directory_bucket(bucket): 806 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 807 validated_attributes = validate_attributes(attributes) 808 if validated_attributes: 809 extra_args["Metadata"] = validated_attributes 810 if self._rust_client and not extra_args: 811 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 812 else: 813 self._s3_client.upload_file( 814 Filename=f, 815 Bucket=bucket, 816 Key=key, 817 Config=self._transfer_config, 818 ExtraArgs=extra_args, 819 ) 820 821 return file_size 822 823 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 824 else: 825 # Upload small files 826 f.seek(0, io.SEEK_END) 827 file_size = f.tell() 828 f.seek(0) 829 830 if file_size <= self._transfer_config.multipart_threshold: 831 if isinstance(f, io.StringIO): 832 self._put_object( 833 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type 834 ) 835 else: 836 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type) 837 return file_size 838 839 # Upload large files using TransferConfig 840 bucket, key = split_path(remote_path) 841 842 def _invoke_api() -> int: 843 extra_args = {} 844 if content_type: 845 extra_args["ContentType"] = content_type 846 if self._is_directory_bucket(bucket): 847 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 848 validated_attributes = validate_attributes(attributes) 849 if validated_attributes: 850 extra_args["Metadata"] = validated_attributes 851 852 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args: 853 data = f.getbuffer() 854 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data) 855 else: 856 self._s3_client.upload_fileobj( 857 Fileobj=f, 858 Bucket=bucket, 859 Key=key, 860 Config=self._transfer_config, 861 ExtraArgs=extra_args, 862 ) 863 864 return file_size 865 866 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 867 868 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 869 if metadata is None: 870 metadata = self._get_object_metadata(remote_path) 871 872 if isinstance(f, str): 873 bucket, key = split_path(remote_path) 874 if os.path.dirname(f): 875 safe_makedirs(os.path.dirname(f)) 876 877 # Download small files 878 if metadata.content_length <= self._transfer_config.multipart_threshold: 879 if self._rust_client: 880 run_async_rust_client_method(self._rust_client, "download", key, f) 881 else: 882 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 883 temp_file_path = fp.name 884 fp.write(self._get_object(remote_path)) 885 os.rename(src=temp_file_path, dst=f) 886 return metadata.content_length 887 888 # Download large files using TransferConfig 889 def _invoke_api() -> int: 890 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 891 temp_file_path = fp.name 892 if self._rust_client: 893 run_async_rust_client_method( 894 self._rust_client, "download_multipart_to_file", key, temp_file_path 895 ) 896 else: 897 self._s3_client.download_fileobj( 898 Bucket=bucket, 899 Key=key, 900 Fileobj=fp, 901 Config=self._transfer_config, 902 ) 903 904 os.rename(src=temp_file_path, dst=f) 905 906 return metadata.content_length 907 908 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 909 else: 910 # Download small files 911 if metadata.content_length <= self._transfer_config.multipart_threshold: 912 response = self._get_object(remote_path) 913 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol, 914 # so we need to check whether `.decode()` is available. 915 if isinstance(f, io.StringIO): 916 if hasattr(response, "decode"): 917 f.write(response.decode("utf-8")) 918 else: 919 f.write(codecs.decode(memoryview(response), "utf-8")) 920 else: 921 f.write(response) 922 return metadata.content_length 923 924 # Download large files using TransferConfig 925 bucket, key = split_path(remote_path) 926 927 def _invoke_api() -> int: 928 self._s3_client.download_fileobj( 929 Bucket=bucket, 930 Key=key, 931 Fileobj=f, 932 Config=self._transfer_config, 933 ) 934 935 return metadata.content_length 936 937 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)