Source code for multistorageclient.providers.s3

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19import time
 20from collections.abc import Callable, Iterator, Sequence, Sized
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23import boto3
 24import botocore
 25import opentelemetry.metrics as api_metrics
 26from boto3.s3.transfer import TransferConfig
 27from botocore.credentials import RefreshableCredentials
 28from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
 29from botocore.session import get_session
 30
 31from ..instrumentation.utils import set_span_attribute
 32from ..telemetry import Telemetry
 33from ..telemetry.attributes.base import AttributesProvider
 34from ..types import (
 35    AWARE_DATETIME_MIN,
 36    Credentials,
 37    CredentialsProvider,
 38    ObjectMetadata,
 39    PreconditionFailedError,
 40    Range,
 41    RetryableError,
 42)
 43from ..utils import split_path, validate_attributes
 44from .base import BaseStorageProvider
 45
 46_T = TypeVar("_T")
 47
 48BOTO3_MAX_POOL_CONNECTIONS = 32
 49
 50MB = 1024 * 1024
 51
 52MULTIPART_THRESHOLD = 512 * MB
 53MULTIPART_CHUNK_SIZE = 256 * MB
 54IO_CHUNK_SIZE = 128 * MB
 55MAX_CONCURRENCY = 16
 56PROVIDER = "s3"
 57
 58EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
 59
 60
 61def _extract_x_trans_id(response: Any) -> None:
 62    """Extract x-trans-id from boto3 response and set it as span attribute."""
 63    try:
 64        if response and isinstance(response, dict):
 65            headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
 66            if headers and isinstance(headers, dict) and "x-trans-id" in headers:
 67                set_span_attribute("x_trans_id", headers["x-trans-id"])
 68    except (KeyError, AttributeError, TypeError):
 69        # Silently ignore any errors in extraction
 70        pass
 71
 72
[docs] 73class StaticS3CredentialsProvider(CredentialsProvider): 74 """ 75 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 76 """ 77 78 _access_key: str 79 _secret_key: str 80 _session_token: Optional[str] 81 82 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 83 """ 84 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 85 session token. 86 87 :param access_key: The access key for S3 authentication. 88 :param secret_key: The secret key for S3 authentication. 89 :param session_token: An optional session token for temporary credentials. 90 """ 91 self._access_key = access_key 92 self._secret_key = secret_key 93 self._session_token = session_token 94
[docs] 95 def get_credentials(self) -> Credentials: 96 return Credentials( 97 access_key=self._access_key, 98 secret_key=self._secret_key, 99 token=self._session_token, 100 expiration=None, 101 )
102
[docs] 103 def refresh_credentials(self) -> None: 104 pass
105 106
[docs] 107class S3StorageProvider(BaseStorageProvider): 108 """ 109 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores. 110 """ 111 112 def __init__( 113 self, 114 region_name: str = "", 115 endpoint_url: str = "", 116 base_path: str = "", 117 credentials_provider: Optional[CredentialsProvider] = None, 118 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {}, 119 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {}, 120 metric_attributes_providers: Sequence[AttributesProvider] = (), 121 **kwargs: Any, 122 ) -> None: 123 """ 124 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 125 126 :param region_name: The AWS region where the S3 bucket is located. 127 :param endpoint_url: The custom endpoint URL for the S3 service. 128 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 129 :param credentials_provider: The provider to retrieve S3 credentials. 130 :param metric_counters: Metric counters. 131 :param metric_gauges: Metric gauges. 132 :param metric_attributes_providers: Metric attributes providers. 133 """ 134 super().__init__( 135 base_path=base_path, 136 provider_name=PROVIDER, 137 metric_counters=metric_counters, 138 metric_gauges=metric_gauges, 139 metric_attributes_providers=metric_attributes_providers, 140 ) 141 142 self._region_name = region_name 143 self._endpoint_url = endpoint_url 144 self._credentials_provider = credentials_provider 145 146 self._signature_version = kwargs.get("signature_version", "") 147 self._s3_client = self._create_s3_client( 148 request_checksum_calculation=kwargs.get("request_checksum_calculation"), 149 response_checksum_validation=kwargs.get("response_checksum_validation"), 150 max_pool_connections=kwargs.get("max_pool_connections", BOTO3_MAX_POOL_CONNECTIONS), 151 connect_timeout=kwargs.get("connect_timeout"), 152 read_timeout=kwargs.get("read_timeout"), 153 retries=kwargs.get("retries"), 154 ) 155 self._transfer_config = TransferConfig( 156 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 157 max_concurrency=int(kwargs.get("max_concurrency", MAX_CONCURRENCY)), 158 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNK_SIZE)), 159 io_chunksize=int(kwargs.get("io_chunk_size", IO_CHUNK_SIZE)), 160 use_threads=True, 161 ) 162 163 def _is_directory_bucket(self, bucket: str) -> bool: 164 """ 165 Determines if the bucket is a directory bucket based on bucket name. 166 """ 167 # S3 Express buckets have a specific naming convention 168 return "--x-s3" in bucket 169 170 def _create_s3_client( 171 self, 172 request_checksum_calculation: Optional[str] = None, 173 response_checksum_validation: Optional[str] = None, 174 max_pool_connections: int = BOTO3_MAX_POOL_CONNECTIONS, 175 connect_timeout: Union[float, int, None] = None, 176 read_timeout: Union[float, int, None] = None, 177 retries: Optional[dict[str, Any]] = None, 178 ): 179 """ 180 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 181 182 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 183 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_. 184 :param max_pool_connections: The maximum number of connections to keep in a connection pool. 185 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection. 186 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection. 187 :param retries: A dictionary for configuration related to retry behavior. 188 189 :return: The configured S3 client. 190 """ 191 options = { 192 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html 193 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 194 max_pool_connections=max_pool_connections, 195 connect_timeout=connect_timeout, 196 read_timeout=read_timeout, 197 retries=retries or {"mode": "standard"}, 198 request_checksum_calculation=request_checksum_calculation, 199 response_checksum_validation=response_checksum_validation, 200 ), 201 } 202 203 if self._region_name: 204 options["region_name"] = self._region_name 205 206 if self._endpoint_url: 207 options["endpoint_url"] = self._endpoint_url 208 209 if self._credentials_provider: 210 creds = self._fetch_credentials() 211 if "expiry_time" in creds and creds["expiry_time"]: 212 # Use RefreshableCredentials if expiry_time provided. 213 refreshable_credentials = RefreshableCredentials.create_from_metadata( 214 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 215 ) 216 217 botocore_session = get_session() 218 botocore_session._credentials = refreshable_credentials 219 220 boto3_session = boto3.Session(botocore_session=botocore_session) 221 222 return boto3_session.client("s3", **options) 223 else: 224 # Add static credentials to the options dictionary 225 options["aws_access_key_id"] = creds["access_key"] 226 options["aws_secret_access_key"] = creds["secret_key"] 227 if creds["token"]: 228 options["aws_session_token"] = creds["token"] 229 230 if self._signature_version: 231 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 232 signature_version=botocore.UNSIGNED 233 if self._signature_version == "UNSIGNED" 234 else self._signature_version 235 ) 236 options["config"] = options["config"].merge(signature_config) 237 238 # Fallback to standard credential chain. 239 return boto3.client("s3", **options) 240 241 def _fetch_credentials(self) -> dict: 242 """ 243 Refreshes the S3 client if the current credentials are expired. 244 """ 245 if not self._credentials_provider: 246 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 247 self._credentials_provider.refresh_credentials() 248 credentials = self._credentials_provider.get_credentials() 249 return { 250 "access_key": credentials.access_key, 251 "secret_key": credentials.secret_key, 252 "token": credentials.token, 253 "expiry_time": credentials.expiration, 254 } 255 256 def _collect_metrics( 257 self, 258 func: Callable[[], _T], 259 operation: str, 260 bucket: str, 261 key: str, 262 put_object_size: Optional[int] = None, 263 get_object_size: Optional[int] = None, 264 ) -> _T: 265 """ 266 Collects and records performance metrics around S3 operations such as PUT, GET, DELETE, etc. 267 268 This method wraps an S3 operation and measures the time it takes to complete, along with recording 269 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 270 proper logging of duration and object size. 271 272 :param func: The function that performs the actual S3 operation. 273 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 274 :param bucket: The name of the S3 bucket involved in the operation. 275 :param key: The key of the object within the S3 bucket. 276 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 277 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 278 279 :return: The result of the S3 operation, typically the return value of the `func` callable. 280 """ 281 # Import the span attribute helper 282 from ..instrumentation.utils import set_span_attribute 283 284 # Set basic operation attributes 285 set_span_attribute("s3_operation", operation) 286 set_span_attribute("s3_bucket", bucket) 287 set_span_attribute("s3_key", key) 288 289 start_time = time.time() 290 status_code = 200 291 292 object_size = None 293 if operation == "PUT": 294 object_size = put_object_size 295 elif operation == "GET" and get_object_size: 296 object_size = get_object_size 297 298 try: 299 result = func() 300 if operation == "GET" and object_size is None and isinstance(result, Sized): 301 object_size = len(result) 302 return result 303 except ClientError as error: 304 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 305 request_id = error.response["ResponseMetadata"].get("RequestId") 306 host_id = error.response["ResponseMetadata"].get("HostId") 307 header = error.response["ResponseMetadata"].get("HTTPHeaders", {}) 308 error_code = error.response["Error"]["Code"] 309 310 # Ensure header is a dictionary before trying to get from it 311 x_trans_id = header.get("x-trans-id") if isinstance(header, dict) else None 312 313 # Record error details in span 314 set_span_attribute("request_id", request_id) 315 set_span_attribute("host_id", host_id) 316 317 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 318 if x_trans_id: 319 error_info += f", x-trans-id: {x_trans_id}" 320 set_span_attribute("x_trans_id", x_trans_id) 321 322 if status_code == 404: 323 if error_code == "NoSuchUpload": 324 error_message = error.response["Error"]["Message"] 325 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error 326 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 327 elif status_code == 412: # Precondition Failed 328 raise PreconditionFailedError( 329 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}" 330 ) from error 331 elif status_code == 429: 332 raise RetryableError( 333 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 334 ) from error 335 elif status_code == 503: 336 raise RetryableError( 337 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}" 338 ) from error 339 elif status_code == 501: 340 raise NotImplementedError( 341 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}" 342 ) from error 343 else: 344 raise RuntimeError( 345 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, " 346 f"error_type: {type(error).__name__}" 347 ) from error 348 except FileNotFoundError as error: 349 status_code = -1 350 raise error 351 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error: 352 status_code = -1 353 raise RetryableError( 354 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. " 355 f"error_type: {type(error).__name__}" 356 ) from error 357 except Exception as error: 358 status_code = -1 359 raise RuntimeError( 360 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 361 ) from error 362 finally: 363 elapsed_time = time.time() - start_time 364 365 set_span_attribute("status_code", status_code) 366 367 # Record metrics 368 self._metric_helper.record_duration( 369 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code 370 ) 371 if object_size: 372 self._metric_helper.record_object_size( 373 object_size, 374 provider=self._provider_name, 375 operation=operation, 376 bucket=bucket, 377 status_code=status_code, 378 ) 379 380 set_span_attribute("object_size", object_size) 381 382 def _put_object( 383 self, 384 path: str, 385 body: bytes, 386 if_match: Optional[str] = None, 387 if_none_match: Optional[str] = None, 388 attributes: Optional[dict[str, str]] = None, 389 ) -> int: 390 """ 391 Uploads an object to the specified S3 path. 392 393 :param path: The S3 path where the object will be uploaded. 394 :param body: The content of the object as bytes. 395 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist. 396 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist. 397 :param attributes: Optional attributes to attach to the object. 398 """ 399 bucket, key = split_path(path) 400 401 def _invoke_api() -> int: 402 kwargs = {"Bucket": bucket, "Key": key, "Body": body} 403 if self._is_directory_bucket(bucket): 404 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 405 if if_match: 406 kwargs["IfMatch"] = if_match 407 if if_none_match: 408 kwargs["IfNoneMatch"] = if_none_match 409 validated_attributes = validate_attributes(attributes) 410 if validated_attributes: 411 kwargs["Metadata"] = validated_attributes 412 413 # Capture the response from put_object 414 response = self._s3_client.put_object(**kwargs) 415 416 # Extract and set x-trans-id if present 417 _extract_x_trans_id(response) 418 419 return len(body) 420 421 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 422 423 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 424 bucket, key = split_path(path) 425 426 def _invoke_api() -> bytes: 427 if byte_range: 428 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 429 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 430 else: 431 response = self._s3_client.get_object(Bucket=bucket, Key=key) 432 433 # Extract and set x-trans-id if present 434 _extract_x_trans_id(response) 435 436 return response["Body"].read() 437 438 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 439 440 def _copy_object(self, src_path: str, dest_path: str) -> int: 441 src_bucket, src_key = split_path(src_path) 442 dest_bucket, dest_key = split_path(dest_path) 443 444 src_object = self._get_object_metadata(src_path) 445 446 def _invoke_api() -> int: 447 response = self._s3_client.copy( 448 CopySource={"Bucket": src_bucket, "Key": src_key}, 449 Bucket=dest_bucket, 450 Key=dest_key, 451 Config=self._transfer_config, 452 ) 453 454 # Extract and set x-trans-id if present 455 _extract_x_trans_id(response) 456 457 return src_object.content_length 458 459 return self._collect_metrics( 460 _invoke_api, 461 operation="COPY", 462 bucket=dest_bucket, 463 key=dest_key, 464 put_object_size=src_object.content_length, 465 ) 466 467 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 468 bucket, key = split_path(path) 469 470 def _invoke_api() -> None: 471 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally 472 if if_match: 473 response = self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match) 474 else: 475 response = self._s3_client.delete_object(Bucket=bucket, Key=key) 476 477 # Extract and set x-trans-id if present 478 _extract_x_trans_id(response) 479 480 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 481 482 def _is_dir(self, path: str) -> bool: 483 # Ensure the path ends with '/' to mimic a directory 484 path = self._append_delimiter(path) 485 486 bucket, key = split_path(path) 487 488 def _invoke_api() -> bool: 489 # List objects with the given prefix 490 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 491 492 # Extract and set x-trans-id if present 493 _extract_x_trans_id(response) 494 495 # Check if there are any contents or common prefixes 496 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 497 498 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) 499 500 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 501 if path.endswith("/"): 502 # If path is a "directory", then metadata is not guaranteed to exist if 503 # it is a "virtual prefix" that was never explicitly created. 504 if self._is_dir(path): 505 return ObjectMetadata( 506 key=path, 507 type="directory", 508 content_length=0, 509 last_modified=AWARE_DATETIME_MIN, 510 ) 511 else: 512 raise FileNotFoundError(f"Directory {path} does not exist.") 513 else: 514 bucket, key = split_path(path) 515 516 def _invoke_api() -> ObjectMetadata: 517 response = self._s3_client.head_object(Bucket=bucket, Key=key) 518 519 # Extract and set x-trans-id if present 520 _extract_x_trans_id(response) 521 522 return ObjectMetadata( 523 key=path, 524 type="file", 525 content_length=response["ContentLength"], 526 content_type=response["ContentType"], 527 last_modified=response["LastModified"], 528 etag=response["ETag"].strip('"'), 529 storage_class=response.get("StorageClass"), 530 metadata=response.get("Metadata"), 531 ) 532 533 try: 534 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 535 except FileNotFoundError as error: 536 if strict: 537 # If the object does not exist on the given path, we will append a trailing slash and 538 # check if the path is a directory. 539 path = self._append_delimiter(path) 540 if self._is_dir(path): 541 return ObjectMetadata( 542 key=path, 543 type="directory", 544 content_length=0, 545 last_modified=AWARE_DATETIME_MIN, 546 ) 547 raise error 548 549 def _list_objects( 550 self, 551 prefix: str, 552 start_after: Optional[str] = None, 553 end_at: Optional[str] = None, 554 include_directories: bool = False, 555 ) -> Iterator[ObjectMetadata]: 556 bucket, prefix = split_path(prefix) 557 558 def _invoke_api() -> Iterator[ObjectMetadata]: 559 paginator = self._s3_client.get_paginator("list_objects_v2") 560 if include_directories: 561 page_iterator = paginator.paginate( 562 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 563 ) 564 else: 565 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 566 567 for page in page_iterator: 568 for item in page.get("CommonPrefixes", []): 569 yield ObjectMetadata( 570 key=os.path.join(bucket, item["Prefix"].rstrip("/")), 571 type="directory", 572 content_length=0, 573 last_modified=AWARE_DATETIME_MIN, 574 ) 575 576 # S3 guarantees lexicographical order for general purpose buckets (for 577 # normal S3) but not directory buckets (for S3 Express One Zone). 578 for response_object in page.get("Contents", []): 579 key = response_object["Key"] 580 if end_at is None or key <= end_at: 581 if key.endswith("/"): 582 if include_directories: 583 yield ObjectMetadata( 584 key=os.path.join(bucket, key.rstrip("/")), 585 type="directory", 586 content_length=0, 587 last_modified=response_object["LastModified"], 588 ) 589 else: 590 yield ObjectMetadata( 591 key=os.path.join(bucket, key), 592 type="file", 593 content_length=response_object["Size"], 594 last_modified=response_object["LastModified"], 595 etag=response_object["ETag"].strip('"'), 596 storage_class=response_object.get("StorageClass"), # Pass storage_class 597 ) 598 else: 599 return 600 601 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 602 603 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 604 file_size: int = 0 605 606 if isinstance(f, str): 607 file_size = os.path.getsize(f) 608 609 # Upload small files 610 if file_size <= self._transfer_config.multipart_threshold: 611 with open(f, "rb") as fp: 612 self._put_object(remote_path, fp.read(), attributes=attributes) 613 return file_size 614 615 # Upload large files using TransferConfig 616 bucket, key = split_path(remote_path) 617 618 def _invoke_api() -> int: 619 extra_args = {} 620 if self._is_directory_bucket(bucket): 621 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 622 validated_attributes = validate_attributes(attributes) 623 if validated_attributes: 624 extra_args["Metadata"] = validated_attributes 625 response = self._s3_client.upload_file( 626 Filename=f, 627 Bucket=bucket, 628 Key=key, 629 Config=self._transfer_config, 630 ExtraArgs=extra_args, 631 ) 632 633 # Extract and set x-trans-id if present 634 _extract_x_trans_id(response) 635 636 return file_size 637 638 return self._collect_metrics( 639 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 640 ) 641 else: 642 # Upload small files 643 f.seek(0, io.SEEK_END) 644 file_size = f.tell() 645 f.seek(0) 646 647 if file_size <= self._transfer_config.multipart_threshold: 648 if isinstance(f, io.StringIO): 649 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 650 else: 651 self._put_object(remote_path, f.read(), attributes=attributes) 652 return file_size 653 654 # Upload large files using TransferConfig 655 bucket, key = split_path(remote_path) 656 657 def _invoke_api() -> int: 658 extra_args = {} 659 if self._is_directory_bucket(bucket): 660 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS 661 validated_attributes = validate_attributes(attributes) 662 if validated_attributes: 663 extra_args["Metadata"] = validated_attributes 664 self._s3_client.upload_fileobj( 665 Fileobj=f, 666 Bucket=bucket, 667 Key=key, 668 Config=self._transfer_config, 669 ExtraArgs=extra_args, 670 ) 671 672 return file_size 673 674 return self._collect_metrics( 675 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 676 ) 677 678 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 679 if metadata is None: 680 metadata = self._get_object_metadata(remote_path) 681 682 if isinstance(f, str): 683 os.makedirs(os.path.dirname(f), exist_ok=True) 684 # Download small files 685 if metadata.content_length <= self._transfer_config.multipart_threshold: 686 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 687 temp_file_path = fp.name 688 fp.write(self._get_object(remote_path)) 689 os.rename(src=temp_file_path, dst=f) 690 return metadata.content_length 691 692 # Download large files using TransferConfig 693 bucket, key = split_path(remote_path) 694 695 def _invoke_api() -> int: 696 response = None 697 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 698 temp_file_path = fp.name 699 response = self._s3_client.download_fileobj( 700 Bucket=bucket, 701 Key=key, 702 Fileobj=fp, 703 Config=self._transfer_config, 704 ) 705 706 # Extract and set x-trans-id if present 707 _extract_x_trans_id(response) 708 os.rename(src=temp_file_path, dst=f) 709 710 return metadata.content_length 711 712 return self._collect_metrics( 713 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 714 ) 715 else: 716 # Download small files 717 if metadata.content_length <= self._transfer_config.multipart_threshold: 718 if isinstance(f, io.StringIO): 719 f.write(self._get_object(remote_path).decode("utf-8")) 720 else: 721 f.write(self._get_object(remote_path)) 722 return metadata.content_length 723 724 # Download large files using TransferConfig 725 bucket, key = split_path(remote_path) 726 727 def _invoke_api() -> int: 728 response = self._s3_client.download_fileobj( 729 Bucket=bucket, 730 Key=key, 731 Fileobj=f, 732 Config=self._transfer_config, 733 ) 734 735 # Extract and set x-trans-id if present 736 _extract_x_trans_id(response) 737 738 return metadata.content_length 739 740 return self._collect_metrics( 741 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 742 )