1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29
30from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
31
32from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
33from ..rust_utils import parse_retry_config, run_async_rust_client_method
34from ..telemetry import Telemetry
35from ..types import (
36 AWARE_DATETIME_MIN,
37 Credentials,
38 CredentialsProvider,
39 ObjectMetadata,
40 PreconditionFailedError,
41 Range,
42 RetryableError,
43)
44from ..utils import (
45 get_available_cpu_count,
46 safe_makedirs,
47 split_path,
48 validate_attributes,
49)
50from .base import BaseStorageProvider
51
52_T = TypeVar("_T")
53
54# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
55MAX_POOL_CONNECTIONS = max(
56 64,
57 get_available_cpu_count(),
58 int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
59)
60
61MiB = 1024 * 1024
62
63# Python and Rust share the same multipart_threshold to keep the code simple.
64MULTIPART_THRESHOLD = 64 * MiB
65MULTIPART_CHUNKSIZE = 32 * MiB
66IO_CHUNKSIZE = 32 * MiB
67PYTHON_MAX_CONCURRENCY = 8
68
69PROVIDER = "s3"
70
71EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
72
73
[docs]
74class StaticS3CredentialsProvider(CredentialsProvider):
75 """
76 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
77 """
78
79 _access_key: str
80 _secret_key: str
81 _session_token: Optional[str]
82
83 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
84 """
85 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
86 session token.
87
88 :param access_key: The access key for S3 authentication.
89 :param secret_key: The secret key for S3 authentication.
90 :param session_token: An optional session token for temporary credentials.
91 """
92 self._access_key = access_key
93 self._secret_key = secret_key
94 self._session_token = session_token
95
[docs]
96 def get_credentials(self) -> Credentials:
97 return Credentials(
98 access_key=self._access_key,
99 secret_key=self._secret_key,
100 token=self._session_token,
101 expiration=None,
102 )
103
[docs]
104 def refresh_credentials(self) -> None:
105 pass
106
107
[docs]
108class S3StorageProvider(BaseStorageProvider):
109 """
110 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
111 """
112
113 def __init__(
114 self,
115 region_name: str = "",
116 endpoint_url: str = "",
117 base_path: str = "",
118 credentials_provider: Optional[CredentialsProvider] = None,
119 config_dict: Optional[dict[str, Any]] = None,
120 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
121 verify: Optional[Union[bool, str]] = None,
122 **kwargs: Any,
123 ) -> None:
124 """
125 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
126
127 :param region_name: The AWS region where the S3 bucket is located.
128 :param endpoint_url: The custom endpoint URL for the S3 service.
129 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
130 :param credentials_provider: The provider to retrieve S3 credentials.
131 :param config_dict: Resolved MSC config.
132 :param telemetry_provider: A function that provides a telemetry instance.
133 :param verify: Controls SSL certificate verification. Can be ``True`` (verify using system CA bundle, default),
134 ``False`` (skip verification), or a string path to a custom CA certificate bundle.
135 """
136 super().__init__(
137 base_path=base_path,
138 provider_name=PROVIDER,
139 config_dict=config_dict,
140 telemetry_provider=telemetry_provider,
141 )
142
143 self._region_name = region_name
144 self._endpoint_url = endpoint_url
145 self._credentials_provider = credentials_provider
146 self._verify = verify
147
148 self._signature_version = kwargs.get("signature_version", "")
149 self._s3_client = self._create_s3_client(
150 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
151 response_checksum_validation=kwargs.get("response_checksum_validation"),
152 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
153 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT),
154 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT),
155 retries=kwargs.get("retries"),
156 )
157 self._transfer_config = TransferConfig(
158 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
159 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
160 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
161 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
162 use_threads=True,
163 )
164
165 self._rust_client = None
166 if "rust_client" in kwargs:
167 # Inherit the rust client options from the kwargs
168 rust_client_options = kwargs["rust_client"]
169 if "max_pool_connections" in kwargs:
170 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
171 if "max_concurrency" in kwargs:
172 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
173 if "multipart_chunksize" in kwargs:
174 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
175 if "read_timeout" in kwargs:
176 rust_client_options["read_timeout"] = kwargs["read_timeout"]
177 if "connect_timeout" in kwargs:
178 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
179 if self._signature_version == "UNSIGNED":
180 rust_client_options["skip_signature"] = True
181 self._rust_client = self._create_rust_client(rust_client_options)
182
183 def _is_directory_bucket(self, bucket: str) -> bool:
184 """
185 Determines if the bucket is a directory bucket based on bucket name.
186 """
187 # S3 Express buckets have a specific naming convention
188 return "--x-s3" in bucket
189
190 def _create_s3_client(
191 self,
192 request_checksum_calculation: Optional[str] = None,
193 response_checksum_validation: Optional[str] = None,
194 max_pool_connections: int = MAX_POOL_CONNECTIONS,
195 connect_timeout: Union[float, int, None] = None,
196 read_timeout: Union[float, int, None] = None,
197 retries: Optional[dict[str, Any]] = None,
198 ):
199 """
200 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
201
202 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
203 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
204 :param max_pool_connections: The maximum number of connections to keep in a connection pool.
205 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection.
206 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection.
207 :param retries: A dictionary for configuration related to retry behavior.
208
209 :return: The configured S3 client.
210 """
211 options = {
212 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
213 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
214 max_pool_connections=max_pool_connections,
215 connect_timeout=connect_timeout,
216 read_timeout=read_timeout,
217 retries=retries or {"mode": "standard"},
218 request_checksum_calculation=request_checksum_calculation,
219 response_checksum_validation=response_checksum_validation,
220 ),
221 }
222
223 if self._region_name:
224 options["region_name"] = self._region_name
225
226 if self._endpoint_url:
227 options["endpoint_url"] = self._endpoint_url
228
229 if self._verify is not None:
230 options["verify"] = self._verify
231
232 if self._credentials_provider:
233 creds = self._fetch_credentials()
234 if "expiry_time" in creds and creds["expiry_time"]:
235 # Use RefreshableCredentials if expiry_time provided.
236 refreshable_credentials = RefreshableCredentials.create_from_metadata(
237 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
238 )
239
240 botocore_session = get_session()
241 botocore_session._credentials = refreshable_credentials
242
243 boto3_session = boto3.Session(botocore_session=botocore_session)
244
245 return boto3_session.client("s3", **options)
246 else:
247 # Add static credentials to the options dictionary
248 options["aws_access_key_id"] = creds["access_key"]
249 options["aws_secret_access_key"] = creds["secret_key"]
250 if creds["token"]:
251 options["aws_session_token"] = creds["token"]
252
253 if self._signature_version:
254 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
255 signature_version=botocore.UNSIGNED
256 if self._signature_version == "UNSIGNED"
257 else self._signature_version
258 )
259 options["config"] = options["config"].merge(signature_config)
260
261 # Fallback to standard credential chain.
262 return boto3.client("s3", **options)
263
264 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
265 """
266 Creates and configures the rust client, using refreshable credentials if possible.
267 """
268 configs = dict(rust_client_options) if rust_client_options else {}
269
270 # Extract and parse retry configuration
271 retry_config = parse_retry_config(configs)
272
273 if self._region_name and "region_name" not in configs:
274 configs["region_name"] = self._region_name
275
276 if self._endpoint_url and "endpoint_url" not in configs:
277 configs["endpoint_url"] = self._endpoint_url
278
279 # If the user specifies a bucket, use it. Otherwise, use the base path.
280 if "bucket" not in configs:
281 bucket, _ = split_path(self._base_path)
282 configs["bucket"] = bucket
283
284 if "max_pool_connections" not in configs:
285 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
286
287 return RustClient(
288 provider=PROVIDER,
289 configs=configs,
290 credentials_provider=self._credentials_provider,
291 retry=retry_config,
292 )
293
294 def _fetch_credentials(self) -> dict:
295 """
296 Refreshes the S3 client if the current credentials are expired.
297 """
298 if not self._credentials_provider:
299 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
300 self._credentials_provider.refresh_credentials()
301 credentials = self._credentials_provider.get_credentials()
302 return {
303 "access_key": credentials.access_key,
304 "secret_key": credentials.secret_key,
305 "token": credentials.token,
306 "expiry_time": credentials.expiration,
307 }
308
309 def _translate_errors(
310 self,
311 func: Callable[[], _T],
312 operation: str,
313 bucket: str,
314 key: str,
315 ) -> _T:
316 """
317 Translates errors like timeouts and client errors.
318
319 :param func: The function that performs the actual S3 operation.
320 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
321 :param bucket: The name of the S3 bucket involved in the operation.
322 :param key: The key of the object within the S3 bucket.
323
324 :return: The result of the S3 operation, typically the return value of the `func` callable.
325 """
326 try:
327 return func()
328 except ClientError as error:
329 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
330 request_id = error.response["ResponseMetadata"].get("RequestId")
331 host_id = error.response["ResponseMetadata"].get("HostId")
332 error_code = error.response["Error"]["Code"]
333 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
334
335 if status_code == 404:
336 if error_code == "NoSuchUpload":
337 error_message = error.response["Error"]["Message"]
338 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
339 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
340 elif status_code == 412: # Precondition Failed
341 raise PreconditionFailedError(
342 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
343 ) from error
344 elif status_code == 429:
345 raise RetryableError(
346 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
347 ) from error
348 elif status_code == 503:
349 raise RetryableError(
350 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
351 ) from error
352 elif status_code == 501:
353 raise NotImplementedError(
354 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
355 ) from error
356 elif status_code == 408:
357 # 408 Request Timeout is from Google Cloud Storage
358 raise RetryableError(
359 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}"
360 ) from error
361 else:
362 raise RuntimeError(
363 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
364 f"error_type: {type(error).__name__}"
365 ) from error
366 except RustClientError as error:
367 message = error.args[0]
368 status_code = error.args[1]
369 if status_code == 404:
370 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
371 elif status_code == 403:
372 raise PermissionError(
373 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
374 ) from error
375 else:
376 raise RetryableError(
377 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
378 ) from error
379 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
380 raise RetryableError(
381 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
382 f"error_type: {type(error).__name__}"
383 ) from error
384 except RustRetryableError as error:
385 raise RetryableError(
386 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
387 f"error_type: {type(error).__name__}"
388 ) from error
389 except Exception as error:
390 raise RuntimeError(
391 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
392 ) from error
393
394 def _put_object(
395 self,
396 path: str,
397 body: bytes,
398 if_match: Optional[str] = None,
399 if_none_match: Optional[str] = None,
400 attributes: Optional[dict[str, str]] = None,
401 content_type: Optional[str] = None,
402 ) -> int:
403 """
404 Uploads an object to the specified S3 path.
405
406 :param path: The S3 path where the object will be uploaded.
407 :param body: The content of the object as bytes.
408 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
409 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
410 :param attributes: Optional attributes to attach to the object.
411 :param content_type: Optional Content-Type header value.
412 """
413 bucket, key = split_path(path)
414
415 def _invoke_api() -> int:
416 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
417 if content_type:
418 kwargs["ContentType"] = content_type
419 if self._is_directory_bucket(bucket):
420 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
421 if if_match:
422 kwargs["IfMatch"] = if_match
423 if if_none_match:
424 kwargs["IfNoneMatch"] = if_none_match
425 validated_attributes = validate_attributes(attributes)
426 if validated_attributes:
427 kwargs["Metadata"] = validated_attributes
428
429 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
430 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
431 if (
432 self._rust_client
433 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
434 and not path.endswith("/")
435 and all(key not in kwargs for key in rust_unsupported_feature_keys)
436 ):
437 run_async_rust_client_method(self._rust_client, "put", key, body)
438 else:
439 self._s3_client.put_object(**kwargs)
440
441 return len(body)
442
443 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
444
445 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
446 bucket, key = split_path(path)
447
448 def _invoke_api() -> bytes:
449 if byte_range:
450 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
451 if self._rust_client:
452 response = run_async_rust_client_method(
453 self._rust_client,
454 "get",
455 key,
456 byte_range,
457 )
458 return response
459 else:
460 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
461 else:
462 if self._rust_client:
463 response = run_async_rust_client_method(self._rust_client, "get", key)
464 return response
465 else:
466 response = self._s3_client.get_object(Bucket=bucket, Key=key)
467
468 return response["Body"].read()
469
470 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
471
472 def _copy_object(self, src_path: str, dest_path: str) -> int:
473 src_bucket, src_key = split_path(src_path)
474 dest_bucket, dest_key = split_path(dest_path)
475
476 src_object = self._get_object_metadata(src_path)
477
478 def _invoke_api() -> int:
479 self._s3_client.copy(
480 CopySource={"Bucket": src_bucket, "Key": src_key},
481 Bucket=dest_bucket,
482 Key=dest_key,
483 Config=self._transfer_config,
484 )
485
486 return src_object.content_length
487
488 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
489
490 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
491 bucket, key = split_path(path)
492
493 def _invoke_api() -> None:
494 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally
495 if if_match:
496 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
497 else:
498 self._s3_client.delete_object(Bucket=bucket, Key=key)
499
500 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
501
502 def _is_dir(self, path: str) -> bool:
503 # Ensure the path ends with '/' to mimic a directory
504 path = self._append_delimiter(path)
505
506 bucket, key = split_path(path)
507
508 def _invoke_api() -> bool:
509 # List objects with the given prefix
510 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
511
512 # Check if there are any contents or common prefixes
513 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
514
515 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
516
517 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
518 bucket, key = split_path(path)
519 if path.endswith("/") or (bucket and not key):
520 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
521 # which metadata is not guaranteed to exist for cases such as
522 # "virtual prefix" that was never explicitly created.
523 if self._is_dir(path):
524 return ObjectMetadata(
525 key=path,
526 type="directory",
527 content_length=0,
528 last_modified=AWARE_DATETIME_MIN,
529 )
530 else:
531 raise FileNotFoundError(f"Directory {path} does not exist.")
532 else:
533
534 def _invoke_api() -> ObjectMetadata:
535 response = self._s3_client.head_object(Bucket=bucket, Key=key)
536
537 return ObjectMetadata(
538 key=path,
539 type="file",
540 content_length=response["ContentLength"],
541 content_type=response.get("ContentType"),
542 last_modified=response["LastModified"],
543 etag=response["ETag"].strip('"') if "ETag" in response else None,
544 storage_class=response.get("StorageClass"),
545 metadata=response.get("Metadata"),
546 )
547
548 try:
549 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
550 except FileNotFoundError as error:
551 if strict:
552 # If the object does not exist on the given path, we will append a trailing slash and
553 # check if the path is a directory.
554 path = self._append_delimiter(path)
555 if self._is_dir(path):
556 return ObjectMetadata(
557 key=path,
558 type="directory",
559 content_length=0,
560 last_modified=AWARE_DATETIME_MIN,
561 )
562 raise error
563
564 def _list_objects(
565 self,
566 path: str,
567 start_after: Optional[str] = None,
568 end_at: Optional[str] = None,
569 include_directories: bool = False,
570 follow_symlinks: bool = True,
571 ) -> Iterator[ObjectMetadata]:
572 bucket, prefix = split_path(path)
573
574 # Get the prefix of the start_after and end_at paths relative to the bucket.
575 if start_after:
576 _, start_after = split_path(start_after)
577 if end_at:
578 _, end_at = split_path(end_at)
579
580 def _invoke_api() -> Iterator[ObjectMetadata]:
581 paginator = self._s3_client.get_paginator("list_objects_v2")
582 if include_directories:
583 page_iterator = paginator.paginate(
584 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
585 )
586 else:
587 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
588
589 for page in page_iterator:
590 for item in page.get("CommonPrefixes", []):
591 prefix_key = item["Prefix"].rstrip("/")
592 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes
593 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
594 yield ObjectMetadata(
595 key=os.path.join(bucket, prefix_key),
596 type="directory",
597 content_length=0,
598 last_modified=AWARE_DATETIME_MIN,
599 )
600 elif end_at is not None and end_at < prefix_key:
601 return
602
603 # S3 guarantees lexicographical order for general purpose buckets (for
604 # normal S3) but not directory buckets (for S3 Express One Zone).
605 for response_object in page.get("Contents", []):
606 key = response_object["Key"]
607 if end_at is None or key <= end_at:
608 if key.endswith("/"):
609 if include_directories:
610 yield ObjectMetadata(
611 key=os.path.join(bucket, key.rstrip("/")),
612 type="directory",
613 content_length=0,
614 last_modified=response_object["LastModified"],
615 )
616 else:
617 yield ObjectMetadata(
618 key=os.path.join(bucket, key),
619 type="file",
620 content_length=response_object["Size"],
621 last_modified=response_object["LastModified"],
622 etag=response_object["ETag"].strip('"'),
623 storage_class=response_object.get("StorageClass"), # Pass storage_class
624 )
625 else:
626 return
627
628 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
629
630 def _upload_file(
631 self,
632 remote_path: str,
633 f: Union[str, IO],
634 attributes: Optional[dict[str, str]] = None,
635 content_type: Optional[str] = None,
636 ) -> int:
637 file_size: int = 0
638
639 if isinstance(f, str):
640 bucket, key = split_path(remote_path)
641 file_size = os.path.getsize(f)
642
643 # Upload small files
644 if file_size <= self._transfer_config.multipart_threshold:
645 if self._rust_client and not attributes and not content_type:
646 run_async_rust_client_method(self._rust_client, "upload", f, key)
647 else:
648 with open(f, "rb") as fp:
649 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
650 return file_size
651
652 # Upload large files using TransferConfig
653 def _invoke_api() -> int:
654 extra_args = {}
655 if content_type:
656 extra_args["ContentType"] = content_type
657 if self._is_directory_bucket(bucket):
658 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
659 validated_attributes = validate_attributes(attributes)
660 if validated_attributes:
661 extra_args["Metadata"] = validated_attributes
662 if self._rust_client and not extra_args:
663 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
664 else:
665 self._s3_client.upload_file(
666 Filename=f,
667 Bucket=bucket,
668 Key=key,
669 Config=self._transfer_config,
670 ExtraArgs=extra_args,
671 )
672
673 return file_size
674
675 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
676 else:
677 # Upload small files
678 f.seek(0, io.SEEK_END)
679 file_size = f.tell()
680 f.seek(0)
681
682 if file_size <= self._transfer_config.multipart_threshold:
683 if isinstance(f, io.StringIO):
684 self._put_object(
685 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
686 )
687 else:
688 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
689 return file_size
690
691 # Upload large files using TransferConfig
692 bucket, key = split_path(remote_path)
693
694 def _invoke_api() -> int:
695 extra_args = {}
696 if content_type:
697 extra_args["ContentType"] = content_type
698 if self._is_directory_bucket(bucket):
699 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
700 validated_attributes = validate_attributes(attributes)
701 if validated_attributes:
702 extra_args["Metadata"] = validated_attributes
703
704 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
705 data = f.getbuffer()
706 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
707 else:
708 self._s3_client.upload_fileobj(
709 Fileobj=f,
710 Bucket=bucket,
711 Key=key,
712 Config=self._transfer_config,
713 ExtraArgs=extra_args,
714 )
715
716 return file_size
717
718 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
719
720 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
721 if metadata is None:
722 metadata = self._get_object_metadata(remote_path)
723
724 if isinstance(f, str):
725 bucket, key = split_path(remote_path)
726 if os.path.dirname(f):
727 safe_makedirs(os.path.dirname(f))
728
729 # Download small files
730 if metadata.content_length <= self._transfer_config.multipart_threshold:
731 if self._rust_client:
732 run_async_rust_client_method(self._rust_client, "download", key, f)
733 else:
734 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
735 temp_file_path = fp.name
736 fp.write(self._get_object(remote_path))
737 os.rename(src=temp_file_path, dst=f)
738 return metadata.content_length
739
740 # Download large files using TransferConfig
741 def _invoke_api() -> int:
742 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
743 temp_file_path = fp.name
744 if self._rust_client:
745 run_async_rust_client_method(
746 self._rust_client, "download_multipart_to_file", key, temp_file_path
747 )
748 else:
749 self._s3_client.download_fileobj(
750 Bucket=bucket,
751 Key=key,
752 Fileobj=fp,
753 Config=self._transfer_config,
754 )
755
756 os.rename(src=temp_file_path, dst=f)
757
758 return metadata.content_length
759
760 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
761 else:
762 # Download small files
763 if metadata.content_length <= self._transfer_config.multipart_threshold:
764 response = self._get_object(remote_path)
765 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
766 # so we need to check whether `.decode()` is available.
767 if isinstance(f, io.StringIO):
768 if hasattr(response, "decode"):
769 f.write(response.decode("utf-8"))
770 else:
771 f.write(codecs.decode(memoryview(response), "utf-8"))
772 else:
773 f.write(response)
774 return metadata.content_length
775
776 # Download large files using TransferConfig
777 bucket, key = split_path(remote_path)
778
779 def _invoke_api() -> int:
780 self._s3_client.download_fileobj(
781 Bucket=bucket,
782 Key=key,
783 Fileobj=f,
784 Config=self._transfer_config,
785 )
786
787 return metadata.content_length
788
789 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)