Source code for multistorageclient.providers.s3

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import io
 18import os
 19import tempfile
 20from collections.abc import Callable, Iterator
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23import boto3
 24import botocore
 25from boto3.s3.transfer import TransferConfig
 26from botocore.credentials import RefreshableCredentials
 27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
 28from botocore.session import get_session
 29
 30from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
 31
 32from ..rust_utils import run_async_rust_client_method
 33from ..telemetry import Telemetry
 34from ..types import (
 35    AWARE_DATETIME_MIN,
 36    Credentials,
 37    CredentialsProvider,
 38    ObjectMetadata,
 39    PreconditionFailedError,
 40    Range,
 41    RetryableError,
 42)
 43from ..utils import (
 44    get_available_cpu_count,
 45    split_path,
 46    validate_attributes,
 47)
 48from .base import BaseStorageProvider
 49
 50_T = TypeVar("_T")
 51
 52# Default connection pool size scales with CPU count (minimum 32 for backward compatibility)
 53MAX_POOL_CONNECTIONS = max(32, get_available_cpu_count())
 54
 55MiB = 1024 * 1024
 56
 57# Python and Rust share the same multipart_threshold to keep the code simple.
 58MULTIPART_THRESHOLD = 64 * MiB
 59MULTIPART_CHUNKSIZE = 32 * MiB
 60IO_CHUNKSIZE = 32 * MiB
 61PYTHON_MAX_CONCURRENCY = 8
 62
 63PROVIDER = "s3"
 64
 65EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
 66
 67
[docs] 68class StaticS3CredentialsProvider(CredentialsProvider): 69 """ 70 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 71 """ 72 73 _access_key: str 74 _secret_key: str 75 _session_token: Optional[str] 76 77 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 78 """ 79 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 80 session token. 81 82 :param access_key: The access key for S3 authentication. 83 :param secret_key: The secret key for S3 authentication. 84 :param session_token: An optional session token for temporary credentials. 85 """ 86 self._access_key = access_key 87 self._secret_key = secret_key 88 self._session_token = session_token 89
[docs] 90 def get_credentials(self) -> Credentials: 91 return Credentials( 92 access_key=self._access_key, 93 secret_key=self._secret_key, 94 token=self._session_token, 95 expiration=None, 96 )
97
[docs] 98 def refresh_credentials(self) -> None: 99 pass
100 101
[docs] 102class S3StorageProvider(BaseStorageProvider): 103 """ 104 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 105 """ 106 107 def __init__( 108 self, 109 region_name: str = "", 110 endpoint_url: str = "", 111 base_path: str = "", 112 credentials_provider: Optional[CredentialsProvider] = None, 113 config_dict: Optional[dict[str, Any]] = None, 114 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 115 verify: Optional[Union[bool, str]] = None, 116 **kwargs: Any, 117 ) -> None: 118 """ 119 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 120 121 :param region_name: The AWS region where the S3 bucket is located. 122 :param endpoint_url: The custom endpoint URL for the S3 service. 123 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 124 :param credentials_provider: The provider to retrieve S3 credentials. 125 :param config_dict: Resolved MSC config. 126 :param telemetry_provider: A function that provides a telemetry instance. 127 :param verify: Controls SSL certificate verification. Can be ``True`` (verify using system CA bundle, default), 128 ``False`` (skip verification), or a string path to a custom CA certificate bundle. 129 """ 130 super().__init__( 131 base_path=base_path, 132 provider_name=PROVIDER, 133 config_dict=config_dict, 134 telemetry_provider=telemetry_provider, 135 ) 136 137 self._region_name = region_name 138 self._endpoint_url = endpoint_url 139 self._credentials_provider = credentials_provider 140 self._verify = verify 141 142 self._signature_version = kwargs.get("signature_version", "") 143 self._s3_client = self._create_s3_client( 144 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 145 response_checksum_validation=kwargs.get("response_checksum_validation"), 146 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS), 147 connect_timeout=kwargs.get("connect_timeout"), 148 read_timeout=kwargs.get("read_timeout"), 149 retries=kwargs.get("retries"), 150 ) 151 self._transfer_config = TransferConfig( 152 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 153 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)), 154 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)), 155 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)), 156 use_threads=True, 157 ) 158 159 self._rust_client = None 160 if "rust_client" in kwargs: 161 # Inherit the rust client options from the kwargs 162 rust_client_options = kwargs["rust_client"] 163 if "max_pool_connections" in kwargs: 164 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"] 165 if "max_concurrency" in kwargs: 166 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 167 if "multipart_chunksize" in kwargs: 168 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 169 if "read_timeout" in kwargs: 170 rust_client_options["read_timeout"] = kwargs["read_timeout"] 171 if "connect_timeout" in kwargs: 172 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 173 if self._signature_version == "UNSIGNED": 174 rust_client_options["skip_signature"] = True 175 self._rust_client = self._create_rust_client(rust_client_options) 176 177 def _is_directory_bucket(self, bucket: str) -> bool: 178 """ 179 Determines if the bucket is a directory bucket based on bucket name. 180 """ 181 # S3 Express buckets have a specific naming convention 182 return "--x-s3" in bucket 183 184 def _create_s3_client( 185 self, 186 request_checksum_calculation: Optional[str] = None, 187 response_checksum_validation: Optional[str] = None, 188 max_pool_connections: int = MAX_POOL_CONNECTIONS, 189 connect_timeout: Union[float, int, None] = None, 190 read_timeout: Union[float, int, None] = None, 191 retries: Optional[dict[str, Any]] = None, 192 ): 193 """ 194 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 195 196 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 197 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 198 :param max_pool_connections: The maximum number of connections to keep in a connection pool. 199 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection. 200 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection. 201 :param retries: A dictionary for configuration related to retry behavior. 202 203 :return: The configured S3 client. 204 """ 205 options = { 206 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 207 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 208 max_pool_connections=max_pool_connections, 209 connect_timeout=connect_timeout, 210 read_timeout=read_timeout, 211 retries=retries or {"mode": "standard"}, 212 request_checksum_calculation=request_checksum_calculation, 213 response_checksum_validation=response_checksum_validation, 214 ), 215 } 216 217 if self._region_name: 218 options["region_name"] = self._region_name 219 220 if self._endpoint_url: 221 options["endpoint_url"] = self._endpoint_url 222 223 if self._verify is not None: 224 options["verify"] = self._verify 225 226 if self._credentials_provider: 227 creds = self._fetch_credentials() 228 if "expiry_time" in creds and creds["expiry_time"]: 229 # Use RefreshableCredentials if expiry_time provided. 230 refreshable_credentials = RefreshableCredentials.create_from_metadata( 231 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 232 ) 233 234 botocore_session = get_session() 235 botocore_session._credentials = refreshable_credentials 236 237 boto3_session = boto3.Session(botocore_session=botocore_session) 238 239 return boto3_session.client("s3", **options) 240 else: 241 # Add static credentials to the options dictionary 242 options["aws_access_key_id"] = creds["access_key"] 243 options["aws_secret_access_key"] = creds["secret_key"] 244 if creds["token"]: 245 options["aws_session_token"] = creds["token"] 246 247 if self._signature_version: 248 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 249 signature_version=botocore.UNSIGNED 250 if self._signature_version == "UNSIGNED" 251 else self._signature_version 252 ) 253 options["config"] = options["config"].merge(signature_config) 254 255 # Fallback to standard credential chain. 256 return boto3.client("s3", **options) 257 258 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 259 """ 260 Creates and configures the rust client, using refreshable credentials if possible. 261 """ 262 configs = dict(rust_client_options) if rust_client_options else {} 263 264 if self._region_name and "region_name" not in configs: 265 configs["region_name"] = self._region_name 266 267 if self._endpoint_url and "endpoint_url" not in configs: 268 configs["endpoint_url"] = self._endpoint_url 269 270 # If the user specifies a bucket, use it. Otherwise, use the base path. 271 if "bucket" not in configs: 272 bucket, _ = split_path(self._base_path) 273 configs["bucket"] = bucket 274 275 if "max_pool_connections" not in configs: 276 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS 277 278 return RustClient( 279 provider=PROVIDER, 280 configs=configs, 281 credentials_provider=self._credentials_provider, 282 ) 283 284 def _fetch_credentials(self) -> dict: 285 """ 286 Refreshes the S3 client if the current credentials are expired. 287 """ 288 if not self._credentials_provider: 289 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 290 self._credentials_provider.refresh_credentials() 291 credentials = self._credentials_provider.get_credentials() 292 return { 293 "access_key": credentials.access_key, 294 "secret_key": credentials.secret_key, 295 "token": credentials.token, 296 "expiry_time": credentials.expiration, 297 } 298 299 def _translate_errors( 300 self, 301 func: Callable[[], _T], 302 operation: str, 303 bucket: str, 304 key: str, 305 ) -> _T: 306 """ 307 Translates errors like timeouts and client errors. 308 309 :param func: The function that performs the actual S3 operation. 310 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 311 :param bucket: The name of the S3 bucket involved in the operation. 312 :param key: The key of the object within the S3 bucket. 313 314 :return: The result of the S3 operation, typically the return value of the `func` callable. 315 """ 316 try: 317 return func() 318 except ClientError as error: 319 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 320 request_id = error.response["ResponseMetadata"].get("RequestId") 321 host_id = error.response["ResponseMetadata"].get("HostId") 322 error_code = error.response["Error"]["Code"] 323 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 324 325 if status_code == 404: 326 if error_code == "NoSuchUpload": 327 error_message = error.response["Error"]["Message"] 328 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 329 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 330 elif status_code == 412: # Precondition Failed 331 raise PreconditionFailedError( 332 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 333 ) from error 334 elif status_code == 429: 335 raise RetryableError( 336 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 337 ) from error 338 elif status_code == 503: 339 raise RetryableError( 340 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 341 ) from error 342 elif status_code == 501: 343 raise NotImplementedError( 344 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 345 ) from error 346 else: 347 raise RuntimeError( 348 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 349 f"error_type: {type(error).__name__}" 350 ) from error 351 except RustClientError as error: 352 message = error.args[0] 353 status_code = error.args[1] 354 if status_code == 404: 355 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 356 elif status_code == 403: 357 raise PermissionError( 358 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 359 ) from error 360 else: 361 raise RetryableError( 362 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 363 ) from error 364 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 365 raise RetryableError( 366 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 367 f"error_type: {type(error).__name__}" 368 ) from error 369 except RustRetryableError as error: 370 raise RetryableError( 371 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 372 f"error_type: {type(error).__name__}" 373 ) from error 374 except Exception as error: 375 raise RuntimeError( 376 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 377 ) from error 378 379 def _put_object( 380 self, 381 path: str, 382 body: bytes, 383 if_match: Optional[str] = None, 384 if_none_match: Optional[str] = None, 385 attributes: Optional[dict[str, str]] = None, 386 content_type: Optional[str] = None, 387 ) -> int: 388 """ 389 Uploads an object to the specified S3 path. 390 391 :param path: The S3 path where the object will be uploaded. 392 :param body: The content of the object as bytes. 393 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 394 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 395 :param attributes: Optional attributes to attach to the object. 396 :param content_type: Optional Content-Type header value. 397 """ 398 bucket, key = split_path(path) 399 400 def _invoke_api() -> int: 401 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 402 if content_type: 403 kwargs["ContentType"] = content_type 404 if self._is_directory_bucket(bucket): 405 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 406 if if_match: 407 kwargs["IfMatch"] = if_match 408 if if_none_match: 409 kwargs["IfNoneMatch"] = if_none_match 410 validated_attributes = validate_attributes(attributes) 411 if validated_attributes: 412 kwargs["Metadata"] = validated_attributes 413 414 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client 415 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"} 416 if ( 417 self._rust_client 418 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 419 and not path.endswith("/") 420 and all(key not in kwargs for key in rust_unsupported_feature_keys) 421 ): 422 run_async_rust_client_method(self._rust_client, "put", key, body) 423 else: 424 self._s3_client.put_object(**kwargs) 425 426 return len(body) 427 428 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 429 430 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 431 bucket, key = split_path(path) 432 433 def _invoke_api() -> bytes: 434 if byte_range: 435 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 436 if self._rust_client: 437 response = run_async_rust_client_method( 438 self._rust_client, 439 "get", 440 key, 441 byte_range, 442 ) 443 return response 444 else: 445 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 446 else: 447 if self._rust_client: 448 response = run_async_rust_client_method(self._rust_client, "get", key) 449 return response 450 else: 451 response = self._s3_client.get_object(Bucket=bucket, Key=key) 452 453 return response["Body"].read() 454 455 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 456 457 def _copy_object(self, src_path: str, dest_path: str) -> int: 458 src_bucket, src_key = split_path(src_path) 459 dest_bucket, dest_key = split_path(dest_path) 460 461 src_object = self._get_object_metadata(src_path) 462 463 def _invoke_api() -> int: 464 self._s3_client.copy( 465 CopySource={"Bucket": src_bucket, "Key": src_key}, 466 Bucket=dest_bucket, 467 Key=dest_key, 468 Config=self._transfer_config, 469 ) 470 471 return src_object.content_length 472 473 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key) 474 475 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 476 bucket, key = split_path(path) 477 478 def _invoke_api() -> None: 479 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally 480 if if_match: 481 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 482 else: 483 self._s3_client.delete_object(Bucket=bucket, Key=key) 484 485 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 486 487 def _is_dir(self, path: str) -> bool: 488 # Ensure the path ends with '/' to mimic a directory 489 path = self._append_delimiter(path) 490 491 bucket, key = split_path(path) 492 493 def _invoke_api() -> bool: 494 # List objects with the given prefix 495 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 496 497 # Check if there are any contents or common prefixes 498 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 499 500 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 501 502 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 503 bucket, key = split_path(path) 504 if path.endswith("/") or (bucket and not key): 505 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 506 # which metadata is not guaranteed to exist for cases such as 507 # "virtual prefix" that was never explicitly created. 508 if self._is_dir(path): 509 return ObjectMetadata( 510 key=path, 511 type="directory", 512 content_length=0, 513 last_modified=AWARE_DATETIME_MIN, 514 ) 515 else: 516 raise FileNotFoundError(f"Directory {path} does not exist.") 517 else: 518 519 def _invoke_api() -> ObjectMetadata: 520 response = self._s3_client.head_object(Bucket=bucket, Key=key) 521 522 return ObjectMetadata( 523 key=path, 524 type="file", 525 content_length=response["ContentLength"], 526 content_type=response.get("ContentType"), 527 last_modified=response["LastModified"], 528 etag=response["ETag"].strip('"') if "ETag" in response else None, 529 storage_class=response.get("StorageClass"), 530 metadata=response.get("Metadata"), 531 ) 532 533 try: 534 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 535 except FileNotFoundError as error: 536 if strict: 537 # If the object does not exist on the given path, we will append a trailing slash and 538 # check if the path is a directory. 539 path = self._append_delimiter(path) 540 if self._is_dir(path): 541 return ObjectMetadata( 542 key=path, 543 type="directory", 544 content_length=0, 545 last_modified=AWARE_DATETIME_MIN, 546 ) 547 raise error 548 549 def _list_objects( 550 self, 551 path: str, 552 start_after: Optional[str] = None, 553 end_at: Optional[str] = None, 554 include_directories: bool = False, 555 follow_symlinks: bool = True, 556 ) -> Iterator[ObjectMetadata]: 557 bucket, prefix = split_path(path) 558 559 # Get the prefix of the start_after and end_at paths relative to the bucket. 560 if start_after: 561 _, start_after = split_path(start_after) 562 if end_at: 563 _, end_at = split_path(end_at) 564 565 def _invoke_api() -> Iterator[ObjectMetadata]: 566 paginator = self._s3_client.get_paginator("list_objects_v2") 567 if include_directories: 568 page_iterator = paginator.paginate( 569 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 570 ) 571 else: 572 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 573 574 for page in page_iterator: 575 for item in page.get("CommonPrefixes", []): 576 yield ObjectMetadata( 577 key=os.path.join(bucket, item["Prefix"].rstrip("/")), 578 type="directory", 579 content_length=0, 580 last_modified=AWARE_DATETIME_MIN, 581 ) 582 583 # S3 guarantees lexicographical order for general purpose buckets (for 584 # normal S3) but not directory buckets (for S3 Express One Zone). 585 for response_object in page.get("Contents", []): 586 key = response_object["Key"] 587 if end_at is None or key <= end_at: 588 if key.endswith("/"): 589 if include_directories: 590 yield ObjectMetadata( 591 key=os.path.join(bucket, key.rstrip("/")), 592 type="directory", 593 content_length=0, 594 last_modified=response_object["LastModified"], 595 ) 596 else: 597 yield ObjectMetadata( 598 key=os.path.join(bucket, key), 599 type="file", 600 content_length=response_object["Size"], 601 last_modified=response_object["LastModified"], 602 etag=response_object["ETag"].strip('"'), 603 storage_class=response_object.get("StorageClass"), # Pass storage_class 604 ) 605 else: 606 return 607 608 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 609 610 def _upload_file( 611 self, 612 remote_path: str, 613 f: Union[str, IO], 614 attributes: Optional[dict[str, str]] = None, 615 content_type: Optional[str] = None, 616 ) -> int: 617 file_size: int = 0 618 619 if isinstance(f, str): 620 bucket, key = split_path(remote_path) 621 file_size = os.path.getsize(f) 622 623 # Upload small files 624 if file_size <= self._transfer_config.multipart_threshold: 625 if self._rust_client and not attributes and not content_type: 626 run_async_rust_client_method(self._rust_client, "upload", f, key) 627 else: 628 with open(f, "rb") as fp: 629 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type) 630 return file_size 631 632 # Upload large files using TransferConfig 633 def _invoke_api() -> int: 634 extra_args = {} 635 if content_type: 636 extra_args["ContentType"] = content_type 637 if self._is_directory_bucket(bucket): 638 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 639 validated_attributes = validate_attributes(attributes) 640 if validated_attributes: 641 extra_args["Metadata"] = validated_attributes 642 if self._rust_client and not extra_args: 643 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 644 else: 645 self._s3_client.upload_file( 646 Filename=f, 647 Bucket=bucket, 648 Key=key, 649 Config=self._transfer_config, 650 ExtraArgs=extra_args, 651 ) 652 653 return file_size 654 655 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 656 else: 657 # Upload small files 658 f.seek(0, io.SEEK_END) 659 file_size = f.tell() 660 f.seek(0) 661 662 if file_size <= self._transfer_config.multipart_threshold: 663 if isinstance(f, io.StringIO): 664 self._put_object( 665 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type 666 ) 667 else: 668 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type) 669 return file_size 670 671 # Upload large files using TransferConfig 672 bucket, key = split_path(remote_path) 673 674 def _invoke_api() -> int: 675 extra_args = {} 676 if content_type: 677 extra_args["ContentType"] = content_type 678 if self._is_directory_bucket(bucket): 679 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 680 validated_attributes = validate_attributes(attributes) 681 if validated_attributes: 682 extra_args["Metadata"] = validated_attributes 683 684 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args: 685 data = f.getbuffer() 686 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data) 687 else: 688 self._s3_client.upload_fileobj( 689 Fileobj=f, 690 Bucket=bucket, 691 Key=key, 692 Config=self._transfer_config, 693 ExtraArgs=extra_args, 694 ) 695 696 return file_size 697 698 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 699 700 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 701 if metadata is None: 702 metadata = self._get_object_metadata(remote_path) 703 704 if isinstance(f, str): 705 bucket, key = split_path(remote_path) 706 if os.path.dirname(f): 707 os.makedirs(os.path.dirname(f), exist_ok=True) 708 709 # Download small files 710 if metadata.content_length <= self._transfer_config.multipart_threshold: 711 if self._rust_client: 712 run_async_rust_client_method(self._rust_client, "download", key, f) 713 else: 714 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 715 temp_file_path = fp.name 716 fp.write(self._get_object(remote_path)) 717 os.rename(src=temp_file_path, dst=f) 718 return metadata.content_length 719 720 # Download large files using TransferConfig 721 def _invoke_api() -> int: 722 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 723 temp_file_path = fp.name 724 if self._rust_client: 725 run_async_rust_client_method( 726 self._rust_client, "download_multipart_to_file", key, temp_file_path 727 ) 728 else: 729 self._s3_client.download_fileobj( 730 Bucket=bucket, 731 Key=key, 732 Fileobj=fp, 733 Config=self._transfer_config, 734 ) 735 736 os.rename(src=temp_file_path, dst=f) 737 738 return metadata.content_length 739 740 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 741 else: 742 # Download small files 743 if metadata.content_length <= self._transfer_config.multipart_threshold: 744 response = self._get_object(remote_path) 745 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 746 # so we need to check whether `.decode()` is available. 747 if isinstance(f, io.StringIO): 748 if hasattr(response, "decode"): 749 f.write(response.decode("utf-8")) 750 else: 751 f.write(codecs.decode(memoryview(response), "utf-8")) 752 else: 753 f.write(response) 754 return metadata.content_length 755 756 # Download large files using TransferConfig 757 bucket, key = split_path(remote_path) 758 759 def _invoke_api() -> int: 760 self._s3_client.download_fileobj( 761 Bucket=bucket, 762 Key=key, 763 Fileobj=f, 764 Config=self._transfer_config, 765 ) 766 767 return metadata.content_length 768 769 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)