Source code for multistorageclient.providers.s3

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import io
 18import os
 19import tempfile
 20from collections.abc import Callable, Iterator
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23import boto3
 24import botocore
 25from boto3.s3.transfer import TransferConfig
 26from botocore.credentials import RefreshableCredentials
 27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
 28from botocore.session import get_session
 29
 30from multistorageclient_rust import RustClient, RustRetryableError
 31
 32from ..rust_utils import run_async_rust_client_method
 33from ..telemetry import Telemetry
 34from ..types import (
 35    AWARE_DATETIME_MIN,
 36    Credentials,
 37    CredentialsProvider,
 38    ObjectMetadata,
 39    PreconditionFailedError,
 40    Range,
 41    RetryableError,
 42)
 43from ..utils import (
 44    get_available_cpu_count,
 45    split_path,
 46    validate_attributes,
 47)
 48from .base import BaseStorageProvider
 49
 50_T = TypeVar("_T")
 51
 52# Default connection pool size scales with CPU count (minimum 32 for backward compatibility)
 53MAX_POOL_CONNECTIONS = max(32, get_available_cpu_count())
 54
 55MiB = 1024 * 1024
 56
 57# Python and Rust share the same multipart_threshold to keep the code simple.
 58MULTIPART_THRESHOLD = 64 * MiB
 59MULTIPART_CHUNKSIZE = 32 * MiB
 60IO_CHUNKSIZE = 32 * MiB
 61PYTHON_MAX_CONCURRENCY = 8
 62
 63PROVIDER = "s3"
 64
 65EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
 66
 67
[docs] 68class StaticS3CredentialsProvider(CredentialsProvider): 69 """ 70 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 71 """ 72 73 _access_key: str 74 _secret_key: str 75 _session_token: Optional[str] 76 77 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 78 """ 79 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 80 session token. 81 82 :param access_key: The access key for S3 authentication. 83 :param secret_key: The secret key for S3 authentication. 84 :param session_token: An optional session token for temporary credentials. 85 """ 86 self._access_key = access_key 87 self._secret_key = secret_key 88 self._session_token = session_token 89
[docs] 90 def get_credentials(self) -> Credentials: 91 return Credentials( 92 access_key=self._access_key, 93 secret_key=self._secret_key, 94 token=self._session_token, 95 expiration=None, 96 )
97
[docs] 98 def refresh_credentials(self) -> None: 99 pass
100 101
[docs] 102class S3StorageProvider(BaseStorageProvider): 103 """ 104 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 105 """ 106 107 def __init__( 108 self, 109 region_name: str = "", 110 endpoint_url: str = "", 111 base_path: str = "", 112 credentials_provider: Optional[CredentialsProvider] = None, 113 config_dict: Optional[dict[str, Any]] = None, 114 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 115 verify: Optional[Union[bool, str]] = None, 116 **kwargs: Any, 117 ) -> None: 118 """ 119 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 120 121 :param region_name: The AWS region where the S3 bucket is located. 122 :param endpoint_url: The custom endpoint URL for the S3 service. 123 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 124 :param credentials_provider: The provider to retrieve S3 credentials. 125 :param config_dict: Resolved MSC config. 126 :param telemetry_provider: A function that provides a telemetry instance. 127 :param verify: Controls SSL certificate verification. Can be ``True`` (verify using system CA bundle, default), 128 ``False`` (skip verification), or a string path to a custom CA certificate bundle. 129 """ 130 super().__init__( 131 base_path=base_path, 132 provider_name=PROVIDER, 133 config_dict=config_dict, 134 telemetry_provider=telemetry_provider, 135 ) 136 137 self._region_name = region_name 138 self._endpoint_url = endpoint_url 139 self._credentials_provider = credentials_provider 140 self._verify = verify 141 142 self._signature_version = kwargs.get("signature_version", "") 143 self._s3_client = self._create_s3_client( 144 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 145 response_checksum_validation=kwargs.get("response_checksum_validation"), 146 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS), 147 connect_timeout=kwargs.get("connect_timeout"), 148 read_timeout=kwargs.get("read_timeout"), 149 retries=kwargs.get("retries"), 150 ) 151 self._transfer_config = TransferConfig( 152 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 153 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)), 154 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)), 155 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)), 156 use_threads=True, 157 ) 158 159 self._rust_client = None 160 if "rust_client" in kwargs: 161 # Inherit the rust client options from the kwargs 162 rust_client_options = kwargs["rust_client"] 163 if "max_pool_connections" in kwargs: 164 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"] 165 if "max_concurrency" in kwargs: 166 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 167 if "multipart_chunksize" in kwargs: 168 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 169 if "read_timeout" in kwargs: 170 rust_client_options["read_timeout"] = kwargs["read_timeout"] 171 if "connect_timeout" in kwargs: 172 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 173 if self._signature_version == "UNSIGNED": 174 rust_client_options["skip_signature"] = True 175 self._rust_client = self._create_rust_client(rust_client_options) 176 177 def _is_directory_bucket(self, bucket: str) -> bool: 178 """ 179 Determines if the bucket is a directory bucket based on bucket name. 180 """ 181 # S3 Express buckets have a specific naming convention 182 return "--x-s3" in bucket 183 184 def _create_s3_client( 185 self, 186 request_checksum_calculation: Optional[str] = None, 187 response_checksum_validation: Optional[str] = None, 188 max_pool_connections: int = MAX_POOL_CONNECTIONS, 189 connect_timeout: Union[float, int, None] = None, 190 read_timeout: Union[float, int, None] = None, 191 retries: Optional[dict[str, Any]] = None, 192 ): 193 """ 194 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 195 196 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 197 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 198 :param max_pool_connections: The maximum number of connections to keep in a connection pool. 199 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection. 200 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection. 201 :param retries: A dictionary for configuration related to retry behavior. 202 203 :return: The configured S3 client. 204 """ 205 options = { 206 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 207 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 208 max_pool_connections=max_pool_connections, 209 connect_timeout=connect_timeout, 210 read_timeout=read_timeout, 211 retries=retries or {"mode": "standard"}, 212 request_checksum_calculation=request_checksum_calculation, 213 response_checksum_validation=response_checksum_validation, 214 ), 215 } 216 217 if self._region_name: 218 options["region_name"] = self._region_name 219 220 if self._endpoint_url: 221 options["endpoint_url"] = self._endpoint_url 222 223 if self._verify is not None: 224 options["verify"] = self._verify 225 226 if self._credentials_provider: 227 creds = self._fetch_credentials() 228 if "expiry_time" in creds and creds["expiry_time"]: 229 # Use RefreshableCredentials if expiry_time provided. 230 refreshable_credentials = RefreshableCredentials.create_from_metadata( 231 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 232 ) 233 234 botocore_session = get_session() 235 botocore_session._credentials = refreshable_credentials 236 237 boto3_session = boto3.Session(botocore_session=botocore_session) 238 239 return boto3_session.client("s3", **options) 240 else: 241 # Add static credentials to the options dictionary 242 options["aws_access_key_id"] = creds["access_key"] 243 options["aws_secret_access_key"] = creds["secret_key"] 244 if creds["token"]: 245 options["aws_session_token"] = creds["token"] 246 247 if self._signature_version: 248 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 249 signature_version=botocore.UNSIGNED 250 if self._signature_version == "UNSIGNED" 251 else self._signature_version 252 ) 253 options["config"] = options["config"].merge(signature_config) 254 255 # Fallback to standard credential chain. 256 return boto3.client("s3", **options) 257 258 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 259 """ 260 Creates and configures the rust client, using refreshable credentials if possible. 261 """ 262 configs = dict(rust_client_options) if rust_client_options else {} 263 264 if self._region_name and "region_name" not in configs: 265 configs["region_name"] = self._region_name 266 267 if self._endpoint_url and "endpoint_url" not in configs: 268 configs["endpoint_url"] = self._endpoint_url 269 270 # If the user specifies a bucket, use it. Otherwise, use the base path. 271 if "bucket" not in configs: 272 bucket, _ = split_path(self._base_path) 273 configs["bucket"] = bucket 274 275 if "max_pool_connections" not in configs: 276 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS 277 278 return RustClient( 279 provider=PROVIDER, 280 configs=configs, 281 credentials_provider=self._credentials_provider, 282 ) 283 284 def _fetch_credentials(self) -> dict: 285 """ 286 Refreshes the S3 client if the current credentials are expired. 287 """ 288 if not self._credentials_provider: 289 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 290 self._credentials_provider.refresh_credentials() 291 credentials = self._credentials_provider.get_credentials() 292 return { 293 "access_key": credentials.access_key, 294 "secret_key": credentials.secret_key, 295 "token": credentials.token, 296 "expiry_time": credentials.expiration, 297 } 298 299 def _translate_errors( 300 self, 301 func: Callable[[], _T], 302 operation: str, 303 bucket: str, 304 key: str, 305 ) -> _T: 306 """ 307 Translates errors like timeouts and client errors. 308 309 :param func: The function that performs the actual S3 operation. 310 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 311 :param bucket: The name of the S3 bucket involved in the operation. 312 :param key: The key of the object within the S3 bucket. 313 314 :return: The result of the S3 operation, typically the return value of the `func` callable. 315 """ 316 try: 317 return func() 318 except ClientError as error: 319 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 320 request_id = error.response["ResponseMetadata"].get("RequestId") 321 host_id = error.response["ResponseMetadata"].get("HostId") 322 error_code = error.response["Error"]["Code"] 323 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 324 325 if status_code == 404: 326 if error_code == "NoSuchUpload": 327 error_message = error.response["Error"]["Message"] 328 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 329 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 330 elif status_code == 412: # Precondition Failed 331 raise PreconditionFailedError( 332 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 333 ) from error 334 elif status_code == 429: 335 raise RetryableError( 336 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 337 ) from error 338 elif status_code == 503: 339 raise RetryableError( 340 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 341 ) from error 342 elif status_code == 501: 343 raise NotImplementedError( 344 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 345 ) from error 346 else: 347 raise RuntimeError( 348 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 349 f"error_type: {type(error).__name__}" 350 ) from error 351 except FileNotFoundError: 352 raise 353 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 354 raise RetryableError( 355 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 356 f"error_type: {type(error).__name__}" 357 ) from error 358 except RustRetryableError as error: 359 raise RetryableError( 360 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 361 f"error_type: {type(error).__name__}" 362 ) from error 363 except Exception as error: 364 raise RuntimeError( 365 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 366 ) from error 367 368 def _put_object( 369 self, 370 path: str, 371 body: bytes, 372 if_match: Optional[str] = None, 373 if_none_match: Optional[str] = None, 374 attributes: Optional[dict[str, str]] = None, 375 content_type: Optional[str] = None, 376 ) -> int: 377 """ 378 Uploads an object to the specified S3 path. 379 380 :param path: The S3 path where the object will be uploaded. 381 :param body: The content of the object as bytes. 382 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 383 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 384 :param attributes: Optional attributes to attach to the object. 385 :param content_type: Optional Content-Type header value. 386 """ 387 bucket, key = split_path(path) 388 389 def _invoke_api() -> int: 390 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 391 if content_type: 392 kwargs["ContentType"] = content_type 393 if self._is_directory_bucket(bucket): 394 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 395 if if_match: 396 kwargs["IfMatch"] = if_match 397 if if_none_match: 398 kwargs["IfNoneMatch"] = if_none_match 399 validated_attributes = validate_attributes(attributes) 400 if validated_attributes: 401 kwargs["Metadata"] = validated_attributes 402 403 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client 404 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"} 405 if ( 406 self._rust_client 407 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 408 and not path.endswith("/") 409 and all(key not in kwargs for key in rust_unsupported_feature_keys) 410 ): 411 run_async_rust_client_method(self._rust_client, "put", key, body) 412 else: 413 self._s3_client.put_object(**kwargs) 414 415 return len(body) 416 417 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 418 419 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 420 bucket, key = split_path(path) 421 422 def _invoke_api() -> bytes: 423 if byte_range: 424 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 425 if self._rust_client: 426 response = run_async_rust_client_method( 427 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 428 ) 429 return response 430 else: 431 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 432 else: 433 if self._rust_client: 434 response = run_async_rust_client_method(self._rust_client, "get", key) 435 return response 436 else: 437 response = self._s3_client.get_object(Bucket=bucket, Key=key) 438 439 return response["Body"].read() 440 441 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 442 443 def _copy_object(self, src_path: str, dest_path: str) -> int: 444 src_bucket, src_key = split_path(src_path) 445 dest_bucket, dest_key = split_path(dest_path) 446 447 src_object = self._get_object_metadata(src_path) 448 449 def _invoke_api() -> int: 450 self._s3_client.copy( 451 CopySource={"Bucket": src_bucket, "Key": src_key}, 452 Bucket=dest_bucket, 453 Key=dest_key, 454 Config=self._transfer_config, 455 ) 456 457 return src_object.content_length 458 459 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key) 460 461 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 462 bucket, key = split_path(path) 463 464 def _invoke_api() -> None: 465 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally 466 if if_match: 467 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 468 else: 469 self._s3_client.delete_object(Bucket=bucket, Key=key) 470 471 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 472 473 def _is_dir(self, path: str) -> bool: 474 # Ensure the path ends with '/' to mimic a directory 475 path = self._append_delimiter(path) 476 477 bucket, key = split_path(path) 478 479 def _invoke_api() -> bool: 480 # List objects with the given prefix 481 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 482 483 # Check if there are any contents or common prefixes 484 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 485 486 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 487 488 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 489 bucket, key = split_path(path) 490 if path.endswith("/") or (bucket and not key): 491 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 492 # which metadata is not guaranteed to exist for cases such as 493 # "virtual prefix" that was never explicitly created. 494 if self._is_dir(path): 495 return ObjectMetadata( 496 key=path, 497 type="directory", 498 content_length=0, 499 last_modified=AWARE_DATETIME_MIN, 500 ) 501 else: 502 raise FileNotFoundError(f"Directory {path} does not exist.") 503 else: 504 505 def _invoke_api() -> ObjectMetadata: 506 response = self._s3_client.head_object(Bucket=bucket, Key=key) 507 508 return ObjectMetadata( 509 key=path, 510 type="file", 511 content_length=response["ContentLength"], 512 content_type=response.get("ContentType"), 513 last_modified=response["LastModified"], 514 etag=response["ETag"].strip('"') if "ETag" in response else None, 515 storage_class=response.get("StorageClass"), 516 metadata=response.get("Metadata"), 517 ) 518 519 try: 520 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 521 except FileNotFoundError as error: 522 if strict: 523 # If the object does not exist on the given path, we will append a trailing slash and 524 # check if the path is a directory. 525 path = self._append_delimiter(path) 526 if self._is_dir(path): 527 return ObjectMetadata( 528 key=path, 529 type="directory", 530 content_length=0, 531 last_modified=AWARE_DATETIME_MIN, 532 ) 533 raise error 534 535 def _list_objects( 536 self, 537 path: str, 538 start_after: Optional[str] = None, 539 end_at: Optional[str] = None, 540 include_directories: bool = False, 541 follow_symlinks: bool = True, 542 ) -> Iterator[ObjectMetadata]: 543 bucket, prefix = split_path(path) 544 545 # Get the prefix of the start_after and end_at paths relative to the bucket. 546 if start_after: 547 _, start_after = split_path(start_after) 548 if end_at: 549 _, end_at = split_path(end_at) 550 551 def _invoke_api() -> Iterator[ObjectMetadata]: 552 paginator = self._s3_client.get_paginator("list_objects_v2") 553 if include_directories: 554 page_iterator = paginator.paginate( 555 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 556 ) 557 else: 558 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 559 560 for page in page_iterator: 561 for item in page.get("CommonPrefixes", []): 562 yield ObjectMetadata( 563 key=os.path.join(bucket, item["Prefix"].rstrip("/")), 564 type="directory", 565 content_length=0, 566 last_modified=AWARE_DATETIME_MIN, 567 ) 568 569 # S3 guarantees lexicographical order for general purpose buckets (for 570 # normal S3) but not directory buckets (for S3 Express One Zone). 571 for response_object in page.get("Contents", []): 572 key = response_object["Key"] 573 if end_at is None or key <= end_at: 574 if key.endswith("/"): 575 if include_directories: 576 yield ObjectMetadata( 577 key=os.path.join(bucket, key.rstrip("/")), 578 type="directory", 579 content_length=0, 580 last_modified=response_object["LastModified"], 581 ) 582 else: 583 yield ObjectMetadata( 584 key=os.path.join(bucket, key), 585 type="file", 586 content_length=response_object["Size"], 587 last_modified=response_object["LastModified"], 588 etag=response_object["ETag"].strip('"'), 589 storage_class=response_object.get("StorageClass"), # Pass storage_class 590 ) 591 else: 592 return 593 594 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 595 596 def _upload_file( 597 self, 598 remote_path: str, 599 f: Union[str, IO], 600 attributes: Optional[dict[str, str]] = None, 601 content_type: Optional[str] = None, 602 ) -> int: 603 file_size: int = 0 604 605 if isinstance(f, str): 606 bucket, key = split_path(remote_path) 607 file_size = os.path.getsize(f) 608 609 # Upload small files 610 if file_size <= self._transfer_config.multipart_threshold: 611 if self._rust_client and not attributes and not content_type: 612 run_async_rust_client_method(self._rust_client, "upload", f, key) 613 else: 614 with open(f, "rb") as fp: 615 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type) 616 return file_size 617 618 # Upload large files using TransferConfig 619 def _invoke_api() -> int: 620 extra_args = {} 621 if content_type: 622 extra_args["ContentType"] = content_type 623 if self._is_directory_bucket(bucket): 624 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 625 validated_attributes = validate_attributes(attributes) 626 if validated_attributes: 627 extra_args["Metadata"] = validated_attributes 628 if self._rust_client and not extra_args: 629 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 630 else: 631 self._s3_client.upload_file( 632 Filename=f, 633 Bucket=bucket, 634 Key=key, 635 Config=self._transfer_config, 636 ExtraArgs=extra_args, 637 ) 638 639 return file_size 640 641 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 642 else: 643 # Upload small files 644 f.seek(0, io.SEEK_END) 645 file_size = f.tell() 646 f.seek(0) 647 648 if file_size <= self._transfer_config.multipart_threshold: 649 if isinstance(f, io.StringIO): 650 self._put_object( 651 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type 652 ) 653 else: 654 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type) 655 return file_size 656 657 # Upload large files using TransferConfig 658 bucket, key = split_path(remote_path) 659 660 def _invoke_api() -> int: 661 extra_args = {} 662 if content_type: 663 extra_args["ContentType"] = content_type 664 if self._is_directory_bucket(bucket): 665 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 666 validated_attributes = validate_attributes(attributes) 667 if validated_attributes: 668 extra_args["Metadata"] = validated_attributes 669 670 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args: 671 data = f.getbuffer() 672 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data) 673 else: 674 self._s3_client.upload_fileobj( 675 Fileobj=f, 676 Bucket=bucket, 677 Key=key, 678 Config=self._transfer_config, 679 ExtraArgs=extra_args, 680 ) 681 682 return file_size 683 684 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 685 686 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 687 if metadata is None: 688 metadata = self._get_object_metadata(remote_path) 689 690 if isinstance(f, str): 691 bucket, key = split_path(remote_path) 692 if os.path.dirname(f): 693 os.makedirs(os.path.dirname(f), exist_ok=True) 694 695 # Download small files 696 if metadata.content_length <= self._transfer_config.multipart_threshold: 697 if self._rust_client: 698 run_async_rust_client_method(self._rust_client, "download", key, f) 699 else: 700 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 701 temp_file_path = fp.name 702 fp.write(self._get_object(remote_path)) 703 os.rename(src=temp_file_path, dst=f) 704 return metadata.content_length 705 706 # Download large files using TransferConfig 707 def _invoke_api() -> int: 708 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 709 temp_file_path = fp.name 710 if self._rust_client: 711 run_async_rust_client_method( 712 self._rust_client, "download_multipart_to_file", key, temp_file_path 713 ) 714 else: 715 self._s3_client.download_fileobj( 716 Bucket=bucket, 717 Key=key, 718 Fileobj=fp, 719 Config=self._transfer_config, 720 ) 721 722 os.rename(src=temp_file_path, dst=f) 723 724 return metadata.content_length 725 726 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 727 else: 728 # Download small files 729 if metadata.content_length <= self._transfer_config.multipart_threshold: 730 response = self._get_object(remote_path) 731 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 732 # so we need to check whether `.decode()` is available. 733 if isinstance(f, io.StringIO): 734 if hasattr(response, "decode"): 735 f.write(response.decode("utf-8")) 736 else: 737 f.write(codecs.decode(memoryview(response), "utf-8")) 738 else: 739 f.write(response) 740 return metadata.content_length 741 742 # Download large files using TransferConfig 743 bucket, key = split_path(remote_path) 744 745 def _invoke_api() -> int: 746 self._s3_client.download_fileobj( 747 Bucket=bucket, 748 Key=key, 749 Fileobj=f, 750 Config=self._transfer_config, 751 ) 752 753 return metadata.content_length 754 755 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)