Source code for multistorageclient.providers.s3

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import io
 18import os
 19import tempfile
 20import time
 21from collections.abc import Callable, Iterator, Sequence, Sized
 22from typing import IO, Any, Optional, TypeVar, Union
 23
 24import boto3
 25import botocore
 26import opentelemetry.metrics as api_metrics
 27from boto3.s3.transfer import TransferConfig
 28from botocore.credentials import RefreshableCredentials
 29from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
 30from botocore.session import get_session
 31
 32from ..instrumentation.utils import set_span_attribute
 33from ..rust_utils import run_async_rust_client_method
 34from ..telemetry import Telemetry
 35from ..telemetry.attributes.base import AttributesProvider
 36from ..types import (
 37    AWARE_DATETIME_MIN,
 38    Credentials,
 39    CredentialsProvider,
 40    ObjectMetadata,
 41    PreconditionFailedError,
 42    Range,
 43    RetryableError,
 44)
 45from ..utils import (
 46    split_path,
 47    validate_attributes,
 48)
 49from .base import BaseStorageProvider
 50
 51_T = TypeVar("_T")
 52
 53BOTO3_MAX_POOL_CONNECTIONS = 32
 54
 55MiB = 1024 * 1024
 56
 57# Python and Rust share the same multipart_threshold to keep the code simple.
 58MULTIPART_THRESHOLD = 64 * MiB
 59MULTIPART_CHUNKSIZE = 32 * MiB
 60IO_CHUNKSIZE = 32 * MiB
 61# Python uses a lower default concurrency due to the GIL limiting true parallelism in threads.
 62PYTHON_MAX_CONCURRENCY = 16
 63RUST_MAX_CONCURRENCY = 32
 64PROVIDER = "s3"
 65
 66EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
 67
 68
 69def _extract_x_trans_id(response: Any) -> None:
 70    """Extract x-trans-id from boto3 response and set it as span attribute."""
 71    try:
 72        if response and isinstance(response, dict):
 73            headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
 74            if headers and isinstance(headers, dict) and "x-trans-id" in headers:
 75                set_span_attribute("x_trans_id", headers["x-trans-id"])
 76    except (KeyError, AttributeError, TypeError):
 77        # Silently ignore any errors in extraction
 78        pass
 79
 80
[docs] 81class StaticS3CredentialsProvider(CredentialsProvider): 82 """ 83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 84 """ 85 86 _access_key: str 87 _secret_key: str 88 _session_token: Optional[str] 89 90 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 91 """ 92 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 93 session token. 94 95 :param access_key: The access key for S3 authentication. 96 :param secret_key: The secret key for S3 authentication. 97 :param session_token: An optional session token for temporary credentials. 98 """ 99 self._access_key = access_key 100 self._secret_key = secret_key 101 self._session_token = session_token 102
[docs] 103 def get_credentials(self) -> Credentials: 104 return Credentials( 105 access_key=self._access_key, 106 secret_key=self._secret_key, 107 token=self._session_token, 108 expiration=None, 109 )
110
[docs] 111 def refresh_credentials(self) -> None: 112 pass
113 114
[docs] 115class S3StorageProvider(BaseStorageProvider): 116 """ 117 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 118 """ 119 120 def __init__( 121 self, 122 region_name: str = "", 123 endpoint_url: str = "", 124 base_path: str = "", 125 credentials_provider: Optional[CredentialsProvider] = None, 126 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {}, 127 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {}, 128 metric_attributes_providers: Sequence[AttributesProvider] = (), 129 **kwargs: Any, 130 ) -> None: 131 """ 132 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 133 134 :param region_name: The AWS region where the S3 bucket is located. 135 :param endpoint_url: The custom endpoint URL for the S3 service. 136 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 137 :param credentials_provider: The provider to retrieve S3 credentials. 138 :param metric_counters: Metric counters. 139 :param metric_gauges: Metric gauges. 140 :param metric_attributes_providers: Metric attributes providers. 141 """ 142 super().__init__( 143 base_path=base_path, 144 provider_name=PROVIDER, 145 metric_counters=metric_counters, 146 metric_gauges=metric_gauges, 147 metric_attributes_providers=metric_attributes_providers, 148 ) 149 150 self._region_name = region_name 151 self._endpoint_url = endpoint_url 152 self._credentials_provider = credentials_provider 153 154 self._signature_version = kwargs.get("signature_version", "") 155 self._s3_client = self._create_s3_client( 156 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 157 response_checksum_validation=kwargs.get("response_checksum_validation"), 158 max_pool_connections=kwargs.get("max_pool_connections", BOTO3_MAX_POOL_CONNECTIONS), 159 connect_timeout=kwargs.get("connect_timeout"), 160 read_timeout=kwargs.get("read_timeout"), 161 retries=kwargs.get("retries"), 162 ) 163 self._transfer_config = TransferConfig( 164 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 165 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)), 166 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)), 167 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)), 168 use_threads=True, 169 ) 170 171 self._rust_client = None 172 if "rust_client" in kwargs: 173 self._rust_client = self._create_rust_client(kwargs.get("rust_client")) 174 175 def _is_directory_bucket(self, bucket: str) -> bool: 176 """ 177 Determines if the bucket is a directory bucket based on bucket name. 178 """ 179 # S3 Express buckets have a specific naming convention 180 return "--x-s3" in bucket 181 182 def _create_s3_client( 183 self, 184 request_checksum_calculation: Optional[str] = None, 185 response_checksum_validation: Optional[str] = None, 186 max_pool_connections: int = BOTO3_MAX_POOL_CONNECTIONS, 187 connect_timeout: Union[float, int, None] = None, 188 read_timeout: Union[float, int, None] = None, 189 retries: Optional[dict[str, Any]] = None, 190 ): 191 """ 192 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 193 194 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 195 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 196 :param max_pool_connections: The maximum number of connections to keep in a connection pool. 197 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection. 198 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection. 199 :param retries: A dictionary for configuration related to retry behavior. 200 201 :return: The configured S3 client. 202 """ 203 options = { 204 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 205 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 206 max_pool_connections=max_pool_connections, 207 connect_timeout=connect_timeout, 208 read_timeout=read_timeout, 209 retries=retries or {"mode": "standard"}, 210 request_checksum_calculation=request_checksum_calculation, 211 response_checksum_validation=response_checksum_validation, 212 ), 213 } 214 215 if self._region_name: 216 options["region_name"] = self._region_name 217 218 if self._endpoint_url: 219 options["endpoint_url"] = self._endpoint_url 220 221 if self._credentials_provider: 222 creds = self._fetch_credentials() 223 if "expiry_time" in creds and creds["expiry_time"]: 224 # Use RefreshableCredentials if expiry_time provided. 225 refreshable_credentials = RefreshableCredentials.create_from_metadata( 226 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 227 ) 228 229 botocore_session = get_session() 230 botocore_session._credentials = refreshable_credentials 231 232 boto3_session = boto3.Session(botocore_session=botocore_session) 233 234 return boto3_session.client("s3", **options) 235 else: 236 # Add static credentials to the options dictionary 237 options["aws_access_key_id"] = creds["access_key"] 238 options["aws_secret_access_key"] = creds["secret_key"] 239 if creds["token"]: 240 options["aws_session_token"] = creds["token"] 241 242 if self._signature_version: 243 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 244 signature_version=botocore.UNSIGNED 245 if self._signature_version == "UNSIGNED" 246 else self._signature_version 247 ) 248 options["config"] = options["config"].merge(signature_config) 249 250 # Fallback to standard credential chain. 251 return boto3.client("s3", **options) 252 253 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 254 """ 255 Creates and configures the rust client, using refreshable credentials if possible. 256 """ 257 from multistorageclient_rust import RustClient 258 259 configs = {} 260 if self._region_name: 261 configs["region_name"] = self._region_name 262 263 # If the user specifies a bucket, use it. Otherwise, use the base path. 264 if rust_client_options and "bucket" in rust_client_options: 265 configs["bucket"] = rust_client_options["bucket"] 266 else: 267 bucket, _ = split_path(self._base_path) 268 configs["bucket"] = bucket 269 270 if self._endpoint_url: 271 configs["endpoint_url"] = self._endpoint_url 272 273 if self._credentials_provider: 274 creds = self._fetch_credentials() 275 if "expiry_time" in creds and creds["expiry_time"]: 276 # TODO: Implement refreshable credentials 277 raise NotImplementedError("Refreshable credentials are not yet implemented for the rust client.") 278 else: 279 # Add static credentials to the configs dictionary 280 configs["aws_access_key_id"] = creds["access_key"] 281 configs["aws_secret_access_key"] = creds["secret_key"] 282 if creds["token"]: 283 configs["aws_session_token"] = creds["token"] 284 285 if rust_client_options: 286 if rust_client_options.get("allow_http", False): 287 configs["allow_http"] = True 288 configs["max_concurrency"] = rust_client_options.get("max_concurrency", RUST_MAX_CONCURRENCY) 289 configs["multipart_chunksize"] = rust_client_options.get("multipart_chunksize", MULTIPART_CHUNKSIZE) 290 291 return RustClient( 292 provider=PROVIDER, 293 configs=configs, 294 ) 295 296 def _fetch_credentials(self) -> dict: 297 """ 298 Refreshes the S3 client if the current credentials are expired. 299 """ 300 if not self._credentials_provider: 301 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 302 self._credentials_provider.refresh_credentials() 303 credentials = self._credentials_provider.get_credentials() 304 return { 305 "access_key": credentials.access_key, 306 "secret_key": credentials.secret_key, 307 "token": credentials.token, 308 "expiry_time": credentials.expiration, 309 } 310 311 def _collect_metrics( 312 self, 313 func: Callable[[], _T], 314 operation: str, 315 bucket: str, 316 key: str, 317 put_object_size: Optional[int] = None, 318 get_object_size: Optional[int] = None, 319 ) -> _T: 320 """ 321 Collects and records performance metrics around S3 operations such as PUT, GET, DELETE, etc. 322 323 This method wraps an S3 operation and measures the time it takes to complete, along with recording 324 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 325 proper logging of duration and object size. 326 327 :param func: The function that performs the actual S3 operation. 328 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 329 :param bucket: The name of the S3 bucket involved in the operation. 330 :param key: The key of the object within the S3 bucket. 331 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 332 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 333 334 :return: The result of the S3 operation, typically the return value of the `func` callable. 335 """ 336 # Import the span attribute helper 337 from ..instrumentation.utils import set_span_attribute 338 339 # Set basic operation attributes 340 set_span_attribute("s3_operation", operation) 341 set_span_attribute("s3_bucket", bucket) 342 set_span_attribute("s3_key", key) 343 344 start_time = time.time() 345 status_code = 200 346 347 object_size = None 348 if operation == "PUT": 349 object_size = put_object_size 350 elif operation == "GET" and get_object_size: 351 object_size = get_object_size 352 353 try: 354 result = func() 355 if operation == "GET" and object_size is None and isinstance(result, Sized): 356 object_size = len(result) 357 return result 358 except ClientError as error: 359 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 360 request_id = error.response["ResponseMetadata"].get("RequestId") 361 host_id = error.response["ResponseMetadata"].get("HostId") 362 header = error.response["ResponseMetadata"].get("HTTPHeaders", {}) 363 error_code = error.response["Error"]["Code"] 364 365 # Ensure header is a dictionary before trying to get from it 366 x_trans_id = header.get("x-trans-id") if isinstance(header, dict) else None 367 368 # Record error details in span 369 set_span_attribute("request_id", request_id) 370 set_span_attribute("host_id", host_id) 371 372 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 373 if x_trans_id: 374 error_info += f", x-trans-id: {x_trans_id}" 375 set_span_attribute("x_trans_id", x_trans_id) 376 377 if status_code == 404: 378 if error_code == "NoSuchUpload": 379 error_message = error.response["Error"]["Message"] 380 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 381 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 382 elif status_code == 412: # Precondition Failed 383 raise PreconditionFailedError( 384 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 385 ) from error 386 elif status_code == 429: 387 raise RetryableError( 388 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 389 ) from error 390 elif status_code == 503: 391 raise RetryableError( 392 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 393 ) from error 394 elif status_code == 501: 395 raise NotImplementedError( 396 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 397 ) from error 398 else: 399 raise RuntimeError( 400 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 401 f"error_type: {type(error).__name__}" 402 ) from error 403 except FileNotFoundError as error: 404 status_code = -1 405 raise error 406 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 407 status_code = -1 408 raise RetryableError( 409 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 410 f"error_type: {type(error).__name__}" 411 ) from error 412 except Exception as error: 413 status_code = -1 414 raise RuntimeError( 415 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 416 ) from error 417 finally: 418 elapsed_time = time.time() - start_time 419 420 set_span_attribute("status_code", status_code) 421 422 # Record metrics 423 self._metric_helper.record_duration( 424 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code 425 ) 426 if object_size: 427 self._metric_helper.record_object_size( 428 object_size, 429 provider=self._provider_name, 430 operation=operation, 431 bucket=bucket, 432 status_code=status_code, 433 ) 434 435 set_span_attribute("object_size", object_size) 436 437 def _put_object( 438 self, 439 path: str, 440 body: bytes, 441 if_match: Optional[str] = None, 442 if_none_match: Optional[str] = None, 443 attributes: Optional[dict[str, str]] = None, 444 ) -> int: 445 """ 446 Uploads an object to the specified S3 path. 447 448 :param path: The S3 path where the object will be uploaded. 449 :param body: The content of the object as bytes. 450 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 451 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 452 :param attributes: Optional attributes to attach to the object. 453 """ 454 bucket, key = split_path(path) 455 456 def _invoke_api() -> int: 457 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 458 if self._is_directory_bucket(bucket): 459 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 460 if if_match: 461 kwargs["IfMatch"] = if_match 462 if if_none_match: 463 kwargs["IfNoneMatch"] = if_none_match 464 validated_attributes = validate_attributes(attributes) 465 if validated_attributes: 466 kwargs["Metadata"] = validated_attributes 467 468 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch"} 469 if ( 470 self._rust_client 471 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 472 and not path.endswith("/") 473 and all(key not in kwargs for key in rust_unsupported_feature_keys) 474 ): 475 response = run_async_rust_client_method(self._rust_client, "put", key, body) 476 else: 477 # Capture the response from put_object 478 response = self._s3_client.put_object(**kwargs) 479 480 # Extract and set x-trans-id if present 481 _extract_x_trans_id(response) 482 483 return len(body) 484 485 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 486 487 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 488 bucket, key = split_path(path) 489 490 def _invoke_api() -> bytes: 491 if byte_range: 492 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 493 if self._rust_client: 494 response = run_async_rust_client_method( 495 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 496 ) 497 return response 498 else: 499 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 500 else: 501 if self._rust_client: 502 response = run_async_rust_client_method(self._rust_client, "get", key) 503 return response 504 else: 505 response = self._s3_client.get_object(Bucket=bucket, Key=key) 506 507 # Extract and set x-trans-id if present 508 _extract_x_trans_id(response) 509 510 return response["Body"].read() 511 512 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 513 514 def _copy_object(self, src_path: str, dest_path: str) -> int: 515 src_bucket, src_key = split_path(src_path) 516 dest_bucket, dest_key = split_path(dest_path) 517 518 src_object = self._get_object_metadata(src_path) 519 520 def _invoke_api() -> int: 521 response = self._s3_client.copy( 522 CopySource={"Bucket": src_bucket, "Key": src_key}, 523 Bucket=dest_bucket, 524 Key=dest_key, 525 Config=self._transfer_config, 526 ) 527 528 # Extract and set x-trans-id if present 529 _extract_x_trans_id(response) 530 531 return src_object.content_length 532 533 return self._collect_metrics( 534 _invoke_api, 535 operation="COPY", 536 bucket=dest_bucket, 537 key=dest_key, 538 put_object_size=src_object.content_length, 539 ) 540 541 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 542 bucket, key = split_path(path) 543 544 def _invoke_api() -> None: 545 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally 546 if if_match: 547 response = self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 548 else: 549 response = self._s3_client.delete_object(Bucket=bucket, Key=key) 550 551 # Extract and set x-trans-id if present 552 _extract_x_trans_id(response) 553 554 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 555 556 def _is_dir(self, path: str) -> bool: 557 # Ensure the path ends with '/' to mimic a directory 558 path = self._append_delimiter(path) 559 560 bucket, key = split_path(path) 561 562 def _invoke_api() -> bool: 563 # List objects with the given prefix 564 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 565 566 # Extract and set x-trans-id if present 567 _extract_x_trans_id(response) 568 569 # Check if there are any contents or common prefixes 570 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 571 572 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) 573 574 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 575 bucket, key = split_path(path) 576 if path.endswith("/") or (bucket and not key): 577 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 578 # which metadata is not guaranteed to exist for cases such as 579 # "virtual prefix" that was never explicitly created. 580 if self._is_dir(path): 581 return ObjectMetadata( 582 key=path, 583 type="directory", 584 content_length=0, 585 last_modified=AWARE_DATETIME_MIN, 586 ) 587 else: 588 raise FileNotFoundError(f"Directory {path} does not exist.") 589 else: 590 591 def _invoke_api() -> ObjectMetadata: 592 response = self._s3_client.head_object(Bucket=bucket, Key=key) 593 594 # Extract and set x-trans-id if present 595 _extract_x_trans_id(response) 596 597 return ObjectMetadata( 598 key=path, 599 type="file", 600 content_length=response["ContentLength"], 601 content_type=response["ContentType"], 602 last_modified=response["LastModified"], 603 etag=response["ETag"].strip('"'), 604 storage_class=response.get("StorageClass"), 605 metadata=response.get("Metadata"), 606 ) 607 608 try: 609 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 610 except FileNotFoundError as error: 611 if strict: 612 # If the object does not exist on the given path, we will append a trailing slash and 613 # check if the path is a directory. 614 path = self._append_delimiter(path) 615 if self._is_dir(path): 616 return ObjectMetadata( 617 key=path, 618 type="directory", 619 content_length=0, 620 last_modified=AWARE_DATETIME_MIN, 621 ) 622 raise error 623 624 def _list_objects( 625 self, 626 path: str, 627 start_after: Optional[str] = None, 628 end_at: Optional[str] = None, 629 include_directories: bool = False, 630 ) -> Iterator[ObjectMetadata]: 631 bucket, prefix = split_path(path) 632 633 def _invoke_api() -> Iterator[ObjectMetadata]: 634 paginator = self._s3_client.get_paginator("list_objects_v2") 635 if include_directories: 636 page_iterator = paginator.paginate( 637 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 638 ) 639 else: 640 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 641 642 for page in page_iterator: 643 for item in page.get("CommonPrefixes", []): 644 yield ObjectMetadata( 645 key=os.path.join(bucket, item["Prefix"].rstrip("/")), 646 type="directory", 647 content_length=0, 648 last_modified=AWARE_DATETIME_MIN, 649 ) 650 651 # S3 guarantees lexicographical order for general purpose buckets (for 652 # normal S3) but not directory buckets (for S3 Express One Zone). 653 for response_object in page.get("Contents", []): 654 key = response_object["Key"] 655 if end_at is None or key <= end_at: 656 if key.endswith("/"): 657 if include_directories: 658 yield ObjectMetadata( 659 key=os.path.join(bucket, key.rstrip("/")), 660 type="directory", 661 content_length=0, 662 last_modified=response_object["LastModified"], 663 ) 664 else: 665 yield ObjectMetadata( 666 key=os.path.join(bucket, key), 667 type="file", 668 content_length=response_object["Size"], 669 last_modified=response_object["LastModified"], 670 etag=response_object["ETag"].strip('"'), 671 storage_class=response_object.get("StorageClass"), # Pass storage_class 672 ) 673 else: 674 return 675 676 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 677 678 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 679 file_size: int = 0 680 681 if isinstance(f, str): 682 bucket, key = split_path(remote_path) 683 file_size = os.path.getsize(f) 684 685 # Upload small files 686 if file_size <= self._transfer_config.multipart_threshold: 687 if self._rust_client and not attributes: 688 run_async_rust_client_method(self._rust_client, "upload", f, key) 689 else: 690 with open(f, "rb") as fp: 691 self._put_object(remote_path, fp.read(), attributes=attributes) 692 return file_size 693 694 # Upload large files using TransferConfig 695 def _invoke_api() -> int: 696 extra_args = {} 697 if self._is_directory_bucket(bucket): 698 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 699 validated_attributes = validate_attributes(attributes) 700 if validated_attributes: 701 extra_args["Metadata"] = validated_attributes 702 if self._rust_client and not extra_args: 703 response = run_async_rust_client_method(self._rust_client, "upload_multipart", f, key) 704 else: 705 response = self._s3_client.upload_file( 706 Filename=f, 707 Bucket=bucket, 708 Key=key, 709 Config=self._transfer_config, 710 ExtraArgs=extra_args, 711 ) 712 713 # Extract and set x-trans-id if present 714 _extract_x_trans_id(response) 715 716 return file_size 717 718 return self._collect_metrics( 719 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 720 ) 721 else: 722 # Upload small files 723 f.seek(0, io.SEEK_END) 724 file_size = f.tell() 725 f.seek(0) 726 727 if file_size <= self._transfer_config.multipart_threshold: 728 if isinstance(f, io.StringIO): 729 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 730 else: 731 self._put_object(remote_path, f.read(), attributes=attributes) 732 return file_size 733 734 # Upload large files using TransferConfig 735 bucket, key = split_path(remote_path) 736 737 def _invoke_api() -> int: 738 extra_args = {} 739 if self._is_directory_bucket(bucket): 740 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 741 validated_attributes = validate_attributes(attributes) 742 if validated_attributes: 743 extra_args["Metadata"] = validated_attributes 744 self._s3_client.upload_fileobj( 745 Fileobj=f, 746 Bucket=bucket, 747 Key=key, 748 Config=self._transfer_config, 749 ExtraArgs=extra_args, 750 ) 751 752 return file_size 753 754 return self._collect_metrics( 755 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 756 ) 757 758 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 759 if metadata is None: 760 metadata = self._get_object_metadata(remote_path) 761 762 if isinstance(f, str): 763 bucket, key = split_path(remote_path) 764 if os.path.dirname(f): 765 os.makedirs(os.path.dirname(f), exist_ok=True) 766 767 # Download small files 768 if metadata.content_length <= self._transfer_config.multipart_threshold: 769 if self._rust_client: 770 run_async_rust_client_method(self._rust_client, "download", key, f) 771 else: 772 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 773 temp_file_path = fp.name 774 fp.write(self._get_object(remote_path)) 775 os.rename(src=temp_file_path, dst=f) 776 return metadata.content_length 777 778 # Download large files using TransferConfig 779 def _invoke_api() -> int: 780 response = None 781 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 782 temp_file_path = fp.name 783 if self._rust_client: 784 response = run_async_rust_client_method( 785 self._rust_client, "download_multipart", key, temp_file_path 786 ) 787 else: 788 response = self._s3_client.download_fileobj( 789 Bucket=bucket, 790 Key=key, 791 Fileobj=fp, 792 Config=self._transfer_config, 793 ) 794 795 # Extract and set x-trans-id if present 796 _extract_x_trans_id(response) 797 os.rename(src=temp_file_path, dst=f) 798 799 return metadata.content_length 800 801 return self._collect_metrics( 802 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 803 ) 804 else: 805 # Download small files 806 if metadata.content_length <= self._transfer_config.multipart_threshold: 807 response = self._get_object(remote_path) 808 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 809 # so we need to check whether `.decode()` is available. 810 if isinstance(f, io.StringIO): 811 if hasattr(response, "decode"): 812 f.write(response.decode("utf-8")) 813 else: 814 f.write(codecs.decode(memoryview(response), "utf-8")) 815 else: 816 f.write(response) 817 return metadata.content_length 818 819 # Download large files using TransferConfig 820 bucket, key = split_path(remote_path) 821 822 def _invoke_api() -> int: 823 response = self._s3_client.download_fileobj( 824 Bucket=bucket, 825 Key=key, 826 Fileobj=f, 827 Config=self._transfer_config, 828 ) 829 830 # Extract and set x-trans-id if present 831 _extract_x_trans_id(response) 832 833 return metadata.content_length 834 835 return self._collect_metrics( 836 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 837 )