1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29
30from multistorageclient_rust import RustClient, RustRetryableError
31
32from ..rust_utils import run_async_rust_client_method
33from ..telemetry import Telemetry
34from ..types import (
35 AWARE_DATETIME_MIN,
36 Credentials,
37 CredentialsProvider,
38 ObjectMetadata,
39 PreconditionFailedError,
40 Range,
41 RetryableError,
42)
43from ..utils import (
44 get_available_cpu_count,
45 split_path,
46 validate_attributes,
47)
48from .base import BaseStorageProvider
49
50_T = TypeVar("_T")
51
52# Default connection pool size scales with CPU count (minimum 32 for backward compatibility)
53MAX_POOL_CONNECTIONS = max(32, get_available_cpu_count())
54
55MiB = 1024 * 1024
56
57# Python and Rust share the same multipart_threshold to keep the code simple.
58MULTIPART_THRESHOLD = 64 * MiB
59MULTIPART_CHUNKSIZE = 32 * MiB
60IO_CHUNKSIZE = 32 * MiB
61PYTHON_MAX_CONCURRENCY = 8
62
63PROVIDER = "s3"
64
65EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
66
67
[docs]
68class StaticS3CredentialsProvider(CredentialsProvider):
69 """
70 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
71 """
72
73 _access_key: str
74 _secret_key: str
75 _session_token: Optional[str]
76
77 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
78 """
79 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
80 session token.
81
82 :param access_key: The access key for S3 authentication.
83 :param secret_key: The secret key for S3 authentication.
84 :param session_token: An optional session token for temporary credentials.
85 """
86 self._access_key = access_key
87 self._secret_key = secret_key
88 self._session_token = session_token
89
[docs]
90 def get_credentials(self) -> Credentials:
91 return Credentials(
92 access_key=self._access_key,
93 secret_key=self._secret_key,
94 token=self._session_token,
95 expiration=None,
96 )
97
[docs]
98 def refresh_credentials(self) -> None:
99 pass
100
101
[docs]
102class S3StorageProvider(BaseStorageProvider):
103 """
104 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
105 """
106
107 def __init__(
108 self,
109 region_name: str = "",
110 endpoint_url: str = "",
111 base_path: str = "",
112 credentials_provider: Optional[CredentialsProvider] = None,
113 config_dict: Optional[dict[str, Any]] = None,
114 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
115 verify: Optional[Union[bool, str]] = None,
116 **kwargs: Any,
117 ) -> None:
118 """
119 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
120
121 :param region_name: The AWS region where the S3 bucket is located.
122 :param endpoint_url: The custom endpoint URL for the S3 service.
123 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
124 :param credentials_provider: The provider to retrieve S3 credentials.
125 :param config_dict: Resolved MSC config.
126 :param telemetry_provider: A function that provides a telemetry instance.
127 :param verify: Controls SSL certificate verification. Can be ``True`` (verify using system CA bundle, default),
128 ``False`` (skip verification), or a string path to a custom CA certificate bundle.
129 """
130 super().__init__(
131 base_path=base_path,
132 provider_name=PROVIDER,
133 config_dict=config_dict,
134 telemetry_provider=telemetry_provider,
135 )
136
137 self._region_name = region_name
138 self._endpoint_url = endpoint_url
139 self._credentials_provider = credentials_provider
140 self._verify = verify
141
142 self._signature_version = kwargs.get("signature_version", "")
143 self._s3_client = self._create_s3_client(
144 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
145 response_checksum_validation=kwargs.get("response_checksum_validation"),
146 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
147 connect_timeout=kwargs.get("connect_timeout"),
148 read_timeout=kwargs.get("read_timeout"),
149 retries=kwargs.get("retries"),
150 )
151 self._transfer_config = TransferConfig(
152 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
153 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
154 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
155 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
156 use_threads=True,
157 )
158
159 self._rust_client = None
160 if "rust_client" in kwargs:
161 # Inherit the rust client options from the kwargs
162 rust_client_options = kwargs["rust_client"]
163 if "max_pool_connections" in kwargs:
164 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
165 if "max_concurrency" in kwargs:
166 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
167 if "multipart_chunksize" in kwargs:
168 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
169 if "read_timeout" in kwargs:
170 rust_client_options["read_timeout"] = kwargs["read_timeout"]
171 if "connect_timeout" in kwargs:
172 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
173 if self._signature_version == "UNSIGNED":
174 rust_client_options["skip_signature"] = True
175 self._rust_client = self._create_rust_client(rust_client_options)
176
177 def _is_directory_bucket(self, bucket: str) -> bool:
178 """
179 Determines if the bucket is a directory bucket based on bucket name.
180 """
181 # S3 Express buckets have a specific naming convention
182 return "--x-s3" in bucket
183
184 def _create_s3_client(
185 self,
186 request_checksum_calculation: Optional[str] = None,
187 response_checksum_validation: Optional[str] = None,
188 max_pool_connections: int = MAX_POOL_CONNECTIONS,
189 connect_timeout: Union[float, int, None] = None,
190 read_timeout: Union[float, int, None] = None,
191 retries: Optional[dict[str, Any]] = None,
192 ):
193 """
194 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
195
196 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
197 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
198 :param max_pool_connections: The maximum number of connections to keep in a connection pool.
199 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection.
200 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection.
201 :param retries: A dictionary for configuration related to retry behavior.
202
203 :return: The configured S3 client.
204 """
205 options = {
206 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
207 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
208 max_pool_connections=max_pool_connections,
209 connect_timeout=connect_timeout,
210 read_timeout=read_timeout,
211 retries=retries or {"mode": "standard"},
212 request_checksum_calculation=request_checksum_calculation,
213 response_checksum_validation=response_checksum_validation,
214 ),
215 }
216
217 if self._region_name:
218 options["region_name"] = self._region_name
219
220 if self._endpoint_url:
221 options["endpoint_url"] = self._endpoint_url
222
223 if self._verify is not None:
224 options["verify"] = self._verify
225
226 if self._credentials_provider:
227 creds = self._fetch_credentials()
228 if "expiry_time" in creds and creds["expiry_time"]:
229 # Use RefreshableCredentials if expiry_time provided.
230 refreshable_credentials = RefreshableCredentials.create_from_metadata(
231 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
232 )
233
234 botocore_session = get_session()
235 botocore_session._credentials = refreshable_credentials
236
237 boto3_session = boto3.Session(botocore_session=botocore_session)
238
239 return boto3_session.client("s3", **options)
240 else:
241 # Add static credentials to the options dictionary
242 options["aws_access_key_id"] = creds["access_key"]
243 options["aws_secret_access_key"] = creds["secret_key"]
244 if creds["token"]:
245 options["aws_session_token"] = creds["token"]
246
247 if self._signature_version:
248 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
249 signature_version=botocore.UNSIGNED
250 if self._signature_version == "UNSIGNED"
251 else self._signature_version
252 )
253 options["config"] = options["config"].merge(signature_config)
254
255 # Fallback to standard credential chain.
256 return boto3.client("s3", **options)
257
258 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
259 """
260 Creates and configures the rust client, using refreshable credentials if possible.
261 """
262 configs = dict(rust_client_options) if rust_client_options else {}
263
264 if self._region_name and "region_name" not in configs:
265 configs["region_name"] = self._region_name
266
267 if self._endpoint_url and "endpoint_url" not in configs:
268 configs["endpoint_url"] = self._endpoint_url
269
270 # If the user specifies a bucket, use it. Otherwise, use the base path.
271 if "bucket" not in configs:
272 bucket, _ = split_path(self._base_path)
273 configs["bucket"] = bucket
274
275 if "max_pool_connections" not in configs:
276 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
277
278 return RustClient(
279 provider=PROVIDER,
280 configs=configs,
281 credentials_provider=self._credentials_provider,
282 )
283
284 def _fetch_credentials(self) -> dict:
285 """
286 Refreshes the S3 client if the current credentials are expired.
287 """
288 if not self._credentials_provider:
289 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
290 self._credentials_provider.refresh_credentials()
291 credentials = self._credentials_provider.get_credentials()
292 return {
293 "access_key": credentials.access_key,
294 "secret_key": credentials.secret_key,
295 "token": credentials.token,
296 "expiry_time": credentials.expiration,
297 }
298
299 def _translate_errors(
300 self,
301 func: Callable[[], _T],
302 operation: str,
303 bucket: str,
304 key: str,
305 ) -> _T:
306 """
307 Translates errors like timeouts and client errors.
308
309 :param func: The function that performs the actual S3 operation.
310 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
311 :param bucket: The name of the S3 bucket involved in the operation.
312 :param key: The key of the object within the S3 bucket.
313
314 :return: The result of the S3 operation, typically the return value of the `func` callable.
315 """
316 try:
317 return func()
318 except ClientError as error:
319 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
320 request_id = error.response["ResponseMetadata"].get("RequestId")
321 host_id = error.response["ResponseMetadata"].get("HostId")
322 error_code = error.response["Error"]["Code"]
323 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
324
325 if status_code == 404:
326 if error_code == "NoSuchUpload":
327 error_message = error.response["Error"]["Message"]
328 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
329 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
330 elif status_code == 412: # Precondition Failed
331 raise PreconditionFailedError(
332 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
333 ) from error
334 elif status_code == 429:
335 raise RetryableError(
336 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
337 ) from error
338 elif status_code == 503:
339 raise RetryableError(
340 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
341 ) from error
342 elif status_code == 501:
343 raise NotImplementedError(
344 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
345 ) from error
346 else:
347 raise RuntimeError(
348 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
349 f"error_type: {type(error).__name__}"
350 ) from error
351 except FileNotFoundError:
352 raise
353 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
354 raise RetryableError(
355 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
356 f"error_type: {type(error).__name__}"
357 ) from error
358 except RustRetryableError as error:
359 raise RetryableError(
360 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
361 f"error_type: {type(error).__name__}"
362 ) from error
363 except Exception as error:
364 raise RuntimeError(
365 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
366 ) from error
367
368 def _put_object(
369 self,
370 path: str,
371 body: bytes,
372 if_match: Optional[str] = None,
373 if_none_match: Optional[str] = None,
374 attributes: Optional[dict[str, str]] = None,
375 content_type: Optional[str] = None,
376 ) -> int:
377 """
378 Uploads an object to the specified S3 path.
379
380 :param path: The S3 path where the object will be uploaded.
381 :param body: The content of the object as bytes.
382 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
383 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
384 :param attributes: Optional attributes to attach to the object.
385 :param content_type: Optional Content-Type header value.
386 """
387 bucket, key = split_path(path)
388
389 def _invoke_api() -> int:
390 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
391 if content_type:
392 kwargs["ContentType"] = content_type
393 if self._is_directory_bucket(bucket):
394 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
395 if if_match:
396 kwargs["IfMatch"] = if_match
397 if if_none_match:
398 kwargs["IfNoneMatch"] = if_none_match
399 validated_attributes = validate_attributes(attributes)
400 if validated_attributes:
401 kwargs["Metadata"] = validated_attributes
402
403 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
404 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
405 if (
406 self._rust_client
407 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
408 and not path.endswith("/")
409 and all(key not in kwargs for key in rust_unsupported_feature_keys)
410 ):
411 run_async_rust_client_method(self._rust_client, "put", key, body)
412 else:
413 self._s3_client.put_object(**kwargs)
414
415 return len(body)
416
417 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
418
419 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
420 bucket, key = split_path(path)
421
422 def _invoke_api() -> bytes:
423 if byte_range:
424 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
425 if self._rust_client:
426 response = run_async_rust_client_method(
427 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
428 )
429 return response
430 else:
431 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
432 else:
433 if self._rust_client:
434 response = run_async_rust_client_method(self._rust_client, "get", key)
435 return response
436 else:
437 response = self._s3_client.get_object(Bucket=bucket, Key=key)
438
439 return response["Body"].read()
440
441 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
442
443 def _copy_object(self, src_path: str, dest_path: str) -> int:
444 src_bucket, src_key = split_path(src_path)
445 dest_bucket, dest_key = split_path(dest_path)
446
447 src_object = self._get_object_metadata(src_path)
448
449 def _invoke_api() -> int:
450 self._s3_client.copy(
451 CopySource={"Bucket": src_bucket, "Key": src_key},
452 Bucket=dest_bucket,
453 Key=dest_key,
454 Config=self._transfer_config,
455 )
456
457 return src_object.content_length
458
459 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
460
461 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
462 bucket, key = split_path(path)
463
464 def _invoke_api() -> None:
465 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally
466 if if_match:
467 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
468 else:
469 self._s3_client.delete_object(Bucket=bucket, Key=key)
470
471 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
472
473 def _is_dir(self, path: str) -> bool:
474 # Ensure the path ends with '/' to mimic a directory
475 path = self._append_delimiter(path)
476
477 bucket, key = split_path(path)
478
479 def _invoke_api() -> bool:
480 # List objects with the given prefix
481 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
482
483 # Check if there are any contents or common prefixes
484 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
485
486 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
487
488 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
489 bucket, key = split_path(path)
490 if path.endswith("/") or (bucket and not key):
491 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
492 # which metadata is not guaranteed to exist for cases such as
493 # "virtual prefix" that was never explicitly created.
494 if self._is_dir(path):
495 return ObjectMetadata(
496 key=path,
497 type="directory",
498 content_length=0,
499 last_modified=AWARE_DATETIME_MIN,
500 )
501 else:
502 raise FileNotFoundError(f"Directory {path} does not exist.")
503 else:
504
505 def _invoke_api() -> ObjectMetadata:
506 response = self._s3_client.head_object(Bucket=bucket, Key=key)
507
508 return ObjectMetadata(
509 key=path,
510 type="file",
511 content_length=response["ContentLength"],
512 content_type=response.get("ContentType"),
513 last_modified=response["LastModified"],
514 etag=response["ETag"].strip('"') if "ETag" in response else None,
515 storage_class=response.get("StorageClass"),
516 metadata=response.get("Metadata"),
517 )
518
519 try:
520 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
521 except FileNotFoundError as error:
522 if strict:
523 # If the object does not exist on the given path, we will append a trailing slash and
524 # check if the path is a directory.
525 path = self._append_delimiter(path)
526 if self._is_dir(path):
527 return ObjectMetadata(
528 key=path,
529 type="directory",
530 content_length=0,
531 last_modified=AWARE_DATETIME_MIN,
532 )
533 raise error
534
535 def _list_objects(
536 self,
537 path: str,
538 start_after: Optional[str] = None,
539 end_at: Optional[str] = None,
540 include_directories: bool = False,
541 follow_symlinks: bool = True,
542 ) -> Iterator[ObjectMetadata]:
543 bucket, prefix = split_path(path)
544
545 # Get the prefix of the start_after and end_at paths relative to the bucket.
546 if start_after:
547 _, start_after = split_path(start_after)
548 if end_at:
549 _, end_at = split_path(end_at)
550
551 def _invoke_api() -> Iterator[ObjectMetadata]:
552 paginator = self._s3_client.get_paginator("list_objects_v2")
553 if include_directories:
554 page_iterator = paginator.paginate(
555 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
556 )
557 else:
558 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
559
560 for page in page_iterator:
561 for item in page.get("CommonPrefixes", []):
562 yield ObjectMetadata(
563 key=os.path.join(bucket, item["Prefix"].rstrip("/")),
564 type="directory",
565 content_length=0,
566 last_modified=AWARE_DATETIME_MIN,
567 )
568
569 # S3 guarantees lexicographical order for general purpose buckets (for
570 # normal S3) but not directory buckets (for S3 Express One Zone).
571 for response_object in page.get("Contents", []):
572 key = response_object["Key"]
573 if end_at is None or key <= end_at:
574 if key.endswith("/"):
575 if include_directories:
576 yield ObjectMetadata(
577 key=os.path.join(bucket, key.rstrip("/")),
578 type="directory",
579 content_length=0,
580 last_modified=response_object["LastModified"],
581 )
582 else:
583 yield ObjectMetadata(
584 key=os.path.join(bucket, key),
585 type="file",
586 content_length=response_object["Size"],
587 last_modified=response_object["LastModified"],
588 etag=response_object["ETag"].strip('"'),
589 storage_class=response_object.get("StorageClass"), # Pass storage_class
590 )
591 else:
592 return
593
594 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
595
596 def _upload_file(
597 self,
598 remote_path: str,
599 f: Union[str, IO],
600 attributes: Optional[dict[str, str]] = None,
601 content_type: Optional[str] = None,
602 ) -> int:
603 file_size: int = 0
604
605 if isinstance(f, str):
606 bucket, key = split_path(remote_path)
607 file_size = os.path.getsize(f)
608
609 # Upload small files
610 if file_size <= self._transfer_config.multipart_threshold:
611 if self._rust_client and not attributes and not content_type:
612 run_async_rust_client_method(self._rust_client, "upload", f, key)
613 else:
614 with open(f, "rb") as fp:
615 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
616 return file_size
617
618 # Upload large files using TransferConfig
619 def _invoke_api() -> int:
620 extra_args = {}
621 if content_type:
622 extra_args["ContentType"] = content_type
623 if self._is_directory_bucket(bucket):
624 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
625 validated_attributes = validate_attributes(attributes)
626 if validated_attributes:
627 extra_args["Metadata"] = validated_attributes
628 if self._rust_client and not extra_args:
629 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
630 else:
631 self._s3_client.upload_file(
632 Filename=f,
633 Bucket=bucket,
634 Key=key,
635 Config=self._transfer_config,
636 ExtraArgs=extra_args,
637 )
638
639 return file_size
640
641 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
642 else:
643 # Upload small files
644 f.seek(0, io.SEEK_END)
645 file_size = f.tell()
646 f.seek(0)
647
648 if file_size <= self._transfer_config.multipart_threshold:
649 if isinstance(f, io.StringIO):
650 self._put_object(
651 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
652 )
653 else:
654 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
655 return file_size
656
657 # Upload large files using TransferConfig
658 bucket, key = split_path(remote_path)
659
660 def _invoke_api() -> int:
661 extra_args = {}
662 if content_type:
663 extra_args["ContentType"] = content_type
664 if self._is_directory_bucket(bucket):
665 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
666 validated_attributes = validate_attributes(attributes)
667 if validated_attributes:
668 extra_args["Metadata"] = validated_attributes
669
670 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
671 data = f.getbuffer()
672 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
673 else:
674 self._s3_client.upload_fileobj(
675 Fileobj=f,
676 Bucket=bucket,
677 Key=key,
678 Config=self._transfer_config,
679 ExtraArgs=extra_args,
680 )
681
682 return file_size
683
684 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
685
686 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
687 if metadata is None:
688 metadata = self._get_object_metadata(remote_path)
689
690 if isinstance(f, str):
691 bucket, key = split_path(remote_path)
692 if os.path.dirname(f):
693 os.makedirs(os.path.dirname(f), exist_ok=True)
694
695 # Download small files
696 if metadata.content_length <= self._transfer_config.multipart_threshold:
697 if self._rust_client:
698 run_async_rust_client_method(self._rust_client, "download", key, f)
699 else:
700 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
701 temp_file_path = fp.name
702 fp.write(self._get_object(remote_path))
703 os.rename(src=temp_file_path, dst=f)
704 return metadata.content_length
705
706 # Download large files using TransferConfig
707 def _invoke_api() -> int:
708 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
709 temp_file_path = fp.name
710 if self._rust_client:
711 run_async_rust_client_method(
712 self._rust_client, "download_multipart_to_file", key, temp_file_path
713 )
714 else:
715 self._s3_client.download_fileobj(
716 Bucket=bucket,
717 Key=key,
718 Fileobj=fp,
719 Config=self._transfer_config,
720 )
721
722 os.rename(src=temp_file_path, dst=f)
723
724 return metadata.content_length
725
726 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
727 else:
728 # Download small files
729 if metadata.content_length <= self._transfer_config.multipart_threshold:
730 response = self._get_object(remote_path)
731 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
732 # so we need to check whether `.decode()` is available.
733 if isinstance(f, io.StringIO):
734 if hasattr(response, "decode"):
735 f.write(response.decode("utf-8"))
736 else:
737 f.write(codecs.decode(memoryview(response), "utf-8"))
738 else:
739 f.write(response)
740 return metadata.content_length
741
742 # Download large files using TransferConfig
743 bucket, key = split_path(remote_path)
744
745 def _invoke_api() -> int:
746 self._s3_client.download_fileobj(
747 Bucket=bucket,
748 Key=key,
749 Fileobj=f,
750 Config=self._transfer_config,
751 )
752
753 return metadata.content_length
754
755 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)