1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29
30from multistorageclient_rust import RustClient, RustRetryableError
31
32from ..instrumentation.utils import set_span_attribute
33from ..rust_utils import run_async_rust_client_method
34from ..telemetry import Telemetry
35from ..types import (
36 AWARE_DATETIME_MIN,
37 Credentials,
38 CredentialsProvider,
39 ObjectMetadata,
40 PreconditionFailedError,
41 Range,
42 RetryableError,
43)
44from ..utils import (
45 get_available_cpu_count,
46 split_path,
47 validate_attributes,
48)
49from .base import BaseStorageProvider
50
51_T = TypeVar("_T")
52
53# Default connection pool size scales with CPU count (minimum 32 for backward compatibility)
54MAX_POOL_CONNECTIONS = max(32, get_available_cpu_count())
55
56MiB = 1024 * 1024
57
58# Python and Rust share the same multipart_threshold to keep the code simple.
59MULTIPART_THRESHOLD = 64 * MiB
60MULTIPART_CHUNKSIZE = 32 * MiB
61IO_CHUNKSIZE = 32 * MiB
62PYTHON_MAX_CONCURRENCY = 8
63
64PROVIDER = "s3"
65
66EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
67
68
69def _extract_x_trans_id(response: Any) -> None:
70 """Extract x-trans-id from boto3 response and set it as span attribute."""
71 try:
72 if response and isinstance(response, dict):
73 headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
74 if headers and isinstance(headers, dict) and "x-trans-id" in headers:
75 set_span_attribute("x_trans_id", headers["x-trans-id"])
76 except (KeyError, AttributeError, TypeError):
77 # Silently ignore any errors in extraction
78 pass
79
80
[docs]
81class StaticS3CredentialsProvider(CredentialsProvider):
82 """
83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
84 """
85
86 _access_key: str
87 _secret_key: str
88 _session_token: Optional[str]
89
90 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
91 """
92 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
93 session token.
94
95 :param access_key: The access key for S3 authentication.
96 :param secret_key: The secret key for S3 authentication.
97 :param session_token: An optional session token for temporary credentials.
98 """
99 self._access_key = access_key
100 self._secret_key = secret_key
101 self._session_token = session_token
102
[docs]
103 def get_credentials(self) -> Credentials:
104 return Credentials(
105 access_key=self._access_key,
106 secret_key=self._secret_key,
107 token=self._session_token,
108 expiration=None,
109 )
110
[docs]
111 def refresh_credentials(self) -> None:
112 pass
113
114
[docs]
115class S3StorageProvider(BaseStorageProvider):
116 """
117 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
118 """
119
120 def __init__(
121 self,
122 region_name: str = "",
123 endpoint_url: str = "",
124 base_path: str = "",
125 credentials_provider: Optional[CredentialsProvider] = None,
126 config_dict: Optional[dict[str, Any]] = None,
127 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
128 **kwargs: Any,
129 ) -> None:
130 """
131 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
132
133 :param region_name: The AWS region where the S3 bucket is located.
134 :param endpoint_url: The custom endpoint URL for the S3 service.
135 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
136 :param credentials_provider: The provider to retrieve S3 credentials.
137 :param config_dict: Resolved MSC config.
138 :param telemetry_provider: A function that provides a telemetry instance.
139 """
140 super().__init__(
141 base_path=base_path,
142 provider_name=PROVIDER,
143 config_dict=config_dict,
144 telemetry_provider=telemetry_provider,
145 )
146
147 self._region_name = region_name
148 self._endpoint_url = endpoint_url
149 self._credentials_provider = credentials_provider
150
151 self._signature_version = kwargs.get("signature_version", "")
152 self._s3_client = self._create_s3_client(
153 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
154 response_checksum_validation=kwargs.get("response_checksum_validation"),
155 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
156 connect_timeout=kwargs.get("connect_timeout"),
157 read_timeout=kwargs.get("read_timeout"),
158 retries=kwargs.get("retries"),
159 )
160 self._transfer_config = TransferConfig(
161 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
162 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
163 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
164 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
165 use_threads=True,
166 )
167
168 self._rust_client = None
169 if "rust_client" in kwargs:
170 # Inherit the rust client options from the kwargs
171 rust_client_options = kwargs["rust_client"]
172 if "max_pool_connections" in kwargs:
173 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
174 if "max_concurrency" in kwargs:
175 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
176 if "multipart_chunksize" in kwargs:
177 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
178 if "read_timeout" in kwargs:
179 rust_client_options["read_timeout"] = kwargs["read_timeout"]
180 if "connect_timeout" in kwargs:
181 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
182 if self._signature_version == "UNSIGNED":
183 rust_client_options["skip_signature"] = True
184 self._rust_client = self._create_rust_client(rust_client_options)
185
186 def _is_directory_bucket(self, bucket: str) -> bool:
187 """
188 Determines if the bucket is a directory bucket based on bucket name.
189 """
190 # S3 Express buckets have a specific naming convention
191 return "--x-s3" in bucket
192
193 def _create_s3_client(
194 self,
195 request_checksum_calculation: Optional[str] = None,
196 response_checksum_validation: Optional[str] = None,
197 max_pool_connections: int = MAX_POOL_CONNECTIONS,
198 connect_timeout: Union[float, int, None] = None,
199 read_timeout: Union[float, int, None] = None,
200 retries: Optional[dict[str, Any]] = None,
201 ):
202 """
203 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
204
205 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
206 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
207 :param max_pool_connections: The maximum number of connections to keep in a connection pool.
208 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection.
209 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection.
210 :param retries: A dictionary for configuration related to retry behavior.
211
212 :return: The configured S3 client.
213 """
214 options = {
215 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
216 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
217 max_pool_connections=max_pool_connections,
218 connect_timeout=connect_timeout,
219 read_timeout=read_timeout,
220 retries=retries or {"mode": "standard"},
221 request_checksum_calculation=request_checksum_calculation,
222 response_checksum_validation=response_checksum_validation,
223 ),
224 }
225
226 if self._region_name:
227 options["region_name"] = self._region_name
228
229 if self._endpoint_url:
230 options["endpoint_url"] = self._endpoint_url
231
232 if self._credentials_provider:
233 creds = self._fetch_credentials()
234 if "expiry_time" in creds and creds["expiry_time"]:
235 # Use RefreshableCredentials if expiry_time provided.
236 refreshable_credentials = RefreshableCredentials.create_from_metadata(
237 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
238 )
239
240 botocore_session = get_session()
241 botocore_session._credentials = refreshable_credentials
242
243 boto3_session = boto3.Session(botocore_session=botocore_session)
244
245 return boto3_session.client("s3", **options)
246 else:
247 # Add static credentials to the options dictionary
248 options["aws_access_key_id"] = creds["access_key"]
249 options["aws_secret_access_key"] = creds["secret_key"]
250 if creds["token"]:
251 options["aws_session_token"] = creds["token"]
252
253 if self._signature_version:
254 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
255 signature_version=botocore.UNSIGNED
256 if self._signature_version == "UNSIGNED"
257 else self._signature_version
258 )
259 options["config"] = options["config"].merge(signature_config)
260
261 # Fallback to standard credential chain.
262 return boto3.client("s3", **options)
263
264 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
265 """
266 Creates and configures the rust client, using refreshable credentials if possible.
267 """
268 configs = {}
269 if self._region_name:
270 configs["region_name"] = self._region_name
271
272 # If the user specifies a bucket, use it. Otherwise, use the base path.
273 if rust_client_options and "bucket" in rust_client_options:
274 configs["bucket"] = rust_client_options["bucket"]
275 else:
276 bucket, _ = split_path(self._base_path)
277 configs["bucket"] = bucket
278
279 if self._endpoint_url:
280 configs["endpoint_url"] = self._endpoint_url
281
282 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
283
284 if rust_client_options:
285 if rust_client_options.get("allow_http", False):
286 configs["allow_http"] = True
287 if "max_concurrency" in rust_client_options:
288 configs["max_concurrency"] = rust_client_options["max_concurrency"]
289 if "max_pool_connections" in rust_client_options:
290 configs["max_pool_connections"] = rust_client_options["max_pool_connections"]
291 if "multipart_chunksize" in rust_client_options:
292 configs["multipart_chunksize"] = rust_client_options["multipart_chunksize"]
293 if "skip_signature" in rust_client_options:
294 configs["skip_signature"] = rust_client_options["skip_signature"]
295
296 return RustClient(
297 provider=PROVIDER,
298 configs=configs,
299 credentials_provider=self._credentials_provider,
300 )
301
302 def _fetch_credentials(self) -> dict:
303 """
304 Refreshes the S3 client if the current credentials are expired.
305 """
306 if not self._credentials_provider:
307 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
308 self._credentials_provider.refresh_credentials()
309 credentials = self._credentials_provider.get_credentials()
310 return {
311 "access_key": credentials.access_key,
312 "secret_key": credentials.secret_key,
313 "token": credentials.token,
314 "expiry_time": credentials.expiration,
315 }
316
317 def _translate_errors(
318 self,
319 func: Callable[[], _T],
320 operation: str,
321 bucket: str,
322 key: str,
323 ) -> _T:
324 """
325 Translates errors like timeouts and client errors.
326
327 TODO: Remove tracing from this method.
328
329 :param func: The function that performs the actual S3 operation.
330 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
331 :param bucket: The name of the S3 bucket involved in the operation.
332 :param key: The key of the object within the S3 bucket.
333
334 :return: The result of the S3 operation, typically the return value of the `func` callable.
335 """
336 # Set basic operation attributes
337 set_span_attribute("s3_operation", operation)
338 set_span_attribute("s3_bucket", bucket)
339 set_span_attribute("s3_key", key)
340
341 try:
342 return func()
343 except ClientError as error:
344 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
345 request_id = error.response["ResponseMetadata"].get("RequestId")
346 host_id = error.response["ResponseMetadata"].get("HostId")
347 header = error.response["ResponseMetadata"].get("HTTPHeaders", {})
348 error_code = error.response["Error"]["Code"]
349
350 # Ensure header is a dictionary before trying to get from it
351 x_trans_id = header.get("x-trans-id") if isinstance(header, dict) else None
352
353 # Record error details in span
354 set_span_attribute("request_id", request_id)
355 set_span_attribute("host_id", host_id)
356
357 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
358 if x_trans_id:
359 error_info += f", x-trans-id: {x_trans_id}"
360 set_span_attribute("x_trans_id", x_trans_id)
361
362 if status_code == 404:
363 if error_code == "NoSuchUpload":
364 error_message = error.response["Error"]["Message"]
365 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
366 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
367 elif status_code == 412: # Precondition Failed
368 raise PreconditionFailedError(
369 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
370 ) from error
371 elif status_code == 429:
372 raise RetryableError(
373 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
374 ) from error
375 elif status_code == 503:
376 raise RetryableError(
377 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
378 ) from error
379 elif status_code == 501:
380 raise NotImplementedError(
381 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
382 ) from error
383 else:
384 raise RuntimeError(
385 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
386 f"error_type: {type(error).__name__}"
387 ) from error
388 except FileNotFoundError:
389 raise
390 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
391 raise RetryableError(
392 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
393 f"error_type: {type(error).__name__}"
394 ) from error
395 except RustRetryableError as error:
396 raise RetryableError(
397 f"Failed to {operation} object(s) at {bucket}/{key} due to exhausted retries from Rust. "
398 f"error_type: {type(error).__name__}"
399 ) from error
400 except Exception as error:
401 raise RuntimeError(
402 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
403 ) from error
404
405 def _put_object(
406 self,
407 path: str,
408 body: bytes,
409 if_match: Optional[str] = None,
410 if_none_match: Optional[str] = None,
411 attributes: Optional[dict[str, str]] = None,
412 content_type: Optional[str] = None,
413 ) -> int:
414 """
415 Uploads an object to the specified S3 path.
416
417 :param path: The S3 path where the object will be uploaded.
418 :param body: The content of the object as bytes.
419 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
420 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
421 :param attributes: Optional attributes to attach to the object.
422 :param content_type: Optional Content-Type header value.
423 """
424 bucket, key = split_path(path)
425
426 def _invoke_api() -> int:
427 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
428 if content_type:
429 kwargs["ContentType"] = content_type
430 if self._is_directory_bucket(bucket):
431 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
432 if if_match:
433 kwargs["IfMatch"] = if_match
434 if if_none_match:
435 kwargs["IfNoneMatch"] = if_none_match
436 validated_attributes = validate_attributes(attributes)
437 if validated_attributes:
438 kwargs["Metadata"] = validated_attributes
439
440 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
441 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
442 if (
443 self._rust_client
444 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
445 and not path.endswith("/")
446 and all(key not in kwargs for key in rust_unsupported_feature_keys)
447 ):
448 response = run_async_rust_client_method(self._rust_client, "put", key, body)
449 else:
450 # Capture the response from put_object
451 response = self._s3_client.put_object(**kwargs)
452
453 # Extract and set x-trans-id if present
454 _extract_x_trans_id(response)
455
456 return len(body)
457
458 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
459
460 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
461 bucket, key = split_path(path)
462
463 def _invoke_api() -> bytes:
464 if byte_range:
465 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
466 if self._rust_client:
467 response = run_async_rust_client_method(
468 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
469 )
470 return response
471 else:
472 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
473 else:
474 if self._rust_client:
475 response = run_async_rust_client_method(self._rust_client, "get", key)
476 return response
477 else:
478 response = self._s3_client.get_object(Bucket=bucket, Key=key)
479
480 # Extract and set x-trans-id if present
481 _extract_x_trans_id(response)
482
483 return response["Body"].read()
484
485 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
486
487 def _copy_object(self, src_path: str, dest_path: str) -> int:
488 src_bucket, src_key = split_path(src_path)
489 dest_bucket, dest_key = split_path(dest_path)
490
491 src_object = self._get_object_metadata(src_path)
492
493 def _invoke_api() -> int:
494 response = self._s3_client.copy(
495 CopySource={"Bucket": src_bucket, "Key": src_key},
496 Bucket=dest_bucket,
497 Key=dest_key,
498 Config=self._transfer_config,
499 )
500
501 # Extract and set x-trans-id if present
502 _extract_x_trans_id(response)
503
504 return src_object.content_length
505
506 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
507
508 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
509 bucket, key = split_path(path)
510
511 def _invoke_api() -> None:
512 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally
513 if if_match:
514 response = self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
515 else:
516 response = self._s3_client.delete_object(Bucket=bucket, Key=key)
517
518 # Extract and set x-trans-id if present
519 _extract_x_trans_id(response)
520
521 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
522
523 def _is_dir(self, path: str) -> bool:
524 # Ensure the path ends with '/' to mimic a directory
525 path = self._append_delimiter(path)
526
527 bucket, key = split_path(path)
528
529 def _invoke_api() -> bool:
530 # List objects with the given prefix
531 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
532
533 # Extract and set x-trans-id if present
534 _extract_x_trans_id(response)
535
536 # Check if there are any contents or common prefixes
537 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
538
539 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
540
541 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
542 bucket, key = split_path(path)
543 if path.endswith("/") or (bucket and not key):
544 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
545 # which metadata is not guaranteed to exist for cases such as
546 # "virtual prefix" that was never explicitly created.
547 if self._is_dir(path):
548 return ObjectMetadata(
549 key=path,
550 type="directory",
551 content_length=0,
552 last_modified=AWARE_DATETIME_MIN,
553 )
554 else:
555 raise FileNotFoundError(f"Directory {path} does not exist.")
556 else:
557
558 def _invoke_api() -> ObjectMetadata:
559 response = self._s3_client.head_object(Bucket=bucket, Key=key)
560
561 # Extract and set x-trans-id if present
562 _extract_x_trans_id(response)
563
564 return ObjectMetadata(
565 key=path,
566 type="file",
567 content_length=response["ContentLength"],
568 content_type=response["ContentType"],
569 last_modified=response["LastModified"],
570 etag=response["ETag"].strip('"'),
571 storage_class=response.get("StorageClass"),
572 metadata=response.get("Metadata"),
573 )
574
575 try:
576 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
577 except FileNotFoundError as error:
578 if strict:
579 # If the object does not exist on the given path, we will append a trailing slash and
580 # check if the path is a directory.
581 path = self._append_delimiter(path)
582 if self._is_dir(path):
583 return ObjectMetadata(
584 key=path,
585 type="directory",
586 content_length=0,
587 last_modified=AWARE_DATETIME_MIN,
588 )
589 raise error
590
591 def _list_objects(
592 self,
593 path: str,
594 start_after: Optional[str] = None,
595 end_at: Optional[str] = None,
596 include_directories: bool = False,
597 follow_symlinks: bool = True,
598 ) -> Iterator[ObjectMetadata]:
599 bucket, prefix = split_path(path)
600
601 def _invoke_api() -> Iterator[ObjectMetadata]:
602 paginator = self._s3_client.get_paginator("list_objects_v2")
603 if include_directories:
604 page_iterator = paginator.paginate(
605 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
606 )
607 else:
608 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
609
610 for page in page_iterator:
611 for item in page.get("CommonPrefixes", []):
612 yield ObjectMetadata(
613 key=os.path.join(bucket, item["Prefix"].rstrip("/")),
614 type="directory",
615 content_length=0,
616 last_modified=AWARE_DATETIME_MIN,
617 )
618
619 # S3 guarantees lexicographical order for general purpose buckets (for
620 # normal S3) but not directory buckets (for S3 Express One Zone).
621 for response_object in page.get("Contents", []):
622 key = response_object["Key"]
623 if end_at is None or key <= end_at:
624 if key.endswith("/"):
625 if include_directories:
626 yield ObjectMetadata(
627 key=os.path.join(bucket, key.rstrip("/")),
628 type="directory",
629 content_length=0,
630 last_modified=response_object["LastModified"],
631 )
632 else:
633 yield ObjectMetadata(
634 key=os.path.join(bucket, key),
635 type="file",
636 content_length=response_object["Size"],
637 last_modified=response_object["LastModified"],
638 etag=response_object["ETag"].strip('"'),
639 storage_class=response_object.get("StorageClass"), # Pass storage_class
640 )
641 else:
642 return
643
644 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
645
646 def _upload_file(
647 self,
648 remote_path: str,
649 f: Union[str, IO],
650 attributes: Optional[dict[str, str]] = None,
651 content_type: Optional[str] = None,
652 ) -> int:
653 file_size: int = 0
654
655 if isinstance(f, str):
656 bucket, key = split_path(remote_path)
657 file_size = os.path.getsize(f)
658
659 # Upload small files
660 if file_size <= self._transfer_config.multipart_threshold:
661 if self._rust_client and not attributes and not content_type:
662 run_async_rust_client_method(self._rust_client, "upload", f, key)
663 else:
664 with open(f, "rb") as fp:
665 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
666 return file_size
667
668 # Upload large files using TransferConfig
669 def _invoke_api() -> int:
670 extra_args = {}
671 if content_type:
672 extra_args["ContentType"] = content_type
673 if self._is_directory_bucket(bucket):
674 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
675 validated_attributes = validate_attributes(attributes)
676 if validated_attributes:
677 extra_args["Metadata"] = validated_attributes
678 if self._rust_client and not extra_args:
679 response = run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
680 else:
681 response = self._s3_client.upload_file(
682 Filename=f,
683 Bucket=bucket,
684 Key=key,
685 Config=self._transfer_config,
686 ExtraArgs=extra_args,
687 )
688
689 # Extract and set x-trans-id if present
690 _extract_x_trans_id(response)
691
692 return file_size
693
694 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
695 else:
696 # Upload small files
697 f.seek(0, io.SEEK_END)
698 file_size = f.tell()
699 f.seek(0)
700
701 if file_size <= self._transfer_config.multipart_threshold:
702 if isinstance(f, io.StringIO):
703 self._put_object(
704 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
705 )
706 else:
707 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
708 return file_size
709
710 # Upload large files using TransferConfig
711 bucket, key = split_path(remote_path)
712
713 def _invoke_api() -> int:
714 extra_args = {}
715 if content_type:
716 extra_args["ContentType"] = content_type
717 if self._is_directory_bucket(bucket):
718 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
719 validated_attributes = validate_attributes(attributes)
720 if validated_attributes:
721 extra_args["Metadata"] = validated_attributes
722 self._s3_client.upload_fileobj(
723 Fileobj=f,
724 Bucket=bucket,
725 Key=key,
726 Config=self._transfer_config,
727 ExtraArgs=extra_args,
728 )
729
730 return file_size
731
732 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
733
734 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
735 if metadata is None:
736 metadata = self._get_object_metadata(remote_path)
737
738 if isinstance(f, str):
739 bucket, key = split_path(remote_path)
740 if os.path.dirname(f):
741 os.makedirs(os.path.dirname(f), exist_ok=True)
742
743 # Download small files
744 if metadata.content_length <= self._transfer_config.multipart_threshold:
745 if self._rust_client:
746 run_async_rust_client_method(self._rust_client, "download", key, f)
747 else:
748 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
749 temp_file_path = fp.name
750 fp.write(self._get_object(remote_path))
751 os.rename(src=temp_file_path, dst=f)
752 return metadata.content_length
753
754 # Download large files using TransferConfig
755 def _invoke_api() -> int:
756 response = None
757 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
758 temp_file_path = fp.name
759 if self._rust_client:
760 response = run_async_rust_client_method(
761 self._rust_client, "download_multipart_to_file", key, temp_file_path
762 )
763 else:
764 response = self._s3_client.download_fileobj(
765 Bucket=bucket,
766 Key=key,
767 Fileobj=fp,
768 Config=self._transfer_config,
769 )
770
771 # Extract and set x-trans-id if present
772 _extract_x_trans_id(response)
773 os.rename(src=temp_file_path, dst=f)
774
775 return metadata.content_length
776
777 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
778 else:
779 # Download small files
780 if metadata.content_length <= self._transfer_config.multipart_threshold:
781 response = self._get_object(remote_path)
782 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
783 # so we need to check whether `.decode()` is available.
784 if isinstance(f, io.StringIO):
785 if hasattr(response, "decode"):
786 f.write(response.decode("utf-8"))
787 else:
788 f.write(codecs.decode(memoryview(response), "utf-8"))
789 else:
790 f.write(response)
791 return metadata.content_length
792
793 # Download large files using TransferConfig
794 bucket, key = split_path(remote_path)
795
796 def _invoke_api() -> int:
797 response = self._s3_client.download_fileobj(
798 Bucket=bucket,
799 Key=key,
800 Fileobj=f,
801 Config=self._transfer_config,
802 )
803
804 # Extract and set x-trans-id if present
805 _extract_x_trans_id(response)
806
807 return metadata.content_length
808
809 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)