1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29
30from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
31
32from ..rust_utils import run_async_rust_client_method
33from ..telemetry import Telemetry
34from ..types import (
35 AWARE_DATETIME_MIN,
36 Credentials,
37 CredentialsProvider,
38 ObjectMetadata,
39 PreconditionFailedError,
40 Range,
41 RetryableError,
42)
43from ..utils import (
44 get_available_cpu_count,
45 split_path,
46 validate_attributes,
47)
48from .base import BaseStorageProvider
49
50_T = TypeVar("_T")
51
52# Default connection pool size scales with CPU count (minimum 32 for backward compatibility)
53MAX_POOL_CONNECTIONS = max(32, get_available_cpu_count())
54
55MiB = 1024 * 1024
56
57# Python and Rust share the same multipart_threshold to keep the code simple.
58MULTIPART_THRESHOLD = 64 * MiB
59MULTIPART_CHUNKSIZE = 32 * MiB
60IO_CHUNKSIZE = 32 * MiB
61PYTHON_MAX_CONCURRENCY = 8
62
63PROVIDER = "s3"
64
65EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
66
67
[docs]
68class StaticS3CredentialsProvider(CredentialsProvider):
69 """
70 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
71 """
72
73 _access_key: str
74 _secret_key: str
75 _session_token: Optional[str]
76
77 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
78 """
79 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
80 session token.
81
82 :param access_key: The access key for S3 authentication.
83 :param secret_key: The secret key for S3 authentication.
84 :param session_token: An optional session token for temporary credentials.
85 """
86 self._access_key = access_key
87 self._secret_key = secret_key
88 self._session_token = session_token
89
[docs]
90 def get_credentials(self) -> Credentials:
91 return Credentials(
92 access_key=self._access_key,
93 secret_key=self._secret_key,
94 token=self._session_token,
95 expiration=None,
96 )
97
[docs]
98 def refresh_credentials(self) -> None:
99 pass
100
101
[docs]
102class S3StorageProvider(BaseStorageProvider):
103 """
104 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
105 """
106
107 def __init__(
108 self,
109 region_name: str = "",
110 endpoint_url: str = "",
111 base_path: str = "",
112 credentials_provider: Optional[CredentialsProvider] = None,
113 config_dict: Optional[dict[str, Any]] = None,
114 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
115 verify: Optional[Union[bool, str]] = None,
116 **kwargs: Any,
117 ) -> None:
118 """
119 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
120
121 :param region_name: The AWS region where the S3 bucket is located.
122 :param endpoint_url: The custom endpoint URL for the S3 service.
123 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
124 :param credentials_provider: The provider to retrieve S3 credentials.
125 :param config_dict: Resolved MSC config.
126 :param telemetry_provider: A function that provides a telemetry instance.
127 :param verify: Controls SSL certificate verification. Can be ``True`` (verify using system CA bundle, default),
128 ``False`` (skip verification), or a string path to a custom CA certificate bundle.
129 """
130 super().__init__(
131 base_path=base_path,
132 provider_name=PROVIDER,
133 config_dict=config_dict,
134 telemetry_provider=telemetry_provider,
135 )
136
137 self._region_name = region_name
138 self._endpoint_url = endpoint_url
139 self._credentials_provider = credentials_provider
140 self._verify = verify
141
142 self._signature_version = kwargs.get("signature_version", "")
143 self._s3_client = self._create_s3_client(
144 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
145 response_checksum_validation=kwargs.get("response_checksum_validation"),
146 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
147 connect_timeout=kwargs.get("connect_timeout"),
148 read_timeout=kwargs.get("read_timeout"),
149 retries=kwargs.get("retries"),
150 )
151 self._transfer_config = TransferConfig(
152 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
153 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
154 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
155 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
156 use_threads=True,
157 )
158
159 self._rust_client = None
160 if "rust_client" in kwargs:
161 # Inherit the rust client options from the kwargs
162 rust_client_options = kwargs["rust_client"]
163 if "max_pool_connections" in kwargs:
164 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
165 if "max_concurrency" in kwargs:
166 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
167 if "multipart_chunksize" in kwargs:
168 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
169 if "read_timeout" in kwargs:
170 rust_client_options["read_timeout"] = kwargs["read_timeout"]
171 if "connect_timeout" in kwargs:
172 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
173 if self._signature_version == "UNSIGNED":
174 rust_client_options["skip_signature"] = True
175 self._rust_client = self._create_rust_client(rust_client_options)
176
177 def _is_directory_bucket(self, bucket: str) -> bool:
178 """
179 Determines if the bucket is a directory bucket based on bucket name.
180 """
181 # S3 Express buckets have a specific naming convention
182 return "--x-s3" in bucket
183
184 def _create_s3_client(
185 self,
186 request_checksum_calculation: Optional[str] = None,
187 response_checksum_validation: Optional[str] = None,
188 max_pool_connections: int = MAX_POOL_CONNECTIONS,
189 connect_timeout: Union[float, int, None] = None,
190 read_timeout: Union[float, int, None] = None,
191 retries: Optional[dict[str, Any]] = None,
192 ):
193 """
194 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
195
196 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
197 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
198 :param max_pool_connections: The maximum number of connections to keep in a connection pool.
199 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection.
200 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection.
201 :param retries: A dictionary for configuration related to retry behavior.
202
203 :return: The configured S3 client.
204 """
205 options = {
206 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
207 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
208 max_pool_connections=max_pool_connections,
209 connect_timeout=connect_timeout,
210 read_timeout=read_timeout,
211 retries=retries or {"mode": "standard"},
212 request_checksum_calculation=request_checksum_calculation,
213 response_checksum_validation=response_checksum_validation,
214 ),
215 }
216
217 if self._region_name:
218 options["region_name"] = self._region_name
219
220 if self._endpoint_url:
221 options["endpoint_url"] = self._endpoint_url
222
223 if self._verify is not None:
224 options["verify"] = self._verify
225
226 if self._credentials_provider:
227 creds = self._fetch_credentials()
228 if "expiry_time" in creds and creds["expiry_time"]:
229 # Use RefreshableCredentials if expiry_time provided.
230 refreshable_credentials = RefreshableCredentials.create_from_metadata(
231 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
232 )
233
234 botocore_session = get_session()
235 botocore_session._credentials = refreshable_credentials
236
237 boto3_session = boto3.Session(botocore_session=botocore_session)
238
239 return boto3_session.client("s3", **options)
240 else:
241 # Add static credentials to the options dictionary
242 options["aws_access_key_id"] = creds["access_key"]
243 options["aws_secret_access_key"] = creds["secret_key"]
244 if creds["token"]:
245 options["aws_session_token"] = creds["token"]
246
247 if self._signature_version:
248 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
249 signature_version=botocore.UNSIGNED
250 if self._signature_version == "UNSIGNED"
251 else self._signature_version
252 )
253 options["config"] = options["config"].merge(signature_config)
254
255 # Fallback to standard credential chain.
256 return boto3.client("s3", **options)
257
258 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
259 """
260 Creates and configures the rust client, using refreshable credentials if possible.
261 """
262 configs = dict(rust_client_options) if rust_client_options else {}
263
264 if self._region_name and "region_name" not in configs:
265 configs["region_name"] = self._region_name
266
267 if self._endpoint_url and "endpoint_url" not in configs:
268 configs["endpoint_url"] = self._endpoint_url
269
270 # If the user specifies a bucket, use it. Otherwise, use the base path.
271 if "bucket" not in configs:
272 bucket, _ = split_path(self._base_path)
273 configs["bucket"] = bucket
274
275 if "max_pool_connections" not in configs:
276 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
277
278 return RustClient(
279 provider=PROVIDER,
280 configs=configs,
281 credentials_provider=self._credentials_provider,
282 )
283
284 def _fetch_credentials(self) -> dict:
285 """
286 Refreshes the S3 client if the current credentials are expired.
287 """
288 if not self._credentials_provider:
289 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
290 self._credentials_provider.refresh_credentials()
291 credentials = self._credentials_provider.get_credentials()
292 return {
293 "access_key": credentials.access_key,
294 "secret_key": credentials.secret_key,
295 "token": credentials.token,
296 "expiry_time": credentials.expiration,
297 }
298
299 def _translate_errors(
300 self,
301 func: Callable[[], _T],
302 operation: str,
303 bucket: str,
304 key: str,
305 ) -> _T:
306 """
307 Translates errors like timeouts and client errors.
308
309 :param func: The function that performs the actual S3 operation.
310 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
311 :param bucket: The name of the S3 bucket involved in the operation.
312 :param key: The key of the object within the S3 bucket.
313
314 :return: The result of the S3 operation, typically the return value of the `func` callable.
315 """
316 try:
317 return func()
318 except ClientError as error:
319 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
320 request_id = error.response["ResponseMetadata"].get("RequestId")
321 host_id = error.response["ResponseMetadata"].get("HostId")
322 error_code = error.response["Error"]["Code"]
323 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
324
325 if status_code == 404:
326 if error_code == "NoSuchUpload":
327 error_message = error.response["Error"]["Message"]
328 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
329 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
330 elif status_code == 412: # Precondition Failed
331 raise PreconditionFailedError(
332 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
333 ) from error
334 elif status_code == 429:
335 raise RetryableError(
336 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
337 ) from error
338 elif status_code == 503:
339 raise RetryableError(
340 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
341 ) from error
342 elif status_code == 501:
343 raise NotImplementedError(
344 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
345 ) from error
346 else:
347 raise RuntimeError(
348 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
349 f"error_type: {type(error).__name__}"
350 ) from error
351 except RustClientError as error:
352 message = error.args[0]
353 status_code = error.args[1]
354 if status_code == 404:
355 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
356 elif status_code == 403:
357 raise PermissionError(
358 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
359 ) from error
360 else:
361 raise RetryableError(
362 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
363 ) from error
364 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
365 raise RetryableError(
366 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
367 f"error_type: {type(error).__name__}"
368 ) from error
369 except RustRetryableError as error:
370 raise RetryableError(
371 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
372 f"error_type: {type(error).__name__}"
373 ) from error
374 except Exception as error:
375 raise RuntimeError(
376 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
377 ) from error
378
379 def _put_object(
380 self,
381 path: str,
382 body: bytes,
383 if_match: Optional[str] = None,
384 if_none_match: Optional[str] = None,
385 attributes: Optional[dict[str, str]] = None,
386 content_type: Optional[str] = None,
387 ) -> int:
388 """
389 Uploads an object to the specified S3 path.
390
391 :param path: The S3 path where the object will be uploaded.
392 :param body: The content of the object as bytes.
393 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
394 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
395 :param attributes: Optional attributes to attach to the object.
396 :param content_type: Optional Content-Type header value.
397 """
398 bucket, key = split_path(path)
399
400 def _invoke_api() -> int:
401 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
402 if content_type:
403 kwargs["ContentType"] = content_type
404 if self._is_directory_bucket(bucket):
405 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
406 if if_match:
407 kwargs["IfMatch"] = if_match
408 if if_none_match:
409 kwargs["IfNoneMatch"] = if_none_match
410 validated_attributes = validate_attributes(attributes)
411 if validated_attributes:
412 kwargs["Metadata"] = validated_attributes
413
414 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
415 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
416 if (
417 self._rust_client
418 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
419 and not path.endswith("/")
420 and all(key not in kwargs for key in rust_unsupported_feature_keys)
421 ):
422 run_async_rust_client_method(self._rust_client, "put", key, body)
423 else:
424 self._s3_client.put_object(**kwargs)
425
426 return len(body)
427
428 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
429
430 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
431 bucket, key = split_path(path)
432
433 def _invoke_api() -> bytes:
434 if byte_range:
435 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
436 if self._rust_client:
437 response = run_async_rust_client_method(
438 self._rust_client,
439 "get",
440 key,
441 byte_range,
442 )
443 return response
444 else:
445 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
446 else:
447 if self._rust_client:
448 response = run_async_rust_client_method(self._rust_client, "get", key)
449 return response
450 else:
451 response = self._s3_client.get_object(Bucket=bucket, Key=key)
452
453 return response["Body"].read()
454
455 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
456
457 def _copy_object(self, src_path: str, dest_path: str) -> int:
458 src_bucket, src_key = split_path(src_path)
459 dest_bucket, dest_key = split_path(dest_path)
460
461 src_object = self._get_object_metadata(src_path)
462
463 def _invoke_api() -> int:
464 self._s3_client.copy(
465 CopySource={"Bucket": src_bucket, "Key": src_key},
466 Bucket=dest_bucket,
467 Key=dest_key,
468 Config=self._transfer_config,
469 )
470
471 return src_object.content_length
472
473 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
474
475 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
476 bucket, key = split_path(path)
477
478 def _invoke_api() -> None:
479 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally
480 if if_match:
481 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
482 else:
483 self._s3_client.delete_object(Bucket=bucket, Key=key)
484
485 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
486
487 def _is_dir(self, path: str) -> bool:
488 # Ensure the path ends with '/' to mimic a directory
489 path = self._append_delimiter(path)
490
491 bucket, key = split_path(path)
492
493 def _invoke_api() -> bool:
494 # List objects with the given prefix
495 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
496
497 # Check if there are any contents or common prefixes
498 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
499
500 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
501
502 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
503 bucket, key = split_path(path)
504 if path.endswith("/") or (bucket and not key):
505 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
506 # which metadata is not guaranteed to exist for cases such as
507 # "virtual prefix" that was never explicitly created.
508 if self._is_dir(path):
509 return ObjectMetadata(
510 key=path,
511 type="directory",
512 content_length=0,
513 last_modified=AWARE_DATETIME_MIN,
514 )
515 else:
516 raise FileNotFoundError(f"Directory {path} does not exist.")
517 else:
518
519 def _invoke_api() -> ObjectMetadata:
520 response = self._s3_client.head_object(Bucket=bucket, Key=key)
521
522 return ObjectMetadata(
523 key=path,
524 type="file",
525 content_length=response["ContentLength"],
526 content_type=response.get("ContentType"),
527 last_modified=response["LastModified"],
528 etag=response["ETag"].strip('"') if "ETag" in response else None,
529 storage_class=response.get("StorageClass"),
530 metadata=response.get("Metadata"),
531 )
532
533 try:
534 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
535 except FileNotFoundError as error:
536 if strict:
537 # If the object does not exist on the given path, we will append a trailing slash and
538 # check if the path is a directory.
539 path = self._append_delimiter(path)
540 if self._is_dir(path):
541 return ObjectMetadata(
542 key=path,
543 type="directory",
544 content_length=0,
545 last_modified=AWARE_DATETIME_MIN,
546 )
547 raise error
548
549 def _list_objects(
550 self,
551 path: str,
552 start_after: Optional[str] = None,
553 end_at: Optional[str] = None,
554 include_directories: bool = False,
555 follow_symlinks: bool = True,
556 ) -> Iterator[ObjectMetadata]:
557 bucket, prefix = split_path(path)
558
559 # Get the prefix of the start_after and end_at paths relative to the bucket.
560 if start_after:
561 _, start_after = split_path(start_after)
562 if end_at:
563 _, end_at = split_path(end_at)
564
565 def _invoke_api() -> Iterator[ObjectMetadata]:
566 paginator = self._s3_client.get_paginator("list_objects_v2")
567 if include_directories:
568 page_iterator = paginator.paginate(
569 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
570 )
571 else:
572 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
573
574 for page in page_iterator:
575 for item in page.get("CommonPrefixes", []):
576 yield ObjectMetadata(
577 key=os.path.join(bucket, item["Prefix"].rstrip("/")),
578 type="directory",
579 content_length=0,
580 last_modified=AWARE_DATETIME_MIN,
581 )
582
583 # S3 guarantees lexicographical order for general purpose buckets (for
584 # normal S3) but not directory buckets (for S3 Express One Zone).
585 for response_object in page.get("Contents", []):
586 key = response_object["Key"]
587 if end_at is None or key <= end_at:
588 if key.endswith("/"):
589 if include_directories:
590 yield ObjectMetadata(
591 key=os.path.join(bucket, key.rstrip("/")),
592 type="directory",
593 content_length=0,
594 last_modified=response_object["LastModified"],
595 )
596 else:
597 yield ObjectMetadata(
598 key=os.path.join(bucket, key),
599 type="file",
600 content_length=response_object["Size"],
601 last_modified=response_object["LastModified"],
602 etag=response_object["ETag"].strip('"'),
603 storage_class=response_object.get("StorageClass"), # Pass storage_class
604 )
605 else:
606 return
607
608 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
609
610 def _upload_file(
611 self,
612 remote_path: str,
613 f: Union[str, IO],
614 attributes: Optional[dict[str, str]] = None,
615 content_type: Optional[str] = None,
616 ) -> int:
617 file_size: int = 0
618
619 if isinstance(f, str):
620 bucket, key = split_path(remote_path)
621 file_size = os.path.getsize(f)
622
623 # Upload small files
624 if file_size <= self._transfer_config.multipart_threshold:
625 if self._rust_client and not attributes and not content_type:
626 run_async_rust_client_method(self._rust_client, "upload", f, key)
627 else:
628 with open(f, "rb") as fp:
629 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
630 return file_size
631
632 # Upload large files using TransferConfig
633 def _invoke_api() -> int:
634 extra_args = {}
635 if content_type:
636 extra_args["ContentType"] = content_type
637 if self._is_directory_bucket(bucket):
638 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
639 validated_attributes = validate_attributes(attributes)
640 if validated_attributes:
641 extra_args["Metadata"] = validated_attributes
642 if self._rust_client and not extra_args:
643 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
644 else:
645 self._s3_client.upload_file(
646 Filename=f,
647 Bucket=bucket,
648 Key=key,
649 Config=self._transfer_config,
650 ExtraArgs=extra_args,
651 )
652
653 return file_size
654
655 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
656 else:
657 # Upload small files
658 f.seek(0, io.SEEK_END)
659 file_size = f.tell()
660 f.seek(0)
661
662 if file_size <= self._transfer_config.multipart_threshold:
663 if isinstance(f, io.StringIO):
664 self._put_object(
665 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
666 )
667 else:
668 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
669 return file_size
670
671 # Upload large files using TransferConfig
672 bucket, key = split_path(remote_path)
673
674 def _invoke_api() -> int:
675 extra_args = {}
676 if content_type:
677 extra_args["ContentType"] = content_type
678 if self._is_directory_bucket(bucket):
679 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
680 validated_attributes = validate_attributes(attributes)
681 if validated_attributes:
682 extra_args["Metadata"] = validated_attributes
683
684 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
685 data = f.getbuffer()
686 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
687 else:
688 self._s3_client.upload_fileobj(
689 Fileobj=f,
690 Bucket=bucket,
691 Key=key,
692 Config=self._transfer_config,
693 ExtraArgs=extra_args,
694 )
695
696 return file_size
697
698 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
699
700 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
701 if metadata is None:
702 metadata = self._get_object_metadata(remote_path)
703
704 if isinstance(f, str):
705 bucket, key = split_path(remote_path)
706 if os.path.dirname(f):
707 os.makedirs(os.path.dirname(f), exist_ok=True)
708
709 # Download small files
710 if metadata.content_length <= self._transfer_config.multipart_threshold:
711 if self._rust_client:
712 run_async_rust_client_method(self._rust_client, "download", key, f)
713 else:
714 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
715 temp_file_path = fp.name
716 fp.write(self._get_object(remote_path))
717 os.rename(src=temp_file_path, dst=f)
718 return metadata.content_length
719
720 # Download large files using TransferConfig
721 def _invoke_api() -> int:
722 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
723 temp_file_path = fp.name
724 if self._rust_client:
725 run_async_rust_client_method(
726 self._rust_client, "download_multipart_to_file", key, temp_file_path
727 )
728 else:
729 self._s3_client.download_fileobj(
730 Bucket=bucket,
731 Key=key,
732 Fileobj=fp,
733 Config=self._transfer_config,
734 )
735
736 os.rename(src=temp_file_path, dst=f)
737
738 return metadata.content_length
739
740 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
741 else:
742 # Download small files
743 if metadata.content_length <= self._transfer_config.multipart_threshold:
744 response = self._get_object(remote_path)
745 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
746 # so we need to check whether `.decode()` is available.
747 if isinstance(f, io.StringIO):
748 if hasattr(response, "decode"):
749 f.write(response.decode("utf-8"))
750 else:
751 f.write(codecs.decode(memoryview(response), "utf-8"))
752 else:
753 f.write(response)
754 return metadata.content_length
755
756 # Download large files using TransferConfig
757 bucket, key = split_path(remote_path)
758
759 def _invoke_api() -> int:
760 self._s3_client.download_fileobj(
761 Bucket=bucket,
762 Key=key,
763 Fileobj=f,
764 Config=self._transfer_config,
765 )
766
767 return metadata.content_length
768
769 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)