1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29
30from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
31
32from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
33from ..rust_utils import parse_retry_config, run_async_rust_client_method
34from ..telemetry import Telemetry
35from ..types import (
36 AWARE_DATETIME_MIN,
37 Credentials,
38 CredentialsProvider,
39 ObjectMetadata,
40 PreconditionFailedError,
41 Range,
42 RetryableError,
43)
44from ..utils import (
45 get_available_cpu_count,
46 split_path,
47 validate_attributes,
48)
49from .base import BaseStorageProvider
50
51_T = TypeVar("_T")
52
53# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
54MAX_POOL_CONNECTIONS = max(
55 64,
56 get_available_cpu_count(),
57 int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
58)
59
60MiB = 1024 * 1024
61
62# Python and Rust share the same multipart_threshold to keep the code simple.
63MULTIPART_THRESHOLD = 64 * MiB
64MULTIPART_CHUNKSIZE = 32 * MiB
65IO_CHUNKSIZE = 32 * MiB
66PYTHON_MAX_CONCURRENCY = 8
67
68PROVIDER = "s3"
69
70EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
71
72
[docs]
73class StaticS3CredentialsProvider(CredentialsProvider):
74 """
75 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
76 """
77
78 _access_key: str
79 _secret_key: str
80 _session_token: Optional[str]
81
82 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
83 """
84 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
85 session token.
86
87 :param access_key: The access key for S3 authentication.
88 :param secret_key: The secret key for S3 authentication.
89 :param session_token: An optional session token for temporary credentials.
90 """
91 self._access_key = access_key
92 self._secret_key = secret_key
93 self._session_token = session_token
94
[docs]
95 def get_credentials(self) -> Credentials:
96 return Credentials(
97 access_key=self._access_key,
98 secret_key=self._secret_key,
99 token=self._session_token,
100 expiration=None,
101 )
102
[docs]
103 def refresh_credentials(self) -> None:
104 pass
105
106
[docs]
107class S3StorageProvider(BaseStorageProvider):
108 """
109 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
110 """
111
112 def __init__(
113 self,
114 region_name: str = "",
115 endpoint_url: str = "",
116 base_path: str = "",
117 credentials_provider: Optional[CredentialsProvider] = None,
118 config_dict: Optional[dict[str, Any]] = None,
119 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
120 verify: Optional[Union[bool, str]] = None,
121 **kwargs: Any,
122 ) -> None:
123 """
124 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
125
126 :param region_name: The AWS region where the S3 bucket is located.
127 :param endpoint_url: The custom endpoint URL for the S3 service.
128 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
129 :param credentials_provider: The provider to retrieve S3 credentials.
130 :param config_dict: Resolved MSC config.
131 :param telemetry_provider: A function that provides a telemetry instance.
132 :param verify: Controls SSL certificate verification. Can be ``True`` (verify using system CA bundle, default),
133 ``False`` (skip verification), or a string path to a custom CA certificate bundle.
134 """
135 super().__init__(
136 base_path=base_path,
137 provider_name=PROVIDER,
138 config_dict=config_dict,
139 telemetry_provider=telemetry_provider,
140 )
141
142 self._region_name = region_name
143 self._endpoint_url = endpoint_url
144 self._credentials_provider = credentials_provider
145 self._verify = verify
146
147 self._signature_version = kwargs.get("signature_version", "")
148 self._s3_client = self._create_s3_client(
149 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
150 response_checksum_validation=kwargs.get("response_checksum_validation"),
151 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
152 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT),
153 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT),
154 retries=kwargs.get("retries"),
155 )
156 self._transfer_config = TransferConfig(
157 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
158 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
159 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
160 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
161 use_threads=True,
162 )
163
164 self._rust_client = None
165 if "rust_client" in kwargs:
166 # Inherit the rust client options from the kwargs
167 rust_client_options = kwargs["rust_client"]
168 if "max_pool_connections" in kwargs:
169 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
170 if "max_concurrency" in kwargs:
171 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
172 if "multipart_chunksize" in kwargs:
173 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
174 if "read_timeout" in kwargs:
175 rust_client_options["read_timeout"] = kwargs["read_timeout"]
176 if "connect_timeout" in kwargs:
177 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
178 if self._signature_version == "UNSIGNED":
179 rust_client_options["skip_signature"] = True
180 self._rust_client = self._create_rust_client(rust_client_options)
181
182 def _is_directory_bucket(self, bucket: str) -> bool:
183 """
184 Determines if the bucket is a directory bucket based on bucket name.
185 """
186 # S3 Express buckets have a specific naming convention
187 return "--x-s3" in bucket
188
189 def _create_s3_client(
190 self,
191 request_checksum_calculation: Optional[str] = None,
192 response_checksum_validation: Optional[str] = None,
193 max_pool_connections: int = MAX_POOL_CONNECTIONS,
194 connect_timeout: Union[float, int, None] = None,
195 read_timeout: Union[float, int, None] = None,
196 retries: Optional[dict[str, Any]] = None,
197 ):
198 """
199 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
200
201 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
202 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
203 :param max_pool_connections: The maximum number of connections to keep in a connection pool.
204 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection.
205 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection.
206 :param retries: A dictionary for configuration related to retry behavior.
207
208 :return: The configured S3 client.
209 """
210 options = {
211 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
212 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
213 max_pool_connections=max_pool_connections,
214 connect_timeout=connect_timeout,
215 read_timeout=read_timeout,
216 retries=retries or {"mode": "standard"},
217 request_checksum_calculation=request_checksum_calculation,
218 response_checksum_validation=response_checksum_validation,
219 ),
220 }
221
222 if self._region_name:
223 options["region_name"] = self._region_name
224
225 if self._endpoint_url:
226 options["endpoint_url"] = self._endpoint_url
227
228 if self._verify is not None:
229 options["verify"] = self._verify
230
231 if self._credentials_provider:
232 creds = self._fetch_credentials()
233 if "expiry_time" in creds and creds["expiry_time"]:
234 # Use RefreshableCredentials if expiry_time provided.
235 refreshable_credentials = RefreshableCredentials.create_from_metadata(
236 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
237 )
238
239 botocore_session = get_session()
240 botocore_session._credentials = refreshable_credentials
241
242 boto3_session = boto3.Session(botocore_session=botocore_session)
243
244 return boto3_session.client("s3", **options)
245 else:
246 # Add static credentials to the options dictionary
247 options["aws_access_key_id"] = creds["access_key"]
248 options["aws_secret_access_key"] = creds["secret_key"]
249 if creds["token"]:
250 options["aws_session_token"] = creds["token"]
251
252 if self._signature_version:
253 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
254 signature_version=botocore.UNSIGNED
255 if self._signature_version == "UNSIGNED"
256 else self._signature_version
257 )
258 options["config"] = options["config"].merge(signature_config)
259
260 # Fallback to standard credential chain.
261 return boto3.client("s3", **options)
262
263 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
264 """
265 Creates and configures the rust client, using refreshable credentials if possible.
266 """
267 configs = dict(rust_client_options) if rust_client_options else {}
268
269 # Extract and parse retry configuration
270 retry_config = parse_retry_config(configs)
271
272 if self._region_name and "region_name" not in configs:
273 configs["region_name"] = self._region_name
274
275 if self._endpoint_url and "endpoint_url" not in configs:
276 configs["endpoint_url"] = self._endpoint_url
277
278 # If the user specifies a bucket, use it. Otherwise, use the base path.
279 if "bucket" not in configs:
280 bucket, _ = split_path(self._base_path)
281 configs["bucket"] = bucket
282
283 if "max_pool_connections" not in configs:
284 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
285
286 return RustClient(
287 provider=PROVIDER,
288 configs=configs,
289 credentials_provider=self._credentials_provider,
290 retry=retry_config,
291 )
292
293 def _fetch_credentials(self) -> dict:
294 """
295 Refreshes the S3 client if the current credentials are expired.
296 """
297 if not self._credentials_provider:
298 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
299 self._credentials_provider.refresh_credentials()
300 credentials = self._credentials_provider.get_credentials()
301 return {
302 "access_key": credentials.access_key,
303 "secret_key": credentials.secret_key,
304 "token": credentials.token,
305 "expiry_time": credentials.expiration,
306 }
307
308 def _translate_errors(
309 self,
310 func: Callable[[], _T],
311 operation: str,
312 bucket: str,
313 key: str,
314 ) -> _T:
315 """
316 Translates errors like timeouts and client errors.
317
318 :param func: The function that performs the actual S3 operation.
319 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
320 :param bucket: The name of the S3 bucket involved in the operation.
321 :param key: The key of the object within the S3 bucket.
322
323 :return: The result of the S3 operation, typically the return value of the `func` callable.
324 """
325 try:
326 return func()
327 except ClientError as error:
328 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
329 request_id = error.response["ResponseMetadata"].get("RequestId")
330 host_id = error.response["ResponseMetadata"].get("HostId")
331 error_code = error.response["Error"]["Code"]
332 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
333
334 if status_code == 404:
335 if error_code == "NoSuchUpload":
336 error_message = error.response["Error"]["Message"]
337 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
338 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
339 elif status_code == 412: # Precondition Failed
340 raise PreconditionFailedError(
341 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
342 ) from error
343 elif status_code == 429:
344 raise RetryableError(
345 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
346 ) from error
347 elif status_code == 503:
348 raise RetryableError(
349 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
350 ) from error
351 elif status_code == 501:
352 raise NotImplementedError(
353 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
354 ) from error
355 elif status_code == 408:
356 # 408 Request Timeout is from Google Cloud Storage
357 raise RetryableError(
358 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}"
359 ) from error
360 else:
361 raise RuntimeError(
362 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
363 f"error_type: {type(error).__name__}"
364 ) from error
365 except RustClientError as error:
366 message = error.args[0]
367 status_code = error.args[1]
368 if status_code == 404:
369 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
370 elif status_code == 403:
371 raise PermissionError(
372 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
373 ) from error
374 else:
375 raise RetryableError(
376 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
377 ) from error
378 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
379 raise RetryableError(
380 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
381 f"error_type: {type(error).__name__}"
382 ) from error
383 except RustRetryableError as error:
384 raise RetryableError(
385 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
386 f"error_type: {type(error).__name__}"
387 ) from error
388 except Exception as error:
389 raise RuntimeError(
390 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
391 ) from error
392
393 def _put_object(
394 self,
395 path: str,
396 body: bytes,
397 if_match: Optional[str] = None,
398 if_none_match: Optional[str] = None,
399 attributes: Optional[dict[str, str]] = None,
400 content_type: Optional[str] = None,
401 ) -> int:
402 """
403 Uploads an object to the specified S3 path.
404
405 :param path: The S3 path where the object will be uploaded.
406 :param body: The content of the object as bytes.
407 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
408 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
409 :param attributes: Optional attributes to attach to the object.
410 :param content_type: Optional Content-Type header value.
411 """
412 bucket, key = split_path(path)
413
414 def _invoke_api() -> int:
415 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
416 if content_type:
417 kwargs["ContentType"] = content_type
418 if self._is_directory_bucket(bucket):
419 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
420 if if_match:
421 kwargs["IfMatch"] = if_match
422 if if_none_match:
423 kwargs["IfNoneMatch"] = if_none_match
424 validated_attributes = validate_attributes(attributes)
425 if validated_attributes:
426 kwargs["Metadata"] = validated_attributes
427
428 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
429 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
430 if (
431 self._rust_client
432 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
433 and not path.endswith("/")
434 and all(key not in kwargs for key in rust_unsupported_feature_keys)
435 ):
436 run_async_rust_client_method(self._rust_client, "put", key, body)
437 else:
438 self._s3_client.put_object(**kwargs)
439
440 return len(body)
441
442 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
443
444 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
445 bucket, key = split_path(path)
446
447 def _invoke_api() -> bytes:
448 if byte_range:
449 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
450 if self._rust_client:
451 response = run_async_rust_client_method(
452 self._rust_client,
453 "get",
454 key,
455 byte_range,
456 )
457 return response
458 else:
459 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
460 else:
461 if self._rust_client:
462 response = run_async_rust_client_method(self._rust_client, "get", key)
463 return response
464 else:
465 response = self._s3_client.get_object(Bucket=bucket, Key=key)
466
467 return response["Body"].read()
468
469 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
470
471 def _copy_object(self, src_path: str, dest_path: str) -> int:
472 src_bucket, src_key = split_path(src_path)
473 dest_bucket, dest_key = split_path(dest_path)
474
475 src_object = self._get_object_metadata(src_path)
476
477 def _invoke_api() -> int:
478 self._s3_client.copy(
479 CopySource={"Bucket": src_bucket, "Key": src_key},
480 Bucket=dest_bucket,
481 Key=dest_key,
482 Config=self._transfer_config,
483 )
484
485 return src_object.content_length
486
487 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
488
489 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
490 bucket, key = split_path(path)
491
492 def _invoke_api() -> None:
493 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally
494 if if_match:
495 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
496 else:
497 self._s3_client.delete_object(Bucket=bucket, Key=key)
498
499 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
500
501 def _is_dir(self, path: str) -> bool:
502 # Ensure the path ends with '/' to mimic a directory
503 path = self._append_delimiter(path)
504
505 bucket, key = split_path(path)
506
507 def _invoke_api() -> bool:
508 # List objects with the given prefix
509 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
510
511 # Check if there are any contents or common prefixes
512 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
513
514 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
515
516 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
517 bucket, key = split_path(path)
518 if path.endswith("/") or (bucket and not key):
519 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
520 # which metadata is not guaranteed to exist for cases such as
521 # "virtual prefix" that was never explicitly created.
522 if self._is_dir(path):
523 return ObjectMetadata(
524 key=path,
525 type="directory",
526 content_length=0,
527 last_modified=AWARE_DATETIME_MIN,
528 )
529 else:
530 raise FileNotFoundError(f"Directory {path} does not exist.")
531 else:
532
533 def _invoke_api() -> ObjectMetadata:
534 response = self._s3_client.head_object(Bucket=bucket, Key=key)
535
536 return ObjectMetadata(
537 key=path,
538 type="file",
539 content_length=response["ContentLength"],
540 content_type=response.get("ContentType"),
541 last_modified=response["LastModified"],
542 etag=response["ETag"].strip('"') if "ETag" in response else None,
543 storage_class=response.get("StorageClass"),
544 metadata=response.get("Metadata"),
545 )
546
547 try:
548 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
549 except FileNotFoundError as error:
550 if strict:
551 # If the object does not exist on the given path, we will append a trailing slash and
552 # check if the path is a directory.
553 path = self._append_delimiter(path)
554 if self._is_dir(path):
555 return ObjectMetadata(
556 key=path,
557 type="directory",
558 content_length=0,
559 last_modified=AWARE_DATETIME_MIN,
560 )
561 raise error
562
563 def _list_objects(
564 self,
565 path: str,
566 start_after: Optional[str] = None,
567 end_at: Optional[str] = None,
568 include_directories: bool = False,
569 follow_symlinks: bool = True,
570 ) -> Iterator[ObjectMetadata]:
571 bucket, prefix = split_path(path)
572
573 # Get the prefix of the start_after and end_at paths relative to the bucket.
574 if start_after:
575 _, start_after = split_path(start_after)
576 if end_at:
577 _, end_at = split_path(end_at)
578
579 def _invoke_api() -> Iterator[ObjectMetadata]:
580 paginator = self._s3_client.get_paginator("list_objects_v2")
581 if include_directories:
582 page_iterator = paginator.paginate(
583 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
584 )
585 else:
586 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
587
588 for page in page_iterator:
589 for item in page.get("CommonPrefixes", []):
590 yield ObjectMetadata(
591 key=os.path.join(bucket, item["Prefix"].rstrip("/")),
592 type="directory",
593 content_length=0,
594 last_modified=AWARE_DATETIME_MIN,
595 )
596
597 # S3 guarantees lexicographical order for general purpose buckets (for
598 # normal S3) but not directory buckets (for S3 Express One Zone).
599 for response_object in page.get("Contents", []):
600 key = response_object["Key"]
601 if end_at is None or key <= end_at:
602 if key.endswith("/"):
603 if include_directories:
604 yield ObjectMetadata(
605 key=os.path.join(bucket, key.rstrip("/")),
606 type="directory",
607 content_length=0,
608 last_modified=response_object["LastModified"],
609 )
610 else:
611 yield ObjectMetadata(
612 key=os.path.join(bucket, key),
613 type="file",
614 content_length=response_object["Size"],
615 last_modified=response_object["LastModified"],
616 etag=response_object["ETag"].strip('"'),
617 storage_class=response_object.get("StorageClass"), # Pass storage_class
618 )
619 else:
620 return
621
622 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
623
624 def _upload_file(
625 self,
626 remote_path: str,
627 f: Union[str, IO],
628 attributes: Optional[dict[str, str]] = None,
629 content_type: Optional[str] = None,
630 ) -> int:
631 file_size: int = 0
632
633 if isinstance(f, str):
634 bucket, key = split_path(remote_path)
635 file_size = os.path.getsize(f)
636
637 # Upload small files
638 if file_size <= self._transfer_config.multipart_threshold:
639 if self._rust_client and not attributes and not content_type:
640 run_async_rust_client_method(self._rust_client, "upload", f, key)
641 else:
642 with open(f, "rb") as fp:
643 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
644 return file_size
645
646 # Upload large files using TransferConfig
647 def _invoke_api() -> int:
648 extra_args = {}
649 if content_type:
650 extra_args["ContentType"] = content_type
651 if self._is_directory_bucket(bucket):
652 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
653 validated_attributes = validate_attributes(attributes)
654 if validated_attributes:
655 extra_args["Metadata"] = validated_attributes
656 if self._rust_client and not extra_args:
657 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
658 else:
659 self._s3_client.upload_file(
660 Filename=f,
661 Bucket=bucket,
662 Key=key,
663 Config=self._transfer_config,
664 ExtraArgs=extra_args,
665 )
666
667 return file_size
668
669 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
670 else:
671 # Upload small files
672 f.seek(0, io.SEEK_END)
673 file_size = f.tell()
674 f.seek(0)
675
676 if file_size <= self._transfer_config.multipart_threshold:
677 if isinstance(f, io.StringIO):
678 self._put_object(
679 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
680 )
681 else:
682 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
683 return file_size
684
685 # Upload large files using TransferConfig
686 bucket, key = split_path(remote_path)
687
688 def _invoke_api() -> int:
689 extra_args = {}
690 if content_type:
691 extra_args["ContentType"] = content_type
692 if self._is_directory_bucket(bucket):
693 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
694 validated_attributes = validate_attributes(attributes)
695 if validated_attributes:
696 extra_args["Metadata"] = validated_attributes
697
698 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
699 data = f.getbuffer()
700 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
701 else:
702 self._s3_client.upload_fileobj(
703 Fileobj=f,
704 Bucket=bucket,
705 Key=key,
706 Config=self._transfer_config,
707 ExtraArgs=extra_args,
708 )
709
710 return file_size
711
712 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
713
714 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
715 if metadata is None:
716 metadata = self._get_object_metadata(remote_path)
717
718 if isinstance(f, str):
719 bucket, key = split_path(remote_path)
720 if os.path.dirname(f):
721 os.makedirs(os.path.dirname(f), exist_ok=True)
722
723 # Download small files
724 if metadata.content_length <= self._transfer_config.multipart_threshold:
725 if self._rust_client:
726 run_async_rust_client_method(self._rust_client, "download", key, f)
727 else:
728 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
729 temp_file_path = fp.name
730 fp.write(self._get_object(remote_path))
731 os.rename(src=temp_file_path, dst=f)
732 return metadata.content_length
733
734 # Download large files using TransferConfig
735 def _invoke_api() -> int:
736 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
737 temp_file_path = fp.name
738 if self._rust_client:
739 run_async_rust_client_method(
740 self._rust_client, "download_multipart_to_file", key, temp_file_path
741 )
742 else:
743 self._s3_client.download_fileobj(
744 Bucket=bucket,
745 Key=key,
746 Fileobj=fp,
747 Config=self._transfer_config,
748 )
749
750 os.rename(src=temp_file_path, dst=f)
751
752 return metadata.content_length
753
754 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
755 else:
756 # Download small files
757 if metadata.content_length <= self._transfer_config.multipart_threshold:
758 response = self._get_object(remote_path)
759 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
760 # so we need to check whether `.decode()` is available.
761 if isinstance(f, io.StringIO):
762 if hasattr(response, "decode"):
763 f.write(response.decode("utf-8"))
764 else:
765 f.write(codecs.decode(memoryview(response), "utf-8"))
766 else:
767 f.write(response)
768 return metadata.content_length
769
770 # Download large files using TransferConfig
771 bucket, key = split_path(remote_path)
772
773 def _invoke_api() -> int:
774 self._s3_client.download_fileobj(
775 Bucket=bucket,
776 Key=key,
777 Fileobj=f,
778 Config=self._transfer_config,
779 )
780
781 return metadata.content_length
782
783 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)