1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20import time
21from collections.abc import Callable, Iterator, Sequence, Sized
22from typing import IO, Any, Optional, TypeVar, Union
23
24import boto3
25import botocore
26import opentelemetry.metrics as api_metrics
27from boto3.s3.transfer import TransferConfig
28from botocore.credentials import RefreshableCredentials
29from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
30from botocore.session import get_session
31
32from ..instrumentation.utils import set_span_attribute
33from ..rust_utils import run_async_rust_client_method
34from ..telemetry import Telemetry
35from ..telemetry.attributes.base import AttributesProvider
36from ..types import (
37 AWARE_DATETIME_MIN,
38 Credentials,
39 CredentialsProvider,
40 ObjectMetadata,
41 PreconditionFailedError,
42 Range,
43 RetryableError,
44)
45from ..utils import (
46 split_path,
47 validate_attributes,
48)
49from .base import BaseStorageProvider
50
51_T = TypeVar("_T")
52
53BOTO3_MAX_POOL_CONNECTIONS = 32
54
55MiB = 1024 * 1024
56
57# Python and Rust share the same multipart_threshold to keep the code simple.
58MULTIPART_THRESHOLD = 64 * MiB
59MULTIPART_CHUNKSIZE = 32 * MiB
60IO_CHUNKSIZE = 32 * MiB
61# Python uses a lower default concurrency due to the GIL limiting true parallelism in threads.
62PYTHON_MAX_CONCURRENCY = 16
63RUST_MAX_CONCURRENCY = 32
64PROVIDER = "s3"
65
66EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
67
68
69def _extract_x_trans_id(response: Any) -> None:
70 """Extract x-trans-id from boto3 response and set it as span attribute."""
71 try:
72 if response and isinstance(response, dict):
73 headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
74 if headers and isinstance(headers, dict) and "x-trans-id" in headers:
75 set_span_attribute("x_trans_id", headers["x-trans-id"])
76 except (KeyError, AttributeError, TypeError):
77 # Silently ignore any errors in extraction
78 pass
79
80
[docs]
81class StaticS3CredentialsProvider(CredentialsProvider):
82 """
83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
84 """
85
86 _access_key: str
87 _secret_key: str
88 _session_token: Optional[str]
89
90 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
91 """
92 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
93 session token.
94
95 :param access_key: The access key for S3 authentication.
96 :param secret_key: The secret key for S3 authentication.
97 :param session_token: An optional session token for temporary credentials.
98 """
99 self._access_key = access_key
100 self._secret_key = secret_key
101 self._session_token = session_token
102
[docs]
103 def get_credentials(self) -> Credentials:
104 return Credentials(
105 access_key=self._access_key,
106 secret_key=self._secret_key,
107 token=self._session_token,
108 expiration=None,
109 )
110
[docs]
111 def refresh_credentials(self) -> None:
112 pass
113
114
[docs]
115class S3StorageProvider(BaseStorageProvider):
116 """
117 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
118 """
119
120 def __init__(
121 self,
122 region_name: str = "",
123 endpoint_url: str = "",
124 base_path: str = "",
125 credentials_provider: Optional[CredentialsProvider] = None,
126 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
127 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
128 metric_attributes_providers: Sequence[AttributesProvider] = (),
129 **kwargs: Any,
130 ) -> None:
131 """
132 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
133
134 :param region_name: The AWS region where the S3 bucket is located.
135 :param endpoint_url: The custom endpoint URL for the S3 service.
136 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
137 :param credentials_provider: The provider to retrieve S3 credentials.
138 :param metric_counters: Metric counters.
139 :param metric_gauges: Metric gauges.
140 :param metric_attributes_providers: Metric attributes providers.
141 """
142 super().__init__(
143 base_path=base_path,
144 provider_name=PROVIDER,
145 metric_counters=metric_counters,
146 metric_gauges=metric_gauges,
147 metric_attributes_providers=metric_attributes_providers,
148 )
149
150 self._region_name = region_name
151 self._endpoint_url = endpoint_url
152 self._credentials_provider = credentials_provider
153
154 self._signature_version = kwargs.get("signature_version", "")
155 self._s3_client = self._create_s3_client(
156 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
157 response_checksum_validation=kwargs.get("response_checksum_validation"),
158 max_pool_connections=kwargs.get("max_pool_connections", BOTO3_MAX_POOL_CONNECTIONS),
159 connect_timeout=kwargs.get("connect_timeout"),
160 read_timeout=kwargs.get("read_timeout"),
161 retries=kwargs.get("retries"),
162 )
163 self._transfer_config = TransferConfig(
164 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
165 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
166 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
167 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
168 use_threads=True,
169 )
170
171 self._rust_client = None
172 if "rust_client" in kwargs:
173 self._rust_client = self._create_rust_client(kwargs.get("rust_client"))
174
175 def _is_directory_bucket(self, bucket: str) -> bool:
176 """
177 Determines if the bucket is a directory bucket based on bucket name.
178 """
179 # S3 Express buckets have a specific naming convention
180 return "--x-s3" in bucket
181
182 def _create_s3_client(
183 self,
184 request_checksum_calculation: Optional[str] = None,
185 response_checksum_validation: Optional[str] = None,
186 max_pool_connections: int = BOTO3_MAX_POOL_CONNECTIONS,
187 connect_timeout: Union[float, int, None] = None,
188 read_timeout: Union[float, int, None] = None,
189 retries: Optional[dict[str, Any]] = None,
190 ):
191 """
192 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
193
194 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
195 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
196 :param max_pool_connections: The maximum number of connections to keep in a connection pool.
197 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection.
198 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection.
199 :param retries: A dictionary for configuration related to retry behavior.
200
201 :return: The configured S3 client.
202 """
203 options = {
204 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
205 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
206 max_pool_connections=max_pool_connections,
207 connect_timeout=connect_timeout,
208 read_timeout=read_timeout,
209 retries=retries or {"mode": "standard"},
210 request_checksum_calculation=request_checksum_calculation,
211 response_checksum_validation=response_checksum_validation,
212 ),
213 }
214
215 if self._region_name:
216 options["region_name"] = self._region_name
217
218 if self._endpoint_url:
219 options["endpoint_url"] = self._endpoint_url
220
221 if self._credentials_provider:
222 creds = self._fetch_credentials()
223 if "expiry_time" in creds and creds["expiry_time"]:
224 # Use RefreshableCredentials if expiry_time provided.
225 refreshable_credentials = RefreshableCredentials.create_from_metadata(
226 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
227 )
228
229 botocore_session = get_session()
230 botocore_session._credentials = refreshable_credentials
231
232 boto3_session = boto3.Session(botocore_session=botocore_session)
233
234 return boto3_session.client("s3", **options)
235 else:
236 # Add static credentials to the options dictionary
237 options["aws_access_key_id"] = creds["access_key"]
238 options["aws_secret_access_key"] = creds["secret_key"]
239 if creds["token"]:
240 options["aws_session_token"] = creds["token"]
241
242 if self._signature_version:
243 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
244 signature_version=botocore.UNSIGNED
245 if self._signature_version == "UNSIGNED"
246 else self._signature_version
247 )
248 options["config"] = options["config"].merge(signature_config)
249
250 # Fallback to standard credential chain.
251 return boto3.client("s3", **options)
252
253 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
254 """
255 Creates and configures the rust client, using refreshable credentials if possible.
256 """
257 from multistorageclient_rust import RustClient
258
259 configs = {}
260 if self._region_name:
261 configs["region_name"] = self._region_name
262
263 # If the user specifies a bucket, use it. Otherwise, use the base path.
264 if rust_client_options and "bucket" in rust_client_options:
265 configs["bucket"] = rust_client_options["bucket"]
266 else:
267 bucket, _ = split_path(self._base_path)
268 configs["bucket"] = bucket
269
270 if self._endpoint_url:
271 configs["endpoint_url"] = self._endpoint_url
272
273 if self._credentials_provider:
274 creds = self._fetch_credentials()
275 if "expiry_time" in creds and creds["expiry_time"]:
276 # TODO: Implement refreshable credentials
277 raise NotImplementedError("Refreshable credentials are not yet implemented for the rust client.")
278 else:
279 # Add static credentials to the configs dictionary
280 configs["aws_access_key_id"] = creds["access_key"]
281 configs["aws_secret_access_key"] = creds["secret_key"]
282 if creds["token"]:
283 configs["aws_session_token"] = creds["token"]
284
285 if rust_client_options:
286 if rust_client_options.get("allow_http", False):
287 configs["allow_http"] = True
288 configs["max_concurrency"] = rust_client_options.get("max_concurrency", RUST_MAX_CONCURRENCY)
289 configs["multipart_chunksize"] = rust_client_options.get("multipart_chunksize", MULTIPART_CHUNKSIZE)
290
291 return RustClient(
292 provider=PROVIDER,
293 configs=configs,
294 )
295
296 def _fetch_credentials(self) -> dict:
297 """
298 Refreshes the S3 client if the current credentials are expired.
299 """
300 if not self._credentials_provider:
301 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
302 self._credentials_provider.refresh_credentials()
303 credentials = self._credentials_provider.get_credentials()
304 return {
305 "access_key": credentials.access_key,
306 "secret_key": credentials.secret_key,
307 "token": credentials.token,
308 "expiry_time": credentials.expiration,
309 }
310
311 def _collect_metrics(
312 self,
313 func: Callable[[], _T],
314 operation: str,
315 bucket: str,
316 key: str,
317 put_object_size: Optional[int] = None,
318 get_object_size: Optional[int] = None,
319 ) -> _T:
320 """
321 Collects and records performance metrics around S3 operations such as PUT, GET, DELETE, etc.
322
323 This method wraps an S3 operation and measures the time it takes to complete, along with recording
324 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
325 proper logging of duration and object size.
326
327 :param func: The function that performs the actual S3 operation.
328 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
329 :param bucket: The name of the S3 bucket involved in the operation.
330 :param key: The key of the object within the S3 bucket.
331 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
332 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
333
334 :return: The result of the S3 operation, typically the return value of the `func` callable.
335 """
336 # Import the span attribute helper
337 from ..instrumentation.utils import set_span_attribute
338
339 # Set basic operation attributes
340 set_span_attribute("s3_operation", operation)
341 set_span_attribute("s3_bucket", bucket)
342 set_span_attribute("s3_key", key)
343
344 start_time = time.time()
345 status_code = 200
346
347 object_size = None
348 if operation == "PUT":
349 object_size = put_object_size
350 elif operation == "GET" and get_object_size:
351 object_size = get_object_size
352
353 try:
354 result = func()
355 if operation == "GET" and object_size is None and isinstance(result, Sized):
356 object_size = len(result)
357 return result
358 except ClientError as error:
359 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
360 request_id = error.response["ResponseMetadata"].get("RequestId")
361 host_id = error.response["ResponseMetadata"].get("HostId")
362 header = error.response["ResponseMetadata"].get("HTTPHeaders", {})
363 error_code = error.response["Error"]["Code"]
364
365 # Ensure header is a dictionary before trying to get from it
366 x_trans_id = header.get("x-trans-id") if isinstance(header, dict) else None
367
368 # Record error details in span
369 set_span_attribute("request_id", request_id)
370 set_span_attribute("host_id", host_id)
371
372 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
373 if x_trans_id:
374 error_info += f", x-trans-id: {x_trans_id}"
375 set_span_attribute("x_trans_id", x_trans_id)
376
377 if status_code == 404:
378 if error_code == "NoSuchUpload":
379 error_message = error.response["Error"]["Message"]
380 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
381 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
382 elif status_code == 412: # Precondition Failed
383 raise PreconditionFailedError(
384 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
385 ) from error
386 elif status_code == 429:
387 raise RetryableError(
388 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
389 ) from error
390 elif status_code == 503:
391 raise RetryableError(
392 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
393 ) from error
394 elif status_code == 501:
395 raise NotImplementedError(
396 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
397 ) from error
398 else:
399 raise RuntimeError(
400 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
401 f"error_type: {type(error).__name__}"
402 ) from error
403 except FileNotFoundError as error:
404 status_code = -1
405 raise error
406 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
407 status_code = -1
408 raise RetryableError(
409 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
410 f"error_type: {type(error).__name__}"
411 ) from error
412 except Exception as error:
413 status_code = -1
414 raise RuntimeError(
415 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
416 ) from error
417 finally:
418 elapsed_time = time.time() - start_time
419
420 set_span_attribute("status_code", status_code)
421
422 # Record metrics
423 self._metric_helper.record_duration(
424 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code
425 )
426 if object_size:
427 self._metric_helper.record_object_size(
428 object_size,
429 provider=self._provider_name,
430 operation=operation,
431 bucket=bucket,
432 status_code=status_code,
433 )
434
435 set_span_attribute("object_size", object_size)
436
437 def _put_object(
438 self,
439 path: str,
440 body: bytes,
441 if_match: Optional[str] = None,
442 if_none_match: Optional[str] = None,
443 attributes: Optional[dict[str, str]] = None,
444 ) -> int:
445 """
446 Uploads an object to the specified S3 path.
447
448 :param path: The S3 path where the object will be uploaded.
449 :param body: The content of the object as bytes.
450 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
451 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
452 :param attributes: Optional attributes to attach to the object.
453 """
454 bucket, key = split_path(path)
455
456 def _invoke_api() -> int:
457 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
458 if self._is_directory_bucket(bucket):
459 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
460 if if_match:
461 kwargs["IfMatch"] = if_match
462 if if_none_match:
463 kwargs["IfNoneMatch"] = if_none_match
464 validated_attributes = validate_attributes(attributes)
465 if validated_attributes:
466 kwargs["Metadata"] = validated_attributes
467
468 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch"}
469 if (
470 self._rust_client
471 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
472 and not path.endswith("/")
473 and all(key not in kwargs for key in rust_unsupported_feature_keys)
474 ):
475 response = run_async_rust_client_method(self._rust_client, "put", key, body)
476 else:
477 # Capture the response from put_object
478 response = self._s3_client.put_object(**kwargs)
479
480 # Extract and set x-trans-id if present
481 _extract_x_trans_id(response)
482
483 return len(body)
484
485 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
486
487 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
488 bucket, key = split_path(path)
489
490 def _invoke_api() -> bytes:
491 if byte_range:
492 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
493 if self._rust_client:
494 response = run_async_rust_client_method(
495 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
496 )
497 return response
498 else:
499 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
500 else:
501 if self._rust_client:
502 response = run_async_rust_client_method(self._rust_client, "get", key)
503 return response
504 else:
505 response = self._s3_client.get_object(Bucket=bucket, Key=key)
506
507 # Extract and set x-trans-id if present
508 _extract_x_trans_id(response)
509
510 return response["Body"].read()
511
512 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
513
514 def _copy_object(self, src_path: str, dest_path: str) -> int:
515 src_bucket, src_key = split_path(src_path)
516 dest_bucket, dest_key = split_path(dest_path)
517
518 src_object = self._get_object_metadata(src_path)
519
520 def _invoke_api() -> int:
521 response = self._s3_client.copy(
522 CopySource={"Bucket": src_bucket, "Key": src_key},
523 Bucket=dest_bucket,
524 Key=dest_key,
525 Config=self._transfer_config,
526 )
527
528 # Extract and set x-trans-id if present
529 _extract_x_trans_id(response)
530
531 return src_object.content_length
532
533 return self._collect_metrics(
534 _invoke_api,
535 operation="COPY",
536 bucket=dest_bucket,
537 key=dest_key,
538 put_object_size=src_object.content_length,
539 )
540
541 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
542 bucket, key = split_path(path)
543
544 def _invoke_api() -> None:
545 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally
546 if if_match:
547 response = self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
548 else:
549 response = self._s3_client.delete_object(Bucket=bucket, Key=key)
550
551 # Extract and set x-trans-id if present
552 _extract_x_trans_id(response)
553
554 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
555
556 def _is_dir(self, path: str) -> bool:
557 # Ensure the path ends with '/' to mimic a directory
558 path = self._append_delimiter(path)
559
560 bucket, key = split_path(path)
561
562 def _invoke_api() -> bool:
563 # List objects with the given prefix
564 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
565
566 # Extract and set x-trans-id if present
567 _extract_x_trans_id(response)
568
569 # Check if there are any contents or common prefixes
570 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
571
572 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key)
573
574 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
575 bucket, key = split_path(path)
576 if path.endswith("/") or (bucket and not key):
577 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
578 # which metadata is not guaranteed to exist for cases such as
579 # "virtual prefix" that was never explicitly created.
580 if self._is_dir(path):
581 return ObjectMetadata(
582 key=path,
583 type="directory",
584 content_length=0,
585 last_modified=AWARE_DATETIME_MIN,
586 )
587 else:
588 raise FileNotFoundError(f"Directory {path} does not exist.")
589 else:
590
591 def _invoke_api() -> ObjectMetadata:
592 response = self._s3_client.head_object(Bucket=bucket, Key=key)
593
594 # Extract and set x-trans-id if present
595 _extract_x_trans_id(response)
596
597 return ObjectMetadata(
598 key=path,
599 type="file",
600 content_length=response["ContentLength"],
601 content_type=response["ContentType"],
602 last_modified=response["LastModified"],
603 etag=response["ETag"].strip('"'),
604 storage_class=response.get("StorageClass"),
605 metadata=response.get("Metadata"),
606 )
607
608 try:
609 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
610 except FileNotFoundError as error:
611 if strict:
612 # If the object does not exist on the given path, we will append a trailing slash and
613 # check if the path is a directory.
614 path = self._append_delimiter(path)
615 if self._is_dir(path):
616 return ObjectMetadata(
617 key=path,
618 type="directory",
619 content_length=0,
620 last_modified=AWARE_DATETIME_MIN,
621 )
622 raise error
623
624 def _list_objects(
625 self,
626 path: str,
627 start_after: Optional[str] = None,
628 end_at: Optional[str] = None,
629 include_directories: bool = False,
630 ) -> Iterator[ObjectMetadata]:
631 bucket, prefix = split_path(path)
632
633 def _invoke_api() -> Iterator[ObjectMetadata]:
634 paginator = self._s3_client.get_paginator("list_objects_v2")
635 if include_directories:
636 page_iterator = paginator.paginate(
637 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
638 )
639 else:
640 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
641
642 for page in page_iterator:
643 for item in page.get("CommonPrefixes", []):
644 yield ObjectMetadata(
645 key=os.path.join(bucket, item["Prefix"].rstrip("/")),
646 type="directory",
647 content_length=0,
648 last_modified=AWARE_DATETIME_MIN,
649 )
650
651 # S3 guarantees lexicographical order for general purpose buckets (for
652 # normal S3) but not directory buckets (for S3 Express One Zone).
653 for response_object in page.get("Contents", []):
654 key = response_object["Key"]
655 if end_at is None or key <= end_at:
656 if key.endswith("/"):
657 if include_directories:
658 yield ObjectMetadata(
659 key=os.path.join(bucket, key.rstrip("/")),
660 type="directory",
661 content_length=0,
662 last_modified=response_object["LastModified"],
663 )
664 else:
665 yield ObjectMetadata(
666 key=os.path.join(bucket, key),
667 type="file",
668 content_length=response_object["Size"],
669 last_modified=response_object["LastModified"],
670 etag=response_object["ETag"].strip('"'),
671 storage_class=response_object.get("StorageClass"), # Pass storage_class
672 )
673 else:
674 return
675
676 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
677
678 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
679 file_size: int = 0
680
681 if isinstance(f, str):
682 bucket, key = split_path(remote_path)
683 file_size = os.path.getsize(f)
684
685 # Upload small files
686 if file_size <= self._transfer_config.multipart_threshold:
687 if self._rust_client and not attributes:
688 run_async_rust_client_method(self._rust_client, "upload", f, key)
689 else:
690 with open(f, "rb") as fp:
691 self._put_object(remote_path, fp.read(), attributes=attributes)
692 return file_size
693
694 # Upload large files using TransferConfig
695 def _invoke_api() -> int:
696 extra_args = {}
697 if self._is_directory_bucket(bucket):
698 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
699 validated_attributes = validate_attributes(attributes)
700 if validated_attributes:
701 extra_args["Metadata"] = validated_attributes
702 if self._rust_client and not extra_args:
703 response = run_async_rust_client_method(self._rust_client, "upload_multipart", f, key)
704 else:
705 response = self._s3_client.upload_file(
706 Filename=f,
707 Bucket=bucket,
708 Key=key,
709 Config=self._transfer_config,
710 ExtraArgs=extra_args,
711 )
712
713 # Extract and set x-trans-id if present
714 _extract_x_trans_id(response)
715
716 return file_size
717
718 return self._collect_metrics(
719 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
720 )
721 else:
722 # Upload small files
723 f.seek(0, io.SEEK_END)
724 file_size = f.tell()
725 f.seek(0)
726
727 if file_size <= self._transfer_config.multipart_threshold:
728 if isinstance(f, io.StringIO):
729 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
730 else:
731 self._put_object(remote_path, f.read(), attributes=attributes)
732 return file_size
733
734 # Upload large files using TransferConfig
735 bucket, key = split_path(remote_path)
736
737 def _invoke_api() -> int:
738 extra_args = {}
739 if self._is_directory_bucket(bucket):
740 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
741 validated_attributes = validate_attributes(attributes)
742 if validated_attributes:
743 extra_args["Metadata"] = validated_attributes
744 self._s3_client.upload_fileobj(
745 Fileobj=f,
746 Bucket=bucket,
747 Key=key,
748 Config=self._transfer_config,
749 ExtraArgs=extra_args,
750 )
751
752 return file_size
753
754 return self._collect_metrics(
755 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
756 )
757
758 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
759 if metadata is None:
760 metadata = self._get_object_metadata(remote_path)
761
762 if isinstance(f, str):
763 bucket, key = split_path(remote_path)
764 if os.path.dirname(f):
765 os.makedirs(os.path.dirname(f), exist_ok=True)
766
767 # Download small files
768 if metadata.content_length <= self._transfer_config.multipart_threshold:
769 if self._rust_client:
770 run_async_rust_client_method(self._rust_client, "download", key, f)
771 else:
772 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
773 temp_file_path = fp.name
774 fp.write(self._get_object(remote_path))
775 os.rename(src=temp_file_path, dst=f)
776 return metadata.content_length
777
778 # Download large files using TransferConfig
779 def _invoke_api() -> int:
780 response = None
781 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
782 temp_file_path = fp.name
783 if self._rust_client:
784 response = run_async_rust_client_method(
785 self._rust_client, "download_multipart", key, temp_file_path
786 )
787 else:
788 response = self._s3_client.download_fileobj(
789 Bucket=bucket,
790 Key=key,
791 Fileobj=fp,
792 Config=self._transfer_config,
793 )
794
795 # Extract and set x-trans-id if present
796 _extract_x_trans_id(response)
797 os.rename(src=temp_file_path, dst=f)
798
799 return metadata.content_length
800
801 return self._collect_metrics(
802 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
803 )
804 else:
805 # Download small files
806 if metadata.content_length <= self._transfer_config.multipart_threshold:
807 response = self._get_object(remote_path)
808 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
809 # so we need to check whether `.decode()` is available.
810 if isinstance(f, io.StringIO):
811 if hasattr(response, "decode"):
812 f.write(response.decode("utf-8"))
813 else:
814 f.write(codecs.decode(memoryview(response), "utf-8"))
815 else:
816 f.write(response)
817 return metadata.content_length
818
819 # Download large files using TransferConfig
820 bucket, key = split_path(remote_path)
821
822 def _invoke_api() -> int:
823 response = self._s3_client.download_fileobj(
824 Bucket=bucket,
825 Key=key,
826 Fileobj=f,
827 Config=self._transfer_config,
828 )
829
830 # Extract and set x-trans-id if present
831 _extract_x_trans_id(response)
832
833 return metadata.content_length
834
835 return self._collect_metrics(
836 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
837 )