1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29
30from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
31
32from ..rust_utils import run_async_rust_client_method
33from ..telemetry import Telemetry
34from ..types import (
35 AWARE_DATETIME_MIN,
36 Credentials,
37 CredentialsProvider,
38 ObjectMetadata,
39 PreconditionFailedError,
40 Range,
41 RetryableError,
42)
43from ..utils import (
44 get_available_cpu_count,
45 split_path,
46 validate_attributes,
47)
48from .base import BaseStorageProvider
49
50_T = TypeVar("_T")
51
52# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
53MAX_POOL_CONNECTIONS = max(
54 64,
55 get_available_cpu_count(),
56 int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
57)
58
59MiB = 1024 * 1024
60
61# Python and Rust share the same multipart_threshold to keep the code simple.
62MULTIPART_THRESHOLD = 64 * MiB
63MULTIPART_CHUNKSIZE = 32 * MiB
64IO_CHUNKSIZE = 32 * MiB
65PYTHON_MAX_CONCURRENCY = 8
66
67PROVIDER = "s3"
68
69EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
70
71
[docs]
72class StaticS3CredentialsProvider(CredentialsProvider):
73 """
74 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
75 """
76
77 _access_key: str
78 _secret_key: str
79 _session_token: Optional[str]
80
81 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
82 """
83 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
84 session token.
85
86 :param access_key: The access key for S3 authentication.
87 :param secret_key: The secret key for S3 authentication.
88 :param session_token: An optional session token for temporary credentials.
89 """
90 self._access_key = access_key
91 self._secret_key = secret_key
92 self._session_token = session_token
93
[docs]
94 def get_credentials(self) -> Credentials:
95 return Credentials(
96 access_key=self._access_key,
97 secret_key=self._secret_key,
98 token=self._session_token,
99 expiration=None,
100 )
101
[docs]
102 def refresh_credentials(self) -> None:
103 pass
104
105
[docs]
106class S3StorageProvider(BaseStorageProvider):
107 """
108 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
109 """
110
111 def __init__(
112 self,
113 region_name: str = "",
114 endpoint_url: str = "",
115 base_path: str = "",
116 credentials_provider: Optional[CredentialsProvider] = None,
117 config_dict: Optional[dict[str, Any]] = None,
118 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
119 verify: Optional[Union[bool, str]] = None,
120 **kwargs: Any,
121 ) -> None:
122 """
123 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
124
125 :param region_name: The AWS region where the S3 bucket is located.
126 :param endpoint_url: The custom endpoint URL for the S3 service.
127 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
128 :param credentials_provider: The provider to retrieve S3 credentials.
129 :param config_dict: Resolved MSC config.
130 :param telemetry_provider: A function that provides a telemetry instance.
131 :param verify: Controls SSL certificate verification. Can be ``True`` (verify using system CA bundle, default),
132 ``False`` (skip verification), or a string path to a custom CA certificate bundle.
133 """
134 super().__init__(
135 base_path=base_path,
136 provider_name=PROVIDER,
137 config_dict=config_dict,
138 telemetry_provider=telemetry_provider,
139 )
140
141 self._region_name = region_name
142 self._endpoint_url = endpoint_url
143 self._credentials_provider = credentials_provider
144 self._verify = verify
145
146 self._signature_version = kwargs.get("signature_version", "")
147 self._s3_client = self._create_s3_client(
148 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
149 response_checksum_validation=kwargs.get("response_checksum_validation"),
150 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
151 connect_timeout=kwargs.get("connect_timeout"),
152 read_timeout=kwargs.get("read_timeout"),
153 retries=kwargs.get("retries"),
154 )
155 self._transfer_config = TransferConfig(
156 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
157 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
158 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
159 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
160 use_threads=True,
161 )
162
163 self._rust_client = None
164 if "rust_client" in kwargs:
165 # Inherit the rust client options from the kwargs
166 rust_client_options = kwargs["rust_client"]
167 if "max_pool_connections" in kwargs:
168 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
169 if "max_concurrency" in kwargs:
170 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
171 if "multipart_chunksize" in kwargs:
172 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
173 if "read_timeout" in kwargs:
174 rust_client_options["read_timeout"] = kwargs["read_timeout"]
175 if "connect_timeout" in kwargs:
176 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
177 if self._signature_version == "UNSIGNED":
178 rust_client_options["skip_signature"] = True
179 self._rust_client = self._create_rust_client(rust_client_options)
180
181 def _is_directory_bucket(self, bucket: str) -> bool:
182 """
183 Determines if the bucket is a directory bucket based on bucket name.
184 """
185 # S3 Express buckets have a specific naming convention
186 return "--x-s3" in bucket
187
188 def _create_s3_client(
189 self,
190 request_checksum_calculation: Optional[str] = None,
191 response_checksum_validation: Optional[str] = None,
192 max_pool_connections: int = MAX_POOL_CONNECTIONS,
193 connect_timeout: Union[float, int, None] = None,
194 read_timeout: Union[float, int, None] = None,
195 retries: Optional[dict[str, Any]] = None,
196 ):
197 """
198 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
199
200 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
201 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
202 :param max_pool_connections: The maximum number of connections to keep in a connection pool.
203 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection.
204 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection.
205 :param retries: A dictionary for configuration related to retry behavior.
206
207 :return: The configured S3 client.
208 """
209 options = {
210 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
211 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
212 max_pool_connections=max_pool_connections,
213 connect_timeout=connect_timeout,
214 read_timeout=read_timeout,
215 retries=retries or {"mode": "standard"},
216 request_checksum_calculation=request_checksum_calculation,
217 response_checksum_validation=response_checksum_validation,
218 ),
219 }
220
221 if self._region_name:
222 options["region_name"] = self._region_name
223
224 if self._endpoint_url:
225 options["endpoint_url"] = self._endpoint_url
226
227 if self._verify is not None:
228 options["verify"] = self._verify
229
230 if self._credentials_provider:
231 creds = self._fetch_credentials()
232 if "expiry_time" in creds and creds["expiry_time"]:
233 # Use RefreshableCredentials if expiry_time provided.
234 refreshable_credentials = RefreshableCredentials.create_from_metadata(
235 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
236 )
237
238 botocore_session = get_session()
239 botocore_session._credentials = refreshable_credentials
240
241 boto3_session = boto3.Session(botocore_session=botocore_session)
242
243 return boto3_session.client("s3", **options)
244 else:
245 # Add static credentials to the options dictionary
246 options["aws_access_key_id"] = creds["access_key"]
247 options["aws_secret_access_key"] = creds["secret_key"]
248 if creds["token"]:
249 options["aws_session_token"] = creds["token"]
250
251 if self._signature_version:
252 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
253 signature_version=botocore.UNSIGNED
254 if self._signature_version == "UNSIGNED"
255 else self._signature_version
256 )
257 options["config"] = options["config"].merge(signature_config)
258
259 # Fallback to standard credential chain.
260 return boto3.client("s3", **options)
261
262 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
263 """
264 Creates and configures the rust client, using refreshable credentials if possible.
265 """
266 configs = dict(rust_client_options) if rust_client_options else {}
267
268 if self._region_name and "region_name" not in configs:
269 configs["region_name"] = self._region_name
270
271 if self._endpoint_url and "endpoint_url" not in configs:
272 configs["endpoint_url"] = self._endpoint_url
273
274 # If the user specifies a bucket, use it. Otherwise, use the base path.
275 if "bucket" not in configs:
276 bucket, _ = split_path(self._base_path)
277 configs["bucket"] = bucket
278
279 if "max_pool_connections" not in configs:
280 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
281
282 return RustClient(
283 provider=PROVIDER,
284 configs=configs,
285 credentials_provider=self._credentials_provider,
286 )
287
288 def _fetch_credentials(self) -> dict:
289 """
290 Refreshes the S3 client if the current credentials are expired.
291 """
292 if not self._credentials_provider:
293 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
294 self._credentials_provider.refresh_credentials()
295 credentials = self._credentials_provider.get_credentials()
296 return {
297 "access_key": credentials.access_key,
298 "secret_key": credentials.secret_key,
299 "token": credentials.token,
300 "expiry_time": credentials.expiration,
301 }
302
303 def _translate_errors(
304 self,
305 func: Callable[[], _T],
306 operation: str,
307 bucket: str,
308 key: str,
309 ) -> _T:
310 """
311 Translates errors like timeouts and client errors.
312
313 :param func: The function that performs the actual S3 operation.
314 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
315 :param bucket: The name of the S3 bucket involved in the operation.
316 :param key: The key of the object within the S3 bucket.
317
318 :return: The result of the S3 operation, typically the return value of the `func` callable.
319 """
320 try:
321 return func()
322 except ClientError as error:
323 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
324 request_id = error.response["ResponseMetadata"].get("RequestId")
325 host_id = error.response["ResponseMetadata"].get("HostId")
326 error_code = error.response["Error"]["Code"]
327 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
328
329 if status_code == 404:
330 if error_code == "NoSuchUpload":
331 error_message = error.response["Error"]["Message"]
332 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
333 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
334 elif status_code == 412: # Precondition Failed
335 raise PreconditionFailedError(
336 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
337 ) from error
338 elif status_code == 429:
339 raise RetryableError(
340 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
341 ) from error
342 elif status_code == 503:
343 raise RetryableError(
344 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
345 ) from error
346 elif status_code == 501:
347 raise NotImplementedError(
348 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
349 ) from error
350 else:
351 raise RuntimeError(
352 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
353 f"error_type: {type(error).__name__}"
354 ) from error
355 except RustClientError as error:
356 message = error.args[0]
357 status_code = error.args[1]
358 if status_code == 404:
359 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
360 elif status_code == 403:
361 raise PermissionError(
362 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
363 ) from error
364 else:
365 raise RetryableError(
366 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
367 ) from error
368 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
369 raise RetryableError(
370 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
371 f"error_type: {type(error).__name__}"
372 ) from error
373 except RustRetryableError as error:
374 raise RetryableError(
375 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
376 f"error_type: {type(error).__name__}"
377 ) from error
378 except Exception as error:
379 raise RuntimeError(
380 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
381 ) from error
382
383 def _put_object(
384 self,
385 path: str,
386 body: bytes,
387 if_match: Optional[str] = None,
388 if_none_match: Optional[str] = None,
389 attributes: Optional[dict[str, str]] = None,
390 content_type: Optional[str] = None,
391 ) -> int:
392 """
393 Uploads an object to the specified S3 path.
394
395 :param path: The S3 path where the object will be uploaded.
396 :param body: The content of the object as bytes.
397 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
398 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
399 :param attributes: Optional attributes to attach to the object.
400 :param content_type: Optional Content-Type header value.
401 """
402 bucket, key = split_path(path)
403
404 def _invoke_api() -> int:
405 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
406 if content_type:
407 kwargs["ContentType"] = content_type
408 if self._is_directory_bucket(bucket):
409 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
410 if if_match:
411 kwargs["IfMatch"] = if_match
412 if if_none_match:
413 kwargs["IfNoneMatch"] = if_none_match
414 validated_attributes = validate_attributes(attributes)
415 if validated_attributes:
416 kwargs["Metadata"] = validated_attributes
417
418 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
419 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
420 if (
421 self._rust_client
422 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
423 and not path.endswith("/")
424 and all(key not in kwargs for key in rust_unsupported_feature_keys)
425 ):
426 run_async_rust_client_method(self._rust_client, "put", key, body)
427 else:
428 self._s3_client.put_object(**kwargs)
429
430 return len(body)
431
432 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
433
434 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
435 bucket, key = split_path(path)
436
437 def _invoke_api() -> bytes:
438 if byte_range:
439 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
440 if self._rust_client:
441 response = run_async_rust_client_method(
442 self._rust_client,
443 "get",
444 key,
445 byte_range,
446 )
447 return response
448 else:
449 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
450 else:
451 if self._rust_client:
452 response = run_async_rust_client_method(self._rust_client, "get", key)
453 return response
454 else:
455 response = self._s3_client.get_object(Bucket=bucket, Key=key)
456
457 return response["Body"].read()
458
459 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
460
461 def _copy_object(self, src_path: str, dest_path: str) -> int:
462 src_bucket, src_key = split_path(src_path)
463 dest_bucket, dest_key = split_path(dest_path)
464
465 src_object = self._get_object_metadata(src_path)
466
467 def _invoke_api() -> int:
468 self._s3_client.copy(
469 CopySource={"Bucket": src_bucket, "Key": src_key},
470 Bucket=dest_bucket,
471 Key=dest_key,
472 Config=self._transfer_config,
473 )
474
475 return src_object.content_length
476
477 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
478
479 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
480 bucket, key = split_path(path)
481
482 def _invoke_api() -> None:
483 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally
484 if if_match:
485 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
486 else:
487 self._s3_client.delete_object(Bucket=bucket, Key=key)
488
489 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
490
491 def _is_dir(self, path: str) -> bool:
492 # Ensure the path ends with '/' to mimic a directory
493 path = self._append_delimiter(path)
494
495 bucket, key = split_path(path)
496
497 def _invoke_api() -> bool:
498 # List objects with the given prefix
499 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
500
501 # Check if there are any contents or common prefixes
502 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
503
504 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
505
506 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
507 bucket, key = split_path(path)
508 if path.endswith("/") or (bucket and not key):
509 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
510 # which metadata is not guaranteed to exist for cases such as
511 # "virtual prefix" that was never explicitly created.
512 if self._is_dir(path):
513 return ObjectMetadata(
514 key=path,
515 type="directory",
516 content_length=0,
517 last_modified=AWARE_DATETIME_MIN,
518 )
519 else:
520 raise FileNotFoundError(f"Directory {path} does not exist.")
521 else:
522
523 def _invoke_api() -> ObjectMetadata:
524 response = self._s3_client.head_object(Bucket=bucket, Key=key)
525
526 return ObjectMetadata(
527 key=path,
528 type="file",
529 content_length=response["ContentLength"],
530 content_type=response.get("ContentType"),
531 last_modified=response["LastModified"],
532 etag=response["ETag"].strip('"') if "ETag" in response else None,
533 storage_class=response.get("StorageClass"),
534 metadata=response.get("Metadata"),
535 )
536
537 try:
538 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
539 except FileNotFoundError as error:
540 if strict:
541 # If the object does not exist on the given path, we will append a trailing slash and
542 # check if the path is a directory.
543 path = self._append_delimiter(path)
544 if self._is_dir(path):
545 return ObjectMetadata(
546 key=path,
547 type="directory",
548 content_length=0,
549 last_modified=AWARE_DATETIME_MIN,
550 )
551 raise error
552
553 def _list_objects(
554 self,
555 path: str,
556 start_after: Optional[str] = None,
557 end_at: Optional[str] = None,
558 include_directories: bool = False,
559 follow_symlinks: bool = True,
560 ) -> Iterator[ObjectMetadata]:
561 bucket, prefix = split_path(path)
562
563 # Get the prefix of the start_after and end_at paths relative to the bucket.
564 if start_after:
565 _, start_after = split_path(start_after)
566 if end_at:
567 _, end_at = split_path(end_at)
568
569 def _invoke_api() -> Iterator[ObjectMetadata]:
570 paginator = self._s3_client.get_paginator("list_objects_v2")
571 if include_directories:
572 page_iterator = paginator.paginate(
573 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
574 )
575 else:
576 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
577
578 for page in page_iterator:
579 for item in page.get("CommonPrefixes", []):
580 yield ObjectMetadata(
581 key=os.path.join(bucket, item["Prefix"].rstrip("/")),
582 type="directory",
583 content_length=0,
584 last_modified=AWARE_DATETIME_MIN,
585 )
586
587 # S3 guarantees lexicographical order for general purpose buckets (for
588 # normal S3) but not directory buckets (for S3 Express One Zone).
589 for response_object in page.get("Contents", []):
590 key = response_object["Key"]
591 if end_at is None or key <= end_at:
592 if key.endswith("/"):
593 if include_directories:
594 yield ObjectMetadata(
595 key=os.path.join(bucket, key.rstrip("/")),
596 type="directory",
597 content_length=0,
598 last_modified=response_object["LastModified"],
599 )
600 else:
601 yield ObjectMetadata(
602 key=os.path.join(bucket, key),
603 type="file",
604 content_length=response_object["Size"],
605 last_modified=response_object["LastModified"],
606 etag=response_object["ETag"].strip('"'),
607 storage_class=response_object.get("StorageClass"), # Pass storage_class
608 )
609 else:
610 return
611
612 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
613
614 def _upload_file(
615 self,
616 remote_path: str,
617 f: Union[str, IO],
618 attributes: Optional[dict[str, str]] = None,
619 content_type: Optional[str] = None,
620 ) -> int:
621 file_size: int = 0
622
623 if isinstance(f, str):
624 bucket, key = split_path(remote_path)
625 file_size = os.path.getsize(f)
626
627 # Upload small files
628 if file_size <= self._transfer_config.multipart_threshold:
629 if self._rust_client and not attributes and not content_type:
630 run_async_rust_client_method(self._rust_client, "upload", f, key)
631 else:
632 with open(f, "rb") as fp:
633 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
634 return file_size
635
636 # Upload large files using TransferConfig
637 def _invoke_api() -> int:
638 extra_args = {}
639 if content_type:
640 extra_args["ContentType"] = content_type
641 if self._is_directory_bucket(bucket):
642 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
643 validated_attributes = validate_attributes(attributes)
644 if validated_attributes:
645 extra_args["Metadata"] = validated_attributes
646 if self._rust_client and not extra_args:
647 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
648 else:
649 self._s3_client.upload_file(
650 Filename=f,
651 Bucket=bucket,
652 Key=key,
653 Config=self._transfer_config,
654 ExtraArgs=extra_args,
655 )
656
657 return file_size
658
659 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
660 else:
661 # Upload small files
662 f.seek(0, io.SEEK_END)
663 file_size = f.tell()
664 f.seek(0)
665
666 if file_size <= self._transfer_config.multipart_threshold:
667 if isinstance(f, io.StringIO):
668 self._put_object(
669 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
670 )
671 else:
672 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
673 return file_size
674
675 # Upload large files using TransferConfig
676 bucket, key = split_path(remote_path)
677
678 def _invoke_api() -> int:
679 extra_args = {}
680 if content_type:
681 extra_args["ContentType"] = content_type
682 if self._is_directory_bucket(bucket):
683 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
684 validated_attributes = validate_attributes(attributes)
685 if validated_attributes:
686 extra_args["Metadata"] = validated_attributes
687
688 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
689 data = f.getbuffer()
690 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
691 else:
692 self._s3_client.upload_fileobj(
693 Fileobj=f,
694 Bucket=bucket,
695 Key=key,
696 Config=self._transfer_config,
697 ExtraArgs=extra_args,
698 )
699
700 return file_size
701
702 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
703
704 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
705 if metadata is None:
706 metadata = self._get_object_metadata(remote_path)
707
708 if isinstance(f, str):
709 bucket, key = split_path(remote_path)
710 if os.path.dirname(f):
711 os.makedirs(os.path.dirname(f), exist_ok=True)
712
713 # Download small files
714 if metadata.content_length <= self._transfer_config.multipart_threshold:
715 if self._rust_client:
716 run_async_rust_client_method(self._rust_client, "download", key, f)
717 else:
718 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
719 temp_file_path = fp.name
720 fp.write(self._get_object(remote_path))
721 os.rename(src=temp_file_path, dst=f)
722 return metadata.content_length
723
724 # Download large files using TransferConfig
725 def _invoke_api() -> int:
726 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
727 temp_file_path = fp.name
728 if self._rust_client:
729 run_async_rust_client_method(
730 self._rust_client, "download_multipart_to_file", key, temp_file_path
731 )
732 else:
733 self._s3_client.download_fileobj(
734 Bucket=bucket,
735 Key=key,
736 Fileobj=fp,
737 Config=self._transfer_config,
738 )
739
740 os.rename(src=temp_file_path, dst=f)
741
742 return metadata.content_length
743
744 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
745 else:
746 # Download small files
747 if metadata.content_length <= self._transfer_config.multipart_threshold:
748 response = self._get_object(remote_path)
749 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
750 # so we need to check whether `.decode()` is available.
751 if isinstance(f, io.StringIO):
752 if hasattr(response, "decode"):
753 f.write(response.decode("utf-8"))
754 else:
755 f.write(codecs.decode(memoryview(response), "utf-8"))
756 else:
757 f.write(response)
758 return metadata.content_length
759
760 # Download large files using TransferConfig
761 bucket, key = split_path(remote_path)
762
763 def _invoke_api() -> int:
764 self._s3_client.download_fileobj(
765 Bucket=bucket,
766 Key=key,
767 Fileobj=f,
768 Config=self._transfer_config,
769 )
770
771 return metadata.content_length
772
773 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)