1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29
30from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
31
32from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
33from ..rust_utils import parse_retry_config, run_async_rust_client_method
34from ..telemetry import Telemetry
35from ..types import (
36 AWARE_DATETIME_MIN,
37 Credentials,
38 CredentialsProvider,
39 ObjectMetadata,
40 PreconditionFailedError,
41 Range,
42 RetryableError,
43)
44from ..utils import (
45 get_available_cpu_count,
46 safe_makedirs,
47 split_path,
48 validate_attributes,
49)
50from .base import BaseStorageProvider
51
52_T = TypeVar("_T")
53
54# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
55MAX_POOL_CONNECTIONS = max(
56 64,
57 get_available_cpu_count(),
58 int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
59)
60
61MiB = 1024 * 1024
62
63# Python and Rust share the same multipart_threshold to keep the code simple.
64MULTIPART_THRESHOLD = 64 * MiB
65MULTIPART_CHUNKSIZE = 32 * MiB
66IO_CHUNKSIZE = 32 * MiB
67PYTHON_MAX_CONCURRENCY = 8
68
69PROVIDER = "s3"
70
71EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
72
73
[docs]
74class StaticS3CredentialsProvider(CredentialsProvider):
75 """
76 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
77 """
78
79 _access_key: str
80 _secret_key: str
81 _session_token: Optional[str]
82
83 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
84 """
85 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
86 session token.
87
88 :param access_key: The access key for S3 authentication.
89 :param secret_key: The secret key for S3 authentication.
90 :param session_token: An optional session token for temporary credentials.
91 """
92 self._access_key = access_key
93 self._secret_key = secret_key
94 self._session_token = session_token
95
[docs]
96 def get_credentials(self) -> Credentials:
97 return Credentials(
98 access_key=self._access_key,
99 secret_key=self._secret_key,
100 token=self._session_token,
101 expiration=None,
102 )
103
[docs]
104 def refresh_credentials(self) -> None:
105 pass
106
107
[docs]
108class S3StorageProvider(BaseStorageProvider):
109 """
110 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
111 """
112
113 def __init__(
114 self,
115 region_name: str = "",
116 endpoint_url: str = "",
117 base_path: str = "",
118 credentials_provider: Optional[CredentialsProvider] = None,
119 config_dict: Optional[dict[str, Any]] = None,
120 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
121 verify: Optional[Union[bool, str]] = None,
122 **kwargs: Any,
123 ) -> None:
124 """
125 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
126
127 :param region_name: The AWS region where the S3 bucket is located.
128 :param endpoint_url: The custom endpoint URL for the S3 service.
129 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
130 :param credentials_provider: The provider to retrieve S3 credentials.
131 :param config_dict: Resolved MSC config.
132 :param telemetry_provider: A function that provides a telemetry instance.
133 :param verify: Controls SSL certificate verification. Can be ``True`` (verify using system CA bundle, default),
134 ``False`` (skip verification), or a string path to a custom CA certificate bundle.
135 """
136 super().__init__(
137 base_path=base_path,
138 provider_name=PROVIDER,
139 config_dict=config_dict,
140 telemetry_provider=telemetry_provider,
141 )
142
143 self._region_name = region_name
144 self._endpoint_url = endpoint_url
145 self._credentials_provider = credentials_provider
146 self._verify = verify
147
148 self._signature_version = kwargs.get("signature_version", "")
149 self._s3_client = self._create_s3_client(
150 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
151 response_checksum_validation=kwargs.get("response_checksum_validation"),
152 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
153 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT),
154 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT),
155 retries=kwargs.get("retries"),
156 )
157 self._transfer_config = TransferConfig(
158 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
159 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
160 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
161 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
162 use_threads=True,
163 )
164
165 self._rust_client = None
166 if "rust_client" in kwargs:
167 # Inherit the rust client options from the kwargs
168 rust_client_options = kwargs["rust_client"]
169 if "max_pool_connections" in kwargs:
170 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
171 if "max_concurrency" in kwargs:
172 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
173 if "multipart_chunksize" in kwargs:
174 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
175 if "read_timeout" in kwargs:
176 rust_client_options["read_timeout"] = kwargs["read_timeout"]
177 if "connect_timeout" in kwargs:
178 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
179 if self._signature_version == "UNSIGNED":
180 rust_client_options["skip_signature"] = True
181 self._rust_client = self._create_rust_client(rust_client_options)
182
183 def _is_directory_bucket(self, bucket: str) -> bool:
184 """
185 Determines if the bucket is a directory bucket based on bucket name.
186 """
187 # S3 Express buckets have a specific naming convention
188 return "--x-s3" in bucket
189
190 def _create_s3_client(
191 self,
192 request_checksum_calculation: Optional[str] = None,
193 response_checksum_validation: Optional[str] = None,
194 max_pool_connections: int = MAX_POOL_CONNECTIONS,
195 connect_timeout: Union[float, int, None] = None,
196 read_timeout: Union[float, int, None] = None,
197 retries: Optional[dict[str, Any]] = None,
198 ):
199 """
200 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
201
202 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
203 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
204 :param max_pool_connections: The maximum number of connections to keep in a connection pool.
205 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection.
206 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection.
207 :param retries: A dictionary for configuration related to retry behavior.
208
209 :return: The configured S3 client.
210 """
211 options = {
212 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
213 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
214 max_pool_connections=max_pool_connections,
215 connect_timeout=connect_timeout,
216 read_timeout=read_timeout,
217 retries=retries or {"mode": "standard"},
218 request_checksum_calculation=request_checksum_calculation,
219 response_checksum_validation=response_checksum_validation,
220 ),
221 }
222
223 if self._region_name:
224 options["region_name"] = self._region_name
225
226 if self._endpoint_url:
227 options["endpoint_url"] = self._endpoint_url
228
229 if self._verify is not None:
230 options["verify"] = self._verify
231
232 if self._credentials_provider:
233 creds = self._fetch_credentials()
234 if "expiry_time" in creds and creds["expiry_time"]:
235 # Use RefreshableCredentials if expiry_time provided.
236 refreshable_credentials = RefreshableCredentials.create_from_metadata(
237 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
238 )
239
240 botocore_session = get_session()
241 botocore_session._credentials = refreshable_credentials
242
243 boto3_session = boto3.Session(botocore_session=botocore_session)
244
245 return boto3_session.client("s3", **options)
246 else:
247 # Add static credentials to the options dictionary
248 options["aws_access_key_id"] = creds["access_key"]
249 options["aws_secret_access_key"] = creds["secret_key"]
250 if creds["token"]:
251 options["aws_session_token"] = creds["token"]
252
253 if self._signature_version:
254 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
255 signature_version=botocore.UNSIGNED
256 if self._signature_version == "UNSIGNED"
257 else self._signature_version
258 )
259 options["config"] = options["config"].merge(signature_config)
260
261 # Fallback to standard credential chain.
262 return boto3.client("s3", **options)
263
264 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
265 """
266 Creates and configures the rust client, using refreshable credentials if possible.
267 """
268 configs = dict(rust_client_options) if rust_client_options else {}
269
270 # Extract and parse retry configuration
271 retry_config = parse_retry_config(configs)
272
273 if self._region_name and "region_name" not in configs:
274 configs["region_name"] = self._region_name
275
276 if self._endpoint_url and "endpoint_url" not in configs:
277 configs["endpoint_url"] = self._endpoint_url
278
279 # If the user specifies a bucket, use it. Otherwise, use the base path.
280 if "bucket" not in configs:
281 bucket, _ = split_path(self._base_path)
282 configs["bucket"] = bucket
283
284 if "max_pool_connections" not in configs:
285 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
286
287 return RustClient(
288 provider=PROVIDER,
289 configs=configs,
290 credentials_provider=self._credentials_provider,
291 retry=retry_config,
292 )
293
294 def _fetch_credentials(self) -> dict:
295 """
296 Refreshes the S3 client if the current credentials are expired.
297 """
298 if not self._credentials_provider:
299 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
300 self._credentials_provider.refresh_credentials()
301 credentials = self._credentials_provider.get_credentials()
302 return {
303 "access_key": credentials.access_key,
304 "secret_key": credentials.secret_key,
305 "token": credentials.token,
306 "expiry_time": credentials.expiration,
307 }
308
309 def _translate_errors(
310 self,
311 func: Callable[[], _T],
312 operation: str,
313 bucket: str,
314 key: str,
315 ) -> _T:
316 """
317 Translates errors like timeouts and client errors.
318
319 :param func: The function that performs the actual S3 operation.
320 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
321 :param bucket: The name of the S3 bucket involved in the operation.
322 :param key: The key of the object within the S3 bucket.
323
324 :return: The result of the S3 operation, typically the return value of the `func` callable.
325 """
326 try:
327 return func()
328 except ClientError as error:
329 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
330 request_id = error.response["ResponseMetadata"].get("RequestId")
331 host_id = error.response["ResponseMetadata"].get("HostId")
332 error_code = error.response["Error"]["Code"]
333 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
334
335 if status_code == 404:
336 if error_code == "NoSuchUpload":
337 error_message = error.response["Error"]["Message"]
338 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
339 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
340 elif status_code == 412: # Precondition Failed
341 raise PreconditionFailedError(
342 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
343 ) from error
344 elif status_code == 429:
345 raise RetryableError(
346 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
347 ) from error
348 elif status_code == 503:
349 raise RetryableError(
350 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
351 ) from error
352 elif status_code == 501:
353 raise NotImplementedError(
354 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
355 ) from error
356 elif status_code == 408:
357 # 408 Request Timeout is from Google Cloud Storage
358 raise RetryableError(
359 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}"
360 ) from error
361 else:
362 raise RuntimeError(
363 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
364 f"error_type: {type(error).__name__}"
365 ) from error
366 except RustClientError as error:
367 message = error.args[0]
368 status_code = error.args[1]
369 if status_code == 404:
370 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
371 elif status_code == 403:
372 raise PermissionError(
373 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
374 ) from error
375 else:
376 raise RetryableError(
377 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
378 ) from error
379 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
380 raise RetryableError(
381 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
382 f"error_type: {type(error).__name__}"
383 ) from error
384 except RustRetryableError as error:
385 raise RetryableError(
386 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
387 f"error_type: {type(error).__name__}"
388 ) from error
389 except Exception as error:
390 raise RuntimeError(
391 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
392 ) from error
393
394 def _put_object(
395 self,
396 path: str,
397 body: bytes,
398 if_match: Optional[str] = None,
399 if_none_match: Optional[str] = None,
400 attributes: Optional[dict[str, str]] = None,
401 content_type: Optional[str] = None,
402 ) -> int:
403 """
404 Uploads an object to the specified S3 path.
405
406 :param path: The S3 path where the object will be uploaded.
407 :param body: The content of the object as bytes.
408 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
409 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
410 :param attributes: Optional attributes to attach to the object.
411 :param content_type: Optional Content-Type header value.
412 """
413 bucket, key = split_path(path)
414
415 def _invoke_api() -> int:
416 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
417 if content_type:
418 kwargs["ContentType"] = content_type
419 if self._is_directory_bucket(bucket):
420 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
421 if if_match:
422 kwargs["IfMatch"] = if_match
423 if if_none_match:
424 kwargs["IfNoneMatch"] = if_none_match
425 validated_attributes = validate_attributes(attributes)
426 if validated_attributes:
427 kwargs["Metadata"] = validated_attributes
428
429 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
430 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
431 if (
432 self._rust_client
433 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
434 and not path.endswith("/")
435 and all(key not in kwargs for key in rust_unsupported_feature_keys)
436 ):
437 run_async_rust_client_method(self._rust_client, "put", key, body)
438 else:
439 self._s3_client.put_object(**kwargs)
440
441 return len(body)
442
443 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
444
445 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
446 bucket, key = split_path(path)
447
448 def _invoke_api() -> bytes:
449 if byte_range:
450 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
451 if self._rust_client:
452 response = run_async_rust_client_method(
453 self._rust_client,
454 "get",
455 key,
456 byte_range,
457 )
458 return response
459 else:
460 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
461 else:
462 if self._rust_client:
463 response = run_async_rust_client_method(self._rust_client, "get", key)
464 return response
465 else:
466 response = self._s3_client.get_object(Bucket=bucket, Key=key)
467
468 return response["Body"].read()
469
470 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
471
472 def _copy_object(self, src_path: str, dest_path: str) -> int:
473 src_bucket, src_key = split_path(src_path)
474 dest_bucket, dest_key = split_path(dest_path)
475
476 src_object = self._get_object_metadata(src_path)
477
478 def _invoke_api() -> int:
479 self._s3_client.copy(
480 CopySource={"Bucket": src_bucket, "Key": src_key},
481 Bucket=dest_bucket,
482 Key=dest_key,
483 Config=self._transfer_config,
484 )
485
486 return src_object.content_length
487
488 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
489
490 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
491 bucket, key = split_path(path)
492
493 def _invoke_api() -> None:
494 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally
495 if if_match:
496 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
497 else:
498 self._s3_client.delete_object(Bucket=bucket, Key=key)
499
500 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
501
502 def _is_dir(self, path: str) -> bool:
503 # Ensure the path ends with '/' to mimic a directory
504 path = self._append_delimiter(path)
505
506 bucket, key = split_path(path)
507
508 def _invoke_api() -> bool:
509 # List objects with the given prefix
510 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
511
512 # Check if there are any contents or common prefixes
513 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
514
515 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
516
517 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
518 bucket, key = split_path(path)
519 if path.endswith("/") or (bucket and not key):
520 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
521 # which metadata is not guaranteed to exist for cases such as
522 # "virtual prefix" that was never explicitly created.
523 if self._is_dir(path):
524 return ObjectMetadata(
525 key=path,
526 type="directory",
527 content_length=0,
528 last_modified=AWARE_DATETIME_MIN,
529 )
530 else:
531 raise FileNotFoundError(f"Directory {path} does not exist.")
532 else:
533
534 def _invoke_api() -> ObjectMetadata:
535 response = self._s3_client.head_object(Bucket=bucket, Key=key)
536
537 return ObjectMetadata(
538 key=path,
539 type="file",
540 content_length=response["ContentLength"],
541 content_type=response.get("ContentType"),
542 last_modified=response["LastModified"],
543 etag=response["ETag"].strip('"') if "ETag" in response else None,
544 storage_class=response.get("StorageClass"),
545 metadata=response.get("Metadata"),
546 )
547
548 try:
549 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
550 except FileNotFoundError as error:
551 if strict:
552 # If the object does not exist on the given path, we will append a trailing slash and
553 # check if the path is a directory.
554 path = self._append_delimiter(path)
555 if self._is_dir(path):
556 return ObjectMetadata(
557 key=path,
558 type="directory",
559 content_length=0,
560 last_modified=AWARE_DATETIME_MIN,
561 )
562 raise error
563
564 def _list_objects(
565 self,
566 path: str,
567 start_after: Optional[str] = None,
568 end_at: Optional[str] = None,
569 include_directories: bool = False,
570 follow_symlinks: bool = True,
571 ) -> Iterator[ObjectMetadata]:
572 bucket, prefix = split_path(path)
573
574 # Get the prefix of the start_after and end_at paths relative to the bucket.
575 if start_after:
576 _, start_after = split_path(start_after)
577 if end_at:
578 _, end_at = split_path(end_at)
579
580 def _invoke_api() -> Iterator[ObjectMetadata]:
581 paginator = self._s3_client.get_paginator("list_objects_v2")
582 if include_directories:
583 page_iterator = paginator.paginate(
584 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
585 )
586 else:
587 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
588
589 for page in page_iterator:
590 for item in page.get("CommonPrefixes", []):
591 yield ObjectMetadata(
592 key=os.path.join(bucket, item["Prefix"].rstrip("/")),
593 type="directory",
594 content_length=0,
595 last_modified=AWARE_DATETIME_MIN,
596 )
597
598 # S3 guarantees lexicographical order for general purpose buckets (for
599 # normal S3) but not directory buckets (for S3 Express One Zone).
600 for response_object in page.get("Contents", []):
601 key = response_object["Key"]
602 if end_at is None or key <= end_at:
603 if key.endswith("/"):
604 if include_directories:
605 yield ObjectMetadata(
606 key=os.path.join(bucket, key.rstrip("/")),
607 type="directory",
608 content_length=0,
609 last_modified=response_object["LastModified"],
610 )
611 else:
612 yield ObjectMetadata(
613 key=os.path.join(bucket, key),
614 type="file",
615 content_length=response_object["Size"],
616 last_modified=response_object["LastModified"],
617 etag=response_object["ETag"].strip('"'),
618 storage_class=response_object.get("StorageClass"), # Pass storage_class
619 )
620 else:
621 return
622
623 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
624
625 def _upload_file(
626 self,
627 remote_path: str,
628 f: Union[str, IO],
629 attributes: Optional[dict[str, str]] = None,
630 content_type: Optional[str] = None,
631 ) -> int:
632 file_size: int = 0
633
634 if isinstance(f, str):
635 bucket, key = split_path(remote_path)
636 file_size = os.path.getsize(f)
637
638 # Upload small files
639 if file_size <= self._transfer_config.multipart_threshold:
640 if self._rust_client and not attributes and not content_type:
641 run_async_rust_client_method(self._rust_client, "upload", f, key)
642 else:
643 with open(f, "rb") as fp:
644 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
645 return file_size
646
647 # Upload large files using TransferConfig
648 def _invoke_api() -> int:
649 extra_args = {}
650 if content_type:
651 extra_args["ContentType"] = content_type
652 if self._is_directory_bucket(bucket):
653 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
654 validated_attributes = validate_attributes(attributes)
655 if validated_attributes:
656 extra_args["Metadata"] = validated_attributes
657 if self._rust_client and not extra_args:
658 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
659 else:
660 self._s3_client.upload_file(
661 Filename=f,
662 Bucket=bucket,
663 Key=key,
664 Config=self._transfer_config,
665 ExtraArgs=extra_args,
666 )
667
668 return file_size
669
670 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
671 else:
672 # Upload small files
673 f.seek(0, io.SEEK_END)
674 file_size = f.tell()
675 f.seek(0)
676
677 if file_size <= self._transfer_config.multipart_threshold:
678 if isinstance(f, io.StringIO):
679 self._put_object(
680 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
681 )
682 else:
683 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
684 return file_size
685
686 # Upload large files using TransferConfig
687 bucket, key = split_path(remote_path)
688
689 def _invoke_api() -> int:
690 extra_args = {}
691 if content_type:
692 extra_args["ContentType"] = content_type
693 if self._is_directory_bucket(bucket):
694 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
695 validated_attributes = validate_attributes(attributes)
696 if validated_attributes:
697 extra_args["Metadata"] = validated_attributes
698
699 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
700 data = f.getbuffer()
701 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
702 else:
703 self._s3_client.upload_fileobj(
704 Fileobj=f,
705 Bucket=bucket,
706 Key=key,
707 Config=self._transfer_config,
708 ExtraArgs=extra_args,
709 )
710
711 return file_size
712
713 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
714
715 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
716 if metadata is None:
717 metadata = self._get_object_metadata(remote_path)
718
719 if isinstance(f, str):
720 bucket, key = split_path(remote_path)
721 if os.path.dirname(f):
722 safe_makedirs(os.path.dirname(f))
723
724 # Download small files
725 if metadata.content_length <= self._transfer_config.multipart_threshold:
726 if self._rust_client:
727 run_async_rust_client_method(self._rust_client, "download", key, f)
728 else:
729 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
730 temp_file_path = fp.name
731 fp.write(self._get_object(remote_path))
732 os.rename(src=temp_file_path, dst=f)
733 return metadata.content_length
734
735 # Download large files using TransferConfig
736 def _invoke_api() -> int:
737 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
738 temp_file_path = fp.name
739 if self._rust_client:
740 run_async_rust_client_method(
741 self._rust_client, "download_multipart_to_file", key, temp_file_path
742 )
743 else:
744 self._s3_client.download_fileobj(
745 Bucket=bucket,
746 Key=key,
747 Fileobj=fp,
748 Config=self._transfer_config,
749 )
750
751 os.rename(src=temp_file_path, dst=f)
752
753 return metadata.content_length
754
755 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
756 else:
757 # Download small files
758 if metadata.content_length <= self._transfer_config.multipart_threshold:
759 response = self._get_object(remote_path)
760 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
761 # so we need to check whether `.decode()` is available.
762 if isinstance(f, io.StringIO):
763 if hasattr(response, "decode"):
764 f.write(response.decode("utf-8"))
765 else:
766 f.write(codecs.decode(memoryview(response), "utf-8"))
767 else:
768 f.write(response)
769 return metadata.content_length
770
771 # Download large files using TransferConfig
772 bucket, key = split_path(remote_path)
773
774 def _invoke_api() -> int:
775 self._s3_client.download_fileobj(
776 Bucket=bucket,
777 Key=key,
778 Fileobj=f,
779 Config=self._transfer_config,
780 )
781
782 return metadata.content_length
783
784 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)