Source code for multistorageclient.providers.s3

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import io
 18import os
 19import tempfile
 20from collections.abc import Callable, Iterator
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23import boto3
 24import botocore
 25from boto3.s3.transfer import TransferConfig
 26from botocore.credentials import RefreshableCredentials
 27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
 28from botocore.session import get_session
 29
 30from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
 31
 32from ..rust_utils import run_async_rust_client_method
 33from ..telemetry import Telemetry
 34from ..types import (
 35    AWARE_DATETIME_MIN,
 36    Credentials,
 37    CredentialsProvider,
 38    ObjectMetadata,
 39    PreconditionFailedError,
 40    Range,
 41    RetryableError,
 42)
 43from ..utils import (
 44    get_available_cpu_count,
 45    split_path,
 46    validate_attributes,
 47)
 48from .base import BaseStorageProvider
 49
 50_T = TypeVar("_T")
 51
 52# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
 53MAX_POOL_CONNECTIONS = max(
 54    64,
 55    get_available_cpu_count(),
 56    int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
 57)
 58
 59MiB = 1024 * 1024
 60
 61# Python and Rust share the same multipart_threshold to keep the code simple.
 62MULTIPART_THRESHOLD = 64 * MiB
 63MULTIPART_CHUNKSIZE = 32 * MiB
 64IO_CHUNKSIZE = 32 * MiB
 65PYTHON_MAX_CONCURRENCY = 8
 66
 67PROVIDER = "s3"
 68
 69EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
 70
 71
[docs] 72class StaticS3CredentialsProvider(CredentialsProvider): 73 """ 74 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 75 """ 76 77 _access_key: str 78 _secret_key: str 79 _session_token: Optional[str] 80 81 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 82 """ 83 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 84 session token. 85 86 :param access_key: The access key for S3 authentication. 87 :param secret_key: The secret key for S3 authentication. 88 :param session_token: An optional session token for temporary credentials. 89 """ 90 self._access_key = access_key 91 self._secret_key = secret_key 92 self._session_token = session_token 93
[docs] 94 def get_credentials(self) -> Credentials: 95 return Credentials( 96 access_key=self._access_key, 97 secret_key=self._secret_key, 98 token=self._session_token, 99 expiration=None, 100 )
101
[docs] 102 def refresh_credentials(self) -> None: 103 pass
104 105
[docs] 106class S3StorageProvider(BaseStorageProvider): 107 """ 108 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 109 """ 110 111 def __init__( 112 self, 113 region_name: str = "", 114 endpoint_url: str = "", 115 base_path: str = "", 116 credentials_provider: Optional[CredentialsProvider] = None, 117 config_dict: Optional[dict[str, Any]] = None, 118 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 119 verify: Optional[Union[bool, str]] = None, 120 **kwargs: Any, 121 ) -> None: 122 """ 123 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 124 125 :param region_name: The AWS region where the S3 bucket is located. 126 :param endpoint_url: The custom endpoint URL for the S3 service. 127 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 128 :param credentials_provider: The provider to retrieve S3 credentials. 129 :param config_dict: Resolved MSC config. 130 :param telemetry_provider: A function that provides a telemetry instance. 131 :param verify: Controls SSL certificate verification. Can be ``True`` (verify using system CA bundle, default), 132 ``False`` (skip verification), or a string path to a custom CA certificate bundle. 133 """ 134 super().__init__( 135 base_path=base_path, 136 provider_name=PROVIDER, 137 config_dict=config_dict, 138 telemetry_provider=telemetry_provider, 139 ) 140 141 self._region_name = region_name 142 self._endpoint_url = endpoint_url 143 self._credentials_provider = credentials_provider 144 self._verify = verify 145 146 self._signature_version = kwargs.get("signature_version", "") 147 self._s3_client = self._create_s3_client( 148 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 149 response_checksum_validation=kwargs.get("response_checksum_validation"), 150 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS), 151 connect_timeout=kwargs.get("connect_timeout"), 152 read_timeout=kwargs.get("read_timeout"), 153 retries=kwargs.get("retries"), 154 ) 155 self._transfer_config = TransferConfig( 156 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 157 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)), 158 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)), 159 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)), 160 use_threads=True, 161 ) 162 163 self._rust_client = None 164 if "rust_client" in kwargs: 165 # Inherit the rust client options from the kwargs 166 rust_client_options = kwargs["rust_client"] 167 if "max_pool_connections" in kwargs: 168 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"] 169 if "max_concurrency" in kwargs: 170 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 171 if "multipart_chunksize" in kwargs: 172 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 173 if "read_timeout" in kwargs: 174 rust_client_options["read_timeout"] = kwargs["read_timeout"] 175 if "connect_timeout" in kwargs: 176 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 177 if self._signature_version == "UNSIGNED": 178 rust_client_options["skip_signature"] = True 179 self._rust_client = self._create_rust_client(rust_client_options) 180 181 def _is_directory_bucket(self, bucket: str) -> bool: 182 """ 183 Determines if the bucket is a directory bucket based on bucket name. 184 """ 185 # S3 Express buckets have a specific naming convention 186 return "--x-s3" in bucket 187 188 def _create_s3_client( 189 self, 190 request_checksum_calculation: Optional[str] = None, 191 response_checksum_validation: Optional[str] = None, 192 max_pool_connections: int = MAX_POOL_CONNECTIONS, 193 connect_timeout: Union[float, int, None] = None, 194 read_timeout: Union[float, int, None] = None, 195 retries: Optional[dict[str, Any]] = None, 196 ): 197 """ 198 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 199 200 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 201 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 202 :param max_pool_connections: The maximum number of connections to keep in a connection pool. 203 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection. 204 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection. 205 :param retries: A dictionary for configuration related to retry behavior. 206 207 :return: The configured S3 client. 208 """ 209 options = { 210 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 211 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 212 max_pool_connections=max_pool_connections, 213 connect_timeout=connect_timeout, 214 read_timeout=read_timeout, 215 retries=retries or {"mode": "standard"}, 216 request_checksum_calculation=request_checksum_calculation, 217 response_checksum_validation=response_checksum_validation, 218 ), 219 } 220 221 if self._region_name: 222 options["region_name"] = self._region_name 223 224 if self._endpoint_url: 225 options["endpoint_url"] = self._endpoint_url 226 227 if self._verify is not None: 228 options["verify"] = self._verify 229 230 if self._credentials_provider: 231 creds = self._fetch_credentials() 232 if "expiry_time" in creds and creds["expiry_time"]: 233 # Use RefreshableCredentials if expiry_time provided. 234 refreshable_credentials = RefreshableCredentials.create_from_metadata( 235 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 236 ) 237 238 botocore_session = get_session() 239 botocore_session._credentials = refreshable_credentials 240 241 boto3_session = boto3.Session(botocore_session=botocore_session) 242 243 return boto3_session.client("s3", **options) 244 else: 245 # Add static credentials to the options dictionary 246 options["aws_access_key_id"] = creds["access_key"] 247 options["aws_secret_access_key"] = creds["secret_key"] 248 if creds["token"]: 249 options["aws_session_token"] = creds["token"] 250 251 if self._signature_version: 252 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 253 signature_version=botocore.UNSIGNED 254 if self._signature_version == "UNSIGNED" 255 else self._signature_version 256 ) 257 options["config"] = options["config"].merge(signature_config) 258 259 # Fallback to standard credential chain. 260 return boto3.client("s3", **options) 261 262 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 263 """ 264 Creates and configures the rust client, using refreshable credentials if possible. 265 """ 266 configs = dict(rust_client_options) if rust_client_options else {} 267 268 if self._region_name and "region_name" not in configs: 269 configs["region_name"] = self._region_name 270 271 if self._endpoint_url and "endpoint_url" not in configs: 272 configs["endpoint_url"] = self._endpoint_url 273 274 # If the user specifies a bucket, use it. Otherwise, use the base path. 275 if "bucket" not in configs: 276 bucket, _ = split_path(self._base_path) 277 configs["bucket"] = bucket 278 279 if "max_pool_connections" not in configs: 280 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS 281 282 return RustClient( 283 provider=PROVIDER, 284 configs=configs, 285 credentials_provider=self._credentials_provider, 286 ) 287 288 def _fetch_credentials(self) -> dict: 289 """ 290 Refreshes the S3 client if the current credentials are expired. 291 """ 292 if not self._credentials_provider: 293 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 294 self._credentials_provider.refresh_credentials() 295 credentials = self._credentials_provider.get_credentials() 296 return { 297 "access_key": credentials.access_key, 298 "secret_key": credentials.secret_key, 299 "token": credentials.token, 300 "expiry_time": credentials.expiration, 301 } 302 303 def _translate_errors( 304 self, 305 func: Callable[[], _T], 306 operation: str, 307 bucket: str, 308 key: str, 309 ) -> _T: 310 """ 311 Translates errors like timeouts and client errors. 312 313 :param func: The function that performs the actual S3 operation. 314 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 315 :param bucket: The name of the S3 bucket involved in the operation. 316 :param key: The key of the object within the S3 bucket. 317 318 :return: The result of the S3 operation, typically the return value of the `func` callable. 319 """ 320 try: 321 return func() 322 except ClientError as error: 323 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 324 request_id = error.response["ResponseMetadata"].get("RequestId") 325 host_id = error.response["ResponseMetadata"].get("HostId") 326 error_code = error.response["Error"]["Code"] 327 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 328 329 if status_code == 404: 330 if error_code == "NoSuchUpload": 331 error_message = error.response["Error"]["Message"] 332 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 333 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 334 elif status_code == 412: # Precondition Failed 335 raise PreconditionFailedError( 336 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 337 ) from error 338 elif status_code == 429: 339 raise RetryableError( 340 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 341 ) from error 342 elif status_code == 503: 343 raise RetryableError( 344 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 345 ) from error 346 elif status_code == 501: 347 raise NotImplementedError( 348 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 349 ) from error 350 else: 351 raise RuntimeError( 352 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 353 f"error_type: {type(error).__name__}" 354 ) from error 355 except RustClientError as error: 356 message = error.args[0] 357 status_code = error.args[1] 358 if status_code == 404: 359 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 360 elif status_code == 403: 361 raise PermissionError( 362 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 363 ) from error 364 else: 365 raise RetryableError( 366 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 367 ) from error 368 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 369 raise RetryableError( 370 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 371 f"error_type: {type(error).__name__}" 372 ) from error 373 except RustRetryableError as error: 374 raise RetryableError( 375 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 376 f"error_type: {type(error).__name__}" 377 ) from error 378 except Exception as error: 379 raise RuntimeError( 380 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 381 ) from error 382 383 def _put_object( 384 self, 385 path: str, 386 body: bytes, 387 if_match: Optional[str] = None, 388 if_none_match: Optional[str] = None, 389 attributes: Optional[dict[str, str]] = None, 390 content_type: Optional[str] = None, 391 ) -> int: 392 """ 393 Uploads an object to the specified S3 path. 394 395 :param path: The S3 path where the object will be uploaded. 396 :param body: The content of the object as bytes. 397 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 398 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 399 :param attributes: Optional attributes to attach to the object. 400 :param content_type: Optional Content-Type header value. 401 """ 402 bucket, key = split_path(path) 403 404 def _invoke_api() -> int: 405 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 406 if content_type: 407 kwargs["ContentType"] = content_type 408 if self._is_directory_bucket(bucket): 409 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 410 if if_match: 411 kwargs["IfMatch"] = if_match 412 if if_none_match: 413 kwargs["IfNoneMatch"] = if_none_match 414 validated_attributes = validate_attributes(attributes) 415 if validated_attributes: 416 kwargs["Metadata"] = validated_attributes 417 418 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client 419 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"} 420 if ( 421 self._rust_client 422 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 423 and not path.endswith("/") 424 and all(key not in kwargs for key in rust_unsupported_feature_keys) 425 ): 426 run_async_rust_client_method(self._rust_client, "put", key, body) 427 else: 428 self._s3_client.put_object(**kwargs) 429 430 return len(body) 431 432 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 433 434 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 435 bucket, key = split_path(path) 436 437 def _invoke_api() -> bytes: 438 if byte_range: 439 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 440 if self._rust_client: 441 response = run_async_rust_client_method( 442 self._rust_client, 443 "get", 444 key, 445 byte_range, 446 ) 447 return response 448 else: 449 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 450 else: 451 if self._rust_client: 452 response = run_async_rust_client_method(self._rust_client, "get", key) 453 return response 454 else: 455 response = self._s3_client.get_object(Bucket=bucket, Key=key) 456 457 return response["Body"].read() 458 459 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 460 461 def _copy_object(self, src_path: str, dest_path: str) -> int: 462 src_bucket, src_key = split_path(src_path) 463 dest_bucket, dest_key = split_path(dest_path) 464 465 src_object = self._get_object_metadata(src_path) 466 467 def _invoke_api() -> int: 468 self._s3_client.copy( 469 CopySource={"Bucket": src_bucket, "Key": src_key}, 470 Bucket=dest_bucket, 471 Key=dest_key, 472 Config=self._transfer_config, 473 ) 474 475 return src_object.content_length 476 477 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key) 478 479 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 480 bucket, key = split_path(path) 481 482 def _invoke_api() -> None: 483 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally 484 if if_match: 485 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 486 else: 487 self._s3_client.delete_object(Bucket=bucket, Key=key) 488 489 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 490 491 def _is_dir(self, path: str) -> bool: 492 # Ensure the path ends with '/' to mimic a directory 493 path = self._append_delimiter(path) 494 495 bucket, key = split_path(path) 496 497 def _invoke_api() -> bool: 498 # List objects with the given prefix 499 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 500 501 # Check if there are any contents or common prefixes 502 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 503 504 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 505 506 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 507 bucket, key = split_path(path) 508 if path.endswith("/") or (bucket and not key): 509 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 510 # which metadata is not guaranteed to exist for cases such as 511 # "virtual prefix" that was never explicitly created. 512 if self._is_dir(path): 513 return ObjectMetadata( 514 key=path, 515 type="directory", 516 content_length=0, 517 last_modified=AWARE_DATETIME_MIN, 518 ) 519 else: 520 raise FileNotFoundError(f"Directory {path} does not exist.") 521 else: 522 523 def _invoke_api() -> ObjectMetadata: 524 response = self._s3_client.head_object(Bucket=bucket, Key=key) 525 526 return ObjectMetadata( 527 key=path, 528 type="file", 529 content_length=response["ContentLength"], 530 content_type=response.get("ContentType"), 531 last_modified=response["LastModified"], 532 etag=response["ETag"].strip('"') if "ETag" in response else None, 533 storage_class=response.get("StorageClass"), 534 metadata=response.get("Metadata"), 535 ) 536 537 try: 538 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 539 except FileNotFoundError as error: 540 if strict: 541 # If the object does not exist on the given path, we will append a trailing slash and 542 # check if the path is a directory. 543 path = self._append_delimiter(path) 544 if self._is_dir(path): 545 return ObjectMetadata( 546 key=path, 547 type="directory", 548 content_length=0, 549 last_modified=AWARE_DATETIME_MIN, 550 ) 551 raise error 552 553 def _list_objects( 554 self, 555 path: str, 556 start_after: Optional[str] = None, 557 end_at: Optional[str] = None, 558 include_directories: bool = False, 559 follow_symlinks: bool = True, 560 ) -> Iterator[ObjectMetadata]: 561 bucket, prefix = split_path(path) 562 563 # Get the prefix of the start_after and end_at paths relative to the bucket. 564 if start_after: 565 _, start_after = split_path(start_after) 566 if end_at: 567 _, end_at = split_path(end_at) 568 569 def _invoke_api() -> Iterator[ObjectMetadata]: 570 paginator = self._s3_client.get_paginator("list_objects_v2") 571 if include_directories: 572 page_iterator = paginator.paginate( 573 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 574 ) 575 else: 576 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 577 578 for page in page_iterator: 579 for item in page.get("CommonPrefixes", []): 580 yield ObjectMetadata( 581 key=os.path.join(bucket, item["Prefix"].rstrip("/")), 582 type="directory", 583 content_length=0, 584 last_modified=AWARE_DATETIME_MIN, 585 ) 586 587 # S3 guarantees lexicographical order for general purpose buckets (for 588 # normal S3) but not directory buckets (for S3 Express One Zone). 589 for response_object in page.get("Contents", []): 590 key = response_object["Key"] 591 if end_at is None or key <= end_at: 592 if key.endswith("/"): 593 if include_directories: 594 yield ObjectMetadata( 595 key=os.path.join(bucket, key.rstrip("/")), 596 type="directory", 597 content_length=0, 598 last_modified=response_object["LastModified"], 599 ) 600 else: 601 yield ObjectMetadata( 602 key=os.path.join(bucket, key), 603 type="file", 604 content_length=response_object["Size"], 605 last_modified=response_object["LastModified"], 606 etag=response_object["ETag"].strip('"'), 607 storage_class=response_object.get("StorageClass"), # Pass storage_class 608 ) 609 else: 610 return 611 612 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 613 614 def _upload_file( 615 self, 616 remote_path: str, 617 f: Union[str, IO], 618 attributes: Optional[dict[str, str]] = None, 619 content_type: Optional[str] = None, 620 ) -> int: 621 file_size: int = 0 622 623 if isinstance(f, str): 624 bucket, key = split_path(remote_path) 625 file_size = os.path.getsize(f) 626 627 # Upload small files 628 if file_size <= self._transfer_config.multipart_threshold: 629 if self._rust_client and not attributes and not content_type: 630 run_async_rust_client_method(self._rust_client, "upload", f, key) 631 else: 632 with open(f, "rb") as fp: 633 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type) 634 return file_size 635 636 # Upload large files using TransferConfig 637 def _invoke_api() -> int: 638 extra_args = {} 639 if content_type: 640 extra_args["ContentType"] = content_type 641 if self._is_directory_bucket(bucket): 642 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 643 validated_attributes = validate_attributes(attributes) 644 if validated_attributes: 645 extra_args["Metadata"] = validated_attributes 646 if self._rust_client and not extra_args: 647 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 648 else: 649 self._s3_client.upload_file( 650 Filename=f, 651 Bucket=bucket, 652 Key=key, 653 Config=self._transfer_config, 654 ExtraArgs=extra_args, 655 ) 656 657 return file_size 658 659 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 660 else: 661 # Upload small files 662 f.seek(0, io.SEEK_END) 663 file_size = f.tell() 664 f.seek(0) 665 666 if file_size <= self._transfer_config.multipart_threshold: 667 if isinstance(f, io.StringIO): 668 self._put_object( 669 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type 670 ) 671 else: 672 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type) 673 return file_size 674 675 # Upload large files using TransferConfig 676 bucket, key = split_path(remote_path) 677 678 def _invoke_api() -> int: 679 extra_args = {} 680 if content_type: 681 extra_args["ContentType"] = content_type 682 if self._is_directory_bucket(bucket): 683 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 684 validated_attributes = validate_attributes(attributes) 685 if validated_attributes: 686 extra_args["Metadata"] = validated_attributes 687 688 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args: 689 data = f.getbuffer() 690 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data) 691 else: 692 self._s3_client.upload_fileobj( 693 Fileobj=f, 694 Bucket=bucket, 695 Key=key, 696 Config=self._transfer_config, 697 ExtraArgs=extra_args, 698 ) 699 700 return file_size 701 702 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 703 704 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 705 if metadata is None: 706 metadata = self._get_object_metadata(remote_path) 707 708 if isinstance(f, str): 709 bucket, key = split_path(remote_path) 710 if os.path.dirname(f): 711 os.makedirs(os.path.dirname(f), exist_ok=True) 712 713 # Download small files 714 if metadata.content_length <= self._transfer_config.multipart_threshold: 715 if self._rust_client: 716 run_async_rust_client_method(self._rust_client, "download", key, f) 717 else: 718 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 719 temp_file_path = fp.name 720 fp.write(self._get_object(remote_path)) 721 os.rename(src=temp_file_path, dst=f) 722 return metadata.content_length 723 724 # Download large files using TransferConfig 725 def _invoke_api() -> int: 726 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 727 temp_file_path = fp.name 728 if self._rust_client: 729 run_async_rust_client_method( 730 self._rust_client, "download_multipart_to_file", key, temp_file_path 731 ) 732 else: 733 self._s3_client.download_fileobj( 734 Bucket=bucket, 735 Key=key, 736 Fileobj=fp, 737 Config=self._transfer_config, 738 ) 739 740 os.rename(src=temp_file_path, dst=f) 741 742 return metadata.content_length 743 744 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 745 else: 746 # Download small files 747 if metadata.content_length <= self._transfer_config.multipart_threshold: 748 response = self._get_object(remote_path) 749 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 750 # so we need to check whether `.decode()` is available. 751 if isinstance(f, io.StringIO): 752 if hasattr(response, "decode"): 753 f.write(response.decode("utf-8")) 754 else: 755 f.write(codecs.decode(memoryview(response), "utf-8")) 756 else: 757 f.write(response) 758 return metadata.content_length 759 760 # Download large files using TransferConfig 761 bucket, key = split_path(remote_path) 762 763 def _invoke_api() -> int: 764 self._s3_client.download_fileobj( 765 Bucket=bucket, 766 Key=key, 767 Fileobj=f, 768 Config=self._transfer_config, 769 ) 770 771 return metadata.content_length 772 773 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)