Source code for multistorageclient.providers.s3

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19import time
 20from collections.abc import Callable, Iterator, Sequence, Sized
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23import boto3
 24import botocore
 25import opentelemetry.metrics as api_metrics
 26from boto3.s3.transfer import TransferConfig
 27from botocore.credentials import RefreshableCredentials
 28from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
 29from botocore.session import get_session
 30
 31from ..instrumentation.utils import set_span_attribute
 32from ..rust_utils import run_async_rust_client_method
 33from ..telemetry import Telemetry
 34from ..telemetry.attributes.base import AttributesProvider
 35from ..types import (
 36    AWARE_DATETIME_MIN,
 37    Credentials,
 38    CredentialsProvider,
 39    ObjectMetadata,
 40    PreconditionFailedError,
 41    Range,
 42    RetryableError,
 43)
 44from ..utils import (
 45    split_path,
 46    validate_attributes,
 47)
 48from .base import BaseStorageProvider
 49
 50_T = TypeVar("_T")
 51
 52BOTO3_MAX_POOL_CONNECTIONS = 32
 53
 54MB = 1024 * 1024
 55
 56MULTIPART_THRESHOLD = 512 * MB
 57MULTIPART_CHUNK_SIZE = 256 * MB
 58IO_CHUNK_SIZE = 128 * MB
 59MAX_CONCURRENCY = 16
 60PROVIDER = "s3"
 61
 62EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
 63
 64
 65def _extract_x_trans_id(response: Any) -> None:
 66    """Extract x-trans-id from boto3 response and set it as span attribute."""
 67    try:
 68        if response and isinstance(response, dict):
 69            headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
 70            if headers and isinstance(headers, dict) and "x-trans-id" in headers:
 71                set_span_attribute("x_trans_id", headers["x-trans-id"])
 72    except (KeyError, AttributeError, TypeError):
 73        # Silently ignore any errors in extraction
 74        pass
 75
 76
[docs] 77class StaticS3CredentialsProvider(CredentialsProvider): 78 """ 79 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 80 """ 81 82 _access_key: str 83 _secret_key: str 84 _session_token: Optional[str] 85 86 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 87 """ 88 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 89 session token. 90 91 :param access_key: The access key for S3 authentication. 92 :param secret_key: The secret key for S3 authentication. 93 :param session_token: An optional session token for temporary credentials. 94 """ 95 self._access_key = access_key 96 self._secret_key = secret_key 97 self._session_token = session_token 98
[docs] 99 def get_credentials(self) -> Credentials: 100 return Credentials( 101 access_key=self._access_key, 102 secret_key=self._secret_key, 103 token=self._session_token, 104 expiration=None, 105 )
106
[docs] 107 def refresh_credentials(self) -> None: 108 pass
109 110
[docs] 111class S3StorageProvider(BaseStorageProvider): 112 """ 113 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 114 """ 115 116 def __init__( 117 self, 118 region_name: str = "", 119 endpoint_url: str = "", 120 base_path: str = "", 121 credentials_provider: Optional[CredentialsProvider] = None, 122 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {}, 123 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {}, 124 metric_attributes_providers: Sequence[AttributesProvider] = (), 125 **kwargs: Any, 126 ) -> None: 127 """ 128 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 129 130 :param region_name: The AWS region where the S3 bucket is located. 131 :param endpoint_url: The custom endpoint URL for the S3 service. 132 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 133 :param credentials_provider: The provider to retrieve S3 credentials. 134 :param metric_counters: Metric counters. 135 :param metric_gauges: Metric gauges. 136 :param metric_attributes_providers: Metric attributes providers. 137 """ 138 super().__init__( 139 base_path=base_path, 140 provider_name=PROVIDER, 141 metric_counters=metric_counters, 142 metric_gauges=metric_gauges, 143 metric_attributes_providers=metric_attributes_providers, 144 ) 145 146 self._region_name = region_name 147 self._endpoint_url = endpoint_url 148 self._credentials_provider = credentials_provider 149 150 self._signature_version = kwargs.get("signature_version", "") 151 self._s3_client = self._create_s3_client( 152 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 153 response_checksum_validation=kwargs.get("response_checksum_validation"), 154 max_pool_connections=kwargs.get("max_pool_connections", BOTO3_MAX_POOL_CONNECTIONS), 155 connect_timeout=kwargs.get("connect_timeout"), 156 read_timeout=kwargs.get("read_timeout"), 157 retries=kwargs.get("retries"), 158 ) 159 self._transfer_config = TransferConfig( 160 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 161 max_concurrency=int(kwargs.get("max_concurrency", MAX_CONCURRENCY)), 162 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNK_SIZE)), 163 io_chunksize=int(kwargs.get("io_chunk_size", IO_CHUNK_SIZE)), 164 use_threads=True, 165 ) 166 167 self._rust_client = None 168 if "rust_client" in kwargs: 169 self._rust_client = self._create_rust_client(kwargs.get("rust_client")) 170 171 def _is_directory_bucket(self, bucket: str) -> bool: 172 """ 173 Determines if the bucket is a directory bucket based on bucket name. 174 """ 175 # S3 Express buckets have a specific naming convention 176 return "--x-s3" in bucket 177 178 def _create_s3_client( 179 self, 180 request_checksum_calculation: Optional[str] = None, 181 response_checksum_validation: Optional[str] = None, 182 max_pool_connections: int = BOTO3_MAX_POOL_CONNECTIONS, 183 connect_timeout: Union[float, int, None] = None, 184 read_timeout: Union[float, int, None] = None, 185 retries: Optional[dict[str, Any]] = None, 186 ): 187 """ 188 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 189 190 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 191 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 192 :param max_pool_connections: The maximum number of connections to keep in a connection pool. 193 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection. 194 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection. 195 :param retries: A dictionary for configuration related to retry behavior. 196 197 :return: The configured S3 client. 198 """ 199 options = { 200 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 201 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 202 max_pool_connections=max_pool_connections, 203 connect_timeout=connect_timeout, 204 read_timeout=read_timeout, 205 retries=retries or {"mode": "standard"}, 206 request_checksum_calculation=request_checksum_calculation, 207 response_checksum_validation=response_checksum_validation, 208 ), 209 } 210 211 if self._region_name: 212 options["region_name"] = self._region_name 213 214 if self._endpoint_url: 215 options["endpoint_url"] = self._endpoint_url 216 217 if self._credentials_provider: 218 creds = self._fetch_credentials() 219 if "expiry_time" in creds and creds["expiry_time"]: 220 # Use RefreshableCredentials if expiry_time provided. 221 refreshable_credentials = RefreshableCredentials.create_from_metadata( 222 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 223 ) 224 225 botocore_session = get_session() 226 botocore_session._credentials = refreshable_credentials 227 228 boto3_session = boto3.Session(botocore_session=botocore_session) 229 230 return boto3_session.client("s3", **options) 231 else: 232 # Add static credentials to the options dictionary 233 options["aws_access_key_id"] = creds["access_key"] 234 options["aws_secret_access_key"] = creds["secret_key"] 235 if creds["token"]: 236 options["aws_session_token"] = creds["token"] 237 238 if self._signature_version: 239 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 240 signature_version=botocore.UNSIGNED 241 if self._signature_version == "UNSIGNED" 242 else self._signature_version 243 ) 244 options["config"] = options["config"].merge(signature_config) 245 246 # Fallback to standard credential chain. 247 return boto3.client("s3", **options) 248 249 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 250 """ 251 Creates and configures the rust client, using refreshable credentials if possible. 252 """ 253 from multistorageclient_rust import RustClient 254 255 configs = {} 256 if self._region_name: 257 configs["region_name"] = self._region_name 258 259 bucket, _ = split_path(self._base_path) 260 configs["bucket"] = bucket 261 262 if self._endpoint_url: 263 configs["endpoint_url"] = self._endpoint_url 264 265 if self._credentials_provider: 266 creds = self._fetch_credentials() 267 if "expiry_time" in creds and creds["expiry_time"]: 268 # TODO: Implement refreshable credentials 269 raise NotImplementedError("Refreshable credentials are not yet implemented for the rust client.") 270 else: 271 # Add static credentials to the configs dictionary 272 configs["aws_access_key_id"] = creds["access_key"] 273 configs["aws_secret_access_key"] = creds["secret_key"] 274 if creds["token"]: 275 configs["aws_session_token"] = creds["token"] 276 277 if rust_client_options: 278 if rust_client_options.get("allow_http", False): 279 configs["allow_http"] = True 280 281 return RustClient( 282 provider=PROVIDER, 283 configs=configs, 284 ) 285 286 def _fetch_credentials(self) -> dict: 287 """ 288 Refreshes the S3 client if the current credentials are expired. 289 """ 290 if not self._credentials_provider: 291 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 292 self._credentials_provider.refresh_credentials() 293 credentials = self._credentials_provider.get_credentials() 294 return { 295 "access_key": credentials.access_key, 296 "secret_key": credentials.secret_key, 297 "token": credentials.token, 298 "expiry_time": credentials.expiration, 299 } 300 301 def _collect_metrics( 302 self, 303 func: Callable[[], _T], 304 operation: str, 305 bucket: str, 306 key: str, 307 put_object_size: Optional[int] = None, 308 get_object_size: Optional[int] = None, 309 ) -> _T: 310 """ 311 Collects and records performance metrics around S3 operations such as PUT, GET, DELETE, etc. 312 313 This method wraps an S3 operation and measures the time it takes to complete, along with recording 314 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 315 proper logging of duration and object size. 316 317 :param func: The function that performs the actual S3 operation. 318 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 319 :param bucket: The name of the S3 bucket involved in the operation. 320 :param key: The key of the object within the S3 bucket. 321 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 322 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 323 324 :return: The result of the S3 operation, typically the return value of the `func` callable. 325 """ 326 # Import the span attribute helper 327 from ..instrumentation.utils import set_span_attribute 328 329 # Set basic operation attributes 330 set_span_attribute("s3_operation", operation) 331 set_span_attribute("s3_bucket", bucket) 332 set_span_attribute("s3_key", key) 333 334 start_time = time.time() 335 status_code = 200 336 337 object_size = None 338 if operation == "PUT": 339 object_size = put_object_size 340 elif operation == "GET" and get_object_size: 341 object_size = get_object_size 342 343 try: 344 result = func() 345 if operation == "GET" and object_size is None and isinstance(result, Sized): 346 object_size = len(result) 347 return result 348 except ClientError as error: 349 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 350 request_id = error.response["ResponseMetadata"].get("RequestId") 351 host_id = error.response["ResponseMetadata"].get("HostId") 352 header = error.response["ResponseMetadata"].get("HTTPHeaders", {}) 353 error_code = error.response["Error"]["Code"] 354 355 # Ensure header is a dictionary before trying to get from it 356 x_trans_id = header.get("x-trans-id") if isinstance(header, dict) else None 357 358 # Record error details in span 359 set_span_attribute("request_id", request_id) 360 set_span_attribute("host_id", host_id) 361 362 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 363 if x_trans_id: 364 error_info += f", x-trans-id: {x_trans_id}" 365 set_span_attribute("x_trans_id", x_trans_id) 366 367 if status_code == 404: 368 if error_code == "NoSuchUpload": 369 error_message = error.response["Error"]["Message"] 370 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 371 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 372 elif status_code == 412: # Precondition Failed 373 raise PreconditionFailedError( 374 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 375 ) from error 376 elif status_code == 429: 377 raise RetryableError( 378 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 379 ) from error 380 elif status_code == 503: 381 raise RetryableError( 382 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 383 ) from error 384 elif status_code == 501: 385 raise NotImplementedError( 386 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 387 ) from error 388 else: 389 raise RuntimeError( 390 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 391 f"error_type: {type(error).__name__}" 392 ) from error 393 except FileNotFoundError as error: 394 status_code = -1 395 raise error 396 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 397 status_code = -1 398 raise RetryableError( 399 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 400 f"error_type: {type(error).__name__}" 401 ) from error 402 except Exception as error: 403 status_code = -1 404 raise RuntimeError( 405 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 406 ) from error 407 finally: 408 elapsed_time = time.time() - start_time 409 410 set_span_attribute("status_code", status_code) 411 412 # Record metrics 413 self._metric_helper.record_duration( 414 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code 415 ) 416 if object_size: 417 self._metric_helper.record_object_size( 418 object_size, 419 provider=self._provider_name, 420 operation=operation, 421 bucket=bucket, 422 status_code=status_code, 423 ) 424 425 set_span_attribute("object_size", object_size) 426 427 def _put_object( 428 self, 429 path: str, 430 body: bytes, 431 if_match: Optional[str] = None, 432 if_none_match: Optional[str] = None, 433 attributes: Optional[dict[str, str]] = None, 434 ) -> int: 435 """ 436 Uploads an object to the specified S3 path. 437 438 :param path: The S3 path where the object will be uploaded. 439 :param body: The content of the object as bytes. 440 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 441 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 442 :param attributes: Optional attributes to attach to the object. 443 """ 444 bucket, key = split_path(path) 445 446 def _invoke_api() -> int: 447 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 448 if self._is_directory_bucket(bucket): 449 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 450 if if_match: 451 kwargs["IfMatch"] = if_match 452 if if_none_match: 453 kwargs["IfNoneMatch"] = if_none_match 454 validated_attributes = validate_attributes(attributes) 455 if validated_attributes: 456 kwargs["Metadata"] = validated_attributes 457 458 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch"} 459 if ( 460 self._rust_client 461 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 462 and not path.endswith("/") 463 and all(key not in kwargs for key in rust_unsupported_feature_keys) 464 ): 465 response = run_async_rust_client_method(self._rust_client, "put", key, body) 466 else: 467 # Capture the response from put_object 468 response = self._s3_client.put_object(**kwargs) 469 470 # Extract and set x-trans-id if present 471 _extract_x_trans_id(response) 472 473 return len(body) 474 475 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 476 477 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 478 bucket, key = split_path(path) 479 480 def _invoke_api() -> bytes: 481 if byte_range: 482 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 483 if self._rust_client: 484 response = run_async_rust_client_method( 485 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 486 ) 487 return response 488 else: 489 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 490 else: 491 if self._rust_client: 492 response = run_async_rust_client_method(self._rust_client, "get", key) 493 return response 494 else: 495 response = self._s3_client.get_object(Bucket=bucket, Key=key) 496 497 # Extract and set x-trans-id if present 498 _extract_x_trans_id(response) 499 500 return response["Body"].read() 501 502 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 503 504 def _copy_object(self, src_path: str, dest_path: str) -> int: 505 src_bucket, src_key = split_path(src_path) 506 dest_bucket, dest_key = split_path(dest_path) 507 508 src_object = self._get_object_metadata(src_path) 509 510 def _invoke_api() -> int: 511 response = self._s3_client.copy( 512 CopySource={"Bucket": src_bucket, "Key": src_key}, 513 Bucket=dest_bucket, 514 Key=dest_key, 515 Config=self._transfer_config, 516 ) 517 518 # Extract and set x-trans-id if present 519 _extract_x_trans_id(response) 520 521 return src_object.content_length 522 523 return self._collect_metrics( 524 _invoke_api, 525 operation="COPY", 526 bucket=dest_bucket, 527 key=dest_key, 528 put_object_size=src_object.content_length, 529 ) 530 531 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 532 bucket, key = split_path(path) 533 534 def _invoke_api() -> None: 535 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally 536 if if_match: 537 response = self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 538 else: 539 response = self._s3_client.delete_object(Bucket=bucket, Key=key) 540 541 # Extract and set x-trans-id if present 542 _extract_x_trans_id(response) 543 544 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 545 546 def _is_dir(self, path: str) -> bool: 547 # Ensure the path ends with '/' to mimic a directory 548 path = self._append_delimiter(path) 549 550 bucket, key = split_path(path) 551 552 def _invoke_api() -> bool: 553 # List objects with the given prefix 554 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 555 556 # Extract and set x-trans-id if present 557 _extract_x_trans_id(response) 558 559 # Check if there are any contents or common prefixes 560 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 561 562 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) 563 564 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 565 if path.endswith("/"): 566 # If path is a "directory", then metadata is not guaranteed to exist if 567 # it is a "virtual prefix" that was never explicitly created. 568 if self._is_dir(path): 569 return ObjectMetadata( 570 key=path, 571 type="directory", 572 content_length=0, 573 last_modified=AWARE_DATETIME_MIN, 574 ) 575 else: 576 raise FileNotFoundError(f"Directory {path} does not exist.") 577 else: 578 bucket, key = split_path(path) 579 580 def _invoke_api() -> ObjectMetadata: 581 response = self._s3_client.head_object(Bucket=bucket, Key=key) 582 583 # Extract and set x-trans-id if present 584 _extract_x_trans_id(response) 585 586 return ObjectMetadata( 587 key=path, 588 type="file", 589 content_length=response["ContentLength"], 590 content_type=response["ContentType"], 591 last_modified=response["LastModified"], 592 etag=response["ETag"].strip('"'), 593 storage_class=response.get("StorageClass"), 594 metadata=response.get("Metadata"), 595 ) 596 597 try: 598 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 599 except FileNotFoundError as error: 600 if strict: 601 # If the object does not exist on the given path, we will append a trailing slash and 602 # check if the path is a directory. 603 path = self._append_delimiter(path) 604 if self._is_dir(path): 605 return ObjectMetadata( 606 key=path, 607 type="directory", 608 content_length=0, 609 last_modified=AWARE_DATETIME_MIN, 610 ) 611 raise error 612 613 def _list_objects( 614 self, 615 prefix: str, 616 start_after: Optional[str] = None, 617 end_at: Optional[str] = None, 618 include_directories: bool = False, 619 ) -> Iterator[ObjectMetadata]: 620 bucket, prefix = split_path(prefix) 621 622 def _invoke_api() -> Iterator[ObjectMetadata]: 623 paginator = self._s3_client.get_paginator("list_objects_v2") 624 if include_directories: 625 page_iterator = paginator.paginate( 626 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 627 ) 628 else: 629 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 630 631 for page in page_iterator: 632 for item in page.get("CommonPrefixes", []): 633 yield ObjectMetadata( 634 key=os.path.join(bucket, item["Prefix"].rstrip("/")), 635 type="directory", 636 content_length=0, 637 last_modified=AWARE_DATETIME_MIN, 638 ) 639 640 # S3 guarantees lexicographical order for general purpose buckets (for 641 # normal S3) but not directory buckets (for S3 Express One Zone). 642 for response_object in page.get("Contents", []): 643 key = response_object["Key"] 644 if end_at is None or key <= end_at: 645 if key.endswith("/"): 646 if include_directories: 647 yield ObjectMetadata( 648 key=os.path.join(bucket, key.rstrip("/")), 649 type="directory", 650 content_length=0, 651 last_modified=response_object["LastModified"], 652 ) 653 else: 654 yield ObjectMetadata( 655 key=os.path.join(bucket, key), 656 type="file", 657 content_length=response_object["Size"], 658 last_modified=response_object["LastModified"], 659 etag=response_object["ETag"].strip('"'), 660 storage_class=response_object.get("StorageClass"), # Pass storage_class 661 ) 662 else: 663 return 664 665 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 666 667 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 668 file_size: int = 0 669 670 if isinstance(f, str): 671 bucket, key = split_path(remote_path) 672 file_size = os.path.getsize(f) 673 674 # Upload small files 675 if file_size <= self._transfer_config.multipart_threshold: 676 if self._rust_client and not attributes: 677 run_async_rust_client_method(self._rust_client, "upload", f, key) 678 else: 679 with open(f, "rb") as fp: 680 self._put_object(remote_path, fp.read(), attributes=attributes) 681 return file_size 682 683 # Upload large files using TransferConfig 684 def _invoke_api() -> int: 685 extra_args = {} 686 if self._is_directory_bucket(bucket): 687 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 688 validated_attributes = validate_attributes(attributes) 689 if validated_attributes: 690 extra_args["Metadata"] = validated_attributes 691 # TODO: Add support for multipart upload of rust client 692 response = self._s3_client.upload_file( 693 Filename=f, 694 Bucket=bucket, 695 Key=key, 696 Config=self._transfer_config, 697 ExtraArgs=extra_args, 698 ) 699 700 # Extract and set x-trans-id if present 701 _extract_x_trans_id(response) 702 703 return file_size 704 705 return self._collect_metrics( 706 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 707 ) 708 else: 709 # Upload small files 710 f.seek(0, io.SEEK_END) 711 file_size = f.tell() 712 f.seek(0) 713 714 if file_size <= self._transfer_config.multipart_threshold: 715 if isinstance(f, io.StringIO): 716 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 717 else: 718 self._put_object(remote_path, f.read(), attributes=attributes) 719 return file_size 720 721 # Upload large files using TransferConfig 722 bucket, key = split_path(remote_path) 723 724 def _invoke_api() -> int: 725 extra_args = {} 726 if self._is_directory_bucket(bucket): 727 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 728 validated_attributes = validate_attributes(attributes) 729 if validated_attributes: 730 extra_args["Metadata"] = validated_attributes 731 self._s3_client.upload_fileobj( 732 Fileobj=f, 733 Bucket=bucket, 734 Key=key, 735 Config=self._transfer_config, 736 ExtraArgs=extra_args, 737 ) 738 739 return file_size 740 741 return self._collect_metrics( 742 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 743 ) 744 745 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 746 if metadata is None: 747 metadata = self._get_object_metadata(remote_path) 748 749 if isinstance(f, str): 750 bucket, key = split_path(remote_path) 751 os.makedirs(os.path.dirname(f), exist_ok=True) 752 753 # Download small files 754 if metadata.content_length <= self._transfer_config.multipart_threshold: 755 if self._rust_client: 756 run_async_rust_client_method(self._rust_client, "download", key, f) 757 else: 758 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 759 temp_file_path = fp.name 760 fp.write(self._get_object(remote_path)) 761 os.rename(src=temp_file_path, dst=f) 762 return metadata.content_length 763 764 # Download large files using TransferConfig 765 def _invoke_api() -> int: 766 response = None 767 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 768 temp_file_path = fp.name 769 # TODO: Add support for multipart download of rust client 770 response = self._s3_client.download_fileobj( 771 Bucket=bucket, 772 Key=key, 773 Fileobj=fp, 774 Config=self._transfer_config, 775 ) 776 777 # Extract and set x-trans-id if present 778 _extract_x_trans_id(response) 779 os.rename(src=temp_file_path, dst=f) 780 781 return metadata.content_length 782 783 return self._collect_metrics( 784 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 785 ) 786 else: 787 # Download small files 788 if metadata.content_length <= self._transfer_config.multipart_threshold: 789 if isinstance(f, io.StringIO): 790 f.write(self._get_object(remote_path).decode("utf-8")) 791 else: 792 f.write(self._get_object(remote_path)) 793 return metadata.content_length 794 795 # Download large files using TransferConfig 796 bucket, key = split_path(remote_path) 797 798 def _invoke_api() -> int: 799 response = self._s3_client.download_fileobj( 800 Bucket=bucket, 801 Key=key, 802 Fileobj=f, 803 Config=self._transfer_config, 804 ) 805 806 # Extract and set x-trans-id if present 807 _extract_x_trans_id(response) 808 809 return metadata.content_length 810 811 return self._collect_metrics( 812 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 813 )