1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20import time
21from collections.abc import Callable, Iterator, Sequence, Sized
22from typing import IO, Any, Optional, TypeVar, Union
23
24import boto3
25import botocore
26import opentelemetry.metrics as api_metrics
27from boto3.s3.transfer import TransferConfig
28from botocore.credentials import RefreshableCredentials
29from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
30from botocore.session import get_session
31
32from ..instrumentation.utils import set_span_attribute
33from ..rust_utils import run_async_rust_client_method
34from ..telemetry import Telemetry
35from ..telemetry.attributes.base import AttributesProvider
36from ..types import (
37 AWARE_DATETIME_MIN,
38 Credentials,
39 CredentialsProvider,
40 ObjectMetadata,
41 PreconditionFailedError,
42 Range,
43 RetryableError,
44)
45from ..utils import (
46 split_path,
47 validate_attributes,
48)
49from .base import BaseStorageProvider
50
51_T = TypeVar("_T")
52
53BOTO3_MAX_POOL_CONNECTIONS = 32
54
55MiB = 1024 * 1024
56
57# Python and Rust share the same multipart_threshold to keep the code simple.
58MULTIPART_THRESHOLD = 64 * MiB
59MULTIPART_CHUNKSIZE = 32 * MiB
60IO_CHUNKSIZE = 32 * MiB
61# Python uses a lower default concurrency due to the GIL limiting true parallelism in threads.
62PYTHON_MAX_CONCURRENCY = 16
63RUST_MAX_CONCURRENCY = 32
64PROVIDER = "s3"
65
66EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
67
68
69def _extract_x_trans_id(response: Any) -> None:
70 """Extract x-trans-id from boto3 response and set it as span attribute."""
71 try:
72 if response and isinstance(response, dict):
73 headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
74 if headers and isinstance(headers, dict) and "x-trans-id" in headers:
75 set_span_attribute("x_trans_id", headers["x-trans-id"])
76 except (KeyError, AttributeError, TypeError):
77 # Silently ignore any errors in extraction
78 pass
79
80
[docs]
81class StaticS3CredentialsProvider(CredentialsProvider):
82 """
83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
84 """
85
86 _access_key: str
87 _secret_key: str
88 _session_token: Optional[str]
89
90 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
91 """
92 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
93 session token.
94
95 :param access_key: The access key for S3 authentication.
96 :param secret_key: The secret key for S3 authentication.
97 :param session_token: An optional session token for temporary credentials.
98 """
99 self._access_key = access_key
100 self._secret_key = secret_key
101 self._session_token = session_token
102
[docs]
103 def get_credentials(self) -> Credentials:
104 return Credentials(
105 access_key=self._access_key,
106 secret_key=self._secret_key,
107 token=self._session_token,
108 expiration=None,
109 )
110
[docs]
111 def refresh_credentials(self) -> None:
112 pass
113
114
[docs]
115class S3StorageProvider(BaseStorageProvider):
116 """
117 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
118 """
119
120 def __init__(
121 self,
122 region_name: str = "",
123 endpoint_url: str = "",
124 base_path: str = "",
125 credentials_provider: Optional[CredentialsProvider] = None,
126 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
127 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
128 metric_attributes_providers: Sequence[AttributesProvider] = (),
129 **kwargs: Any,
130 ) -> None:
131 """
132 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
133
134 :param region_name: The AWS region where the S3 bucket is located.
135 :param endpoint_url: The custom endpoint URL for the S3 service.
136 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
137 :param credentials_provider: The provider to retrieve S3 credentials.
138 :param metric_counters: Metric counters.
139 :param metric_gauges: Metric gauges.
140 :param metric_attributes_providers: Metric attributes providers.
141 """
142 super().__init__(
143 base_path=base_path,
144 provider_name=PROVIDER,
145 metric_counters=metric_counters,
146 metric_gauges=metric_gauges,
147 metric_attributes_providers=metric_attributes_providers,
148 )
149
150 self._region_name = region_name
151 self._endpoint_url = endpoint_url
152 self._credentials_provider = credentials_provider
153
154 self._signature_version = kwargs.get("signature_version", "")
155 self._s3_client = self._create_s3_client(
156 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
157 response_checksum_validation=kwargs.get("response_checksum_validation"),
158 max_pool_connections=kwargs.get("max_pool_connections", BOTO3_MAX_POOL_CONNECTIONS),
159 connect_timeout=kwargs.get("connect_timeout"),
160 read_timeout=kwargs.get("read_timeout"),
161 retries=kwargs.get("retries"),
162 )
163 self._transfer_config = TransferConfig(
164 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
165 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
166 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
167 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
168 use_threads=True,
169 )
170
171 self._rust_client = None
172 if "rust_client" in kwargs:
173 self._rust_client = self._create_rust_client(kwargs.get("rust_client"))
174
175 def _is_directory_bucket(self, bucket: str) -> bool:
176 """
177 Determines if the bucket is a directory bucket based on bucket name.
178 """
179 # S3 Express buckets have a specific naming convention
180 return "--x-s3" in bucket
181
182 def _create_s3_client(
183 self,
184 request_checksum_calculation: Optional[str] = None,
185 response_checksum_validation: Optional[str] = None,
186 max_pool_connections: int = BOTO3_MAX_POOL_CONNECTIONS,
187 connect_timeout: Union[float, int, None] = None,
188 read_timeout: Union[float, int, None] = None,
189 retries: Optional[dict[str, Any]] = None,
190 ):
191 """
192 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
193
194 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
195 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
196 :param max_pool_connections: The maximum number of connections to keep in a connection pool.
197 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection.
198 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection.
199 :param retries: A dictionary for configuration related to retry behavior.
200
201 :return: The configured S3 client.
202 """
203 options = {
204 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
205 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
206 max_pool_connections=max_pool_connections,
207 connect_timeout=connect_timeout,
208 read_timeout=read_timeout,
209 retries=retries or {"mode": "standard"},
210 request_checksum_calculation=request_checksum_calculation,
211 response_checksum_validation=response_checksum_validation,
212 ),
213 }
214
215 if self._region_name:
216 options["region_name"] = self._region_name
217
218 if self._endpoint_url:
219 options["endpoint_url"] = self._endpoint_url
220
221 if self._credentials_provider:
222 creds = self._fetch_credentials()
223 if "expiry_time" in creds and creds["expiry_time"]:
224 # Use RefreshableCredentials if expiry_time provided.
225 refreshable_credentials = RefreshableCredentials.create_from_metadata(
226 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
227 )
228
229 botocore_session = get_session()
230 botocore_session._credentials = refreshable_credentials
231
232 boto3_session = boto3.Session(botocore_session=botocore_session)
233
234 return boto3_session.client("s3", **options)
235 else:
236 # Add static credentials to the options dictionary
237 options["aws_access_key_id"] = creds["access_key"]
238 options["aws_secret_access_key"] = creds["secret_key"]
239 if creds["token"]:
240 options["aws_session_token"] = creds["token"]
241
242 if self._signature_version:
243 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
244 signature_version=botocore.UNSIGNED
245 if self._signature_version == "UNSIGNED"
246 else self._signature_version
247 )
248 options["config"] = options["config"].merge(signature_config)
249
250 # Fallback to standard credential chain.
251 return boto3.client("s3", **options)
252
253 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
254 """
255 Creates and configures the rust client, using refreshable credentials if possible.
256 """
257 from multistorageclient_rust import RustClient
258
259 configs = {}
260 if self._region_name:
261 configs["region_name"] = self._region_name
262
263 # If the user specifies a bucket, use it. Otherwise, use the base path.
264 if rust_client_options and "bucket" in rust_client_options:
265 configs["bucket"] = rust_client_options["bucket"]
266 else:
267 bucket, _ = split_path(self._base_path)
268 configs["bucket"] = bucket
269
270 if self._endpoint_url:
271 configs["endpoint_url"] = self._endpoint_url
272
273 if rust_client_options:
274 if rust_client_options.get("allow_http", False):
275 configs["allow_http"] = True
276 configs["max_concurrency"] = rust_client_options.get("max_concurrency", RUST_MAX_CONCURRENCY)
277 configs["multipart_chunksize"] = rust_client_options.get("multipart_chunksize", MULTIPART_CHUNKSIZE)
278
279 return RustClient(
280 provider=PROVIDER,
281 configs=configs,
282 credentials_provider=self._credentials_provider,
283 )
284
285 def _fetch_credentials(self) -> dict:
286 """
287 Refreshes the S3 client if the current credentials are expired.
288 """
289 if not self._credentials_provider:
290 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
291 self._credentials_provider.refresh_credentials()
292 credentials = self._credentials_provider.get_credentials()
293 return {
294 "access_key": credentials.access_key,
295 "secret_key": credentials.secret_key,
296 "token": credentials.token,
297 "expiry_time": credentials.expiration,
298 }
299
300 def _collect_metrics(
301 self,
302 func: Callable[[], _T],
303 operation: str,
304 bucket: str,
305 key: str,
306 put_object_size: Optional[int] = None,
307 get_object_size: Optional[int] = None,
308 ) -> _T:
309 """
310 Collects and records performance metrics around S3 operations such as PUT, GET, DELETE, etc.
311
312 This method wraps an S3 operation and measures the time it takes to complete, along with recording
313 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
314 proper logging of duration and object size.
315
316 :param func: The function that performs the actual S3 operation.
317 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
318 :param bucket: The name of the S3 bucket involved in the operation.
319 :param key: The key of the object within the S3 bucket.
320 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
321 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
322
323 :return: The result of the S3 operation, typically the return value of the `func` callable.
324 """
325 # Import the span attribute helper
326 from ..instrumentation.utils import set_span_attribute
327
328 # Set basic operation attributes
329 set_span_attribute("s3_operation", operation)
330 set_span_attribute("s3_bucket", bucket)
331 set_span_attribute("s3_key", key)
332
333 start_time = time.time()
334 status_code = 200
335
336 object_size = None
337 if operation == "PUT":
338 object_size = put_object_size
339 elif operation == "GET" and get_object_size:
340 object_size = get_object_size
341
342 try:
343 result = func()
344 if operation == "GET" and object_size is None and isinstance(result, Sized):
345 object_size = len(result)
346 return result
347 except ClientError as error:
348 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
349 request_id = error.response["ResponseMetadata"].get("RequestId")
350 host_id = error.response["ResponseMetadata"].get("HostId")
351 header = error.response["ResponseMetadata"].get("HTTPHeaders", {})
352 error_code = error.response["Error"]["Code"]
353
354 # Ensure header is a dictionary before trying to get from it
355 x_trans_id = header.get("x-trans-id") if isinstance(header, dict) else None
356
357 # Record error details in span
358 set_span_attribute("request_id", request_id)
359 set_span_attribute("host_id", host_id)
360
361 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
362 if x_trans_id:
363 error_info += f", x-trans-id: {x_trans_id}"
364 set_span_attribute("x_trans_id", x_trans_id)
365
366 if status_code == 404:
367 if error_code == "NoSuchUpload":
368 error_message = error.response["Error"]["Message"]
369 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
370 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
371 elif status_code == 412: # Precondition Failed
372 raise PreconditionFailedError(
373 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
374 ) from error
375 elif status_code == 429:
376 raise RetryableError(
377 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
378 ) from error
379 elif status_code == 503:
380 raise RetryableError(
381 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
382 ) from error
383 elif status_code == 501:
384 raise NotImplementedError(
385 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
386 ) from error
387 else:
388 raise RuntimeError(
389 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
390 f"error_type: {type(error).__name__}"
391 ) from error
392 except FileNotFoundError as error:
393 status_code = -1
394 raise error
395 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
396 status_code = -1
397 raise RetryableError(
398 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
399 f"error_type: {type(error).__name__}"
400 ) from error
401 except Exception as error:
402 status_code = -1
403 raise RuntimeError(
404 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
405 ) from error
406 finally:
407 elapsed_time = time.time() - start_time
408
409 set_span_attribute("status_code", status_code)
410
411 # Record metrics
412 self._metric_helper.record_duration(
413 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code
414 )
415 if object_size:
416 self._metric_helper.record_object_size(
417 object_size,
418 provider=self._provider_name,
419 operation=operation,
420 bucket=bucket,
421 status_code=status_code,
422 )
423
424 set_span_attribute("object_size", object_size)
425
426 def _put_object(
427 self,
428 path: str,
429 body: bytes,
430 if_match: Optional[str] = None,
431 if_none_match: Optional[str] = None,
432 attributes: Optional[dict[str, str]] = None,
433 ) -> int:
434 """
435 Uploads an object to the specified S3 path.
436
437 :param path: The S3 path where the object will be uploaded.
438 :param body: The content of the object as bytes.
439 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
440 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
441 :param attributes: Optional attributes to attach to the object.
442 """
443 bucket, key = split_path(path)
444
445 def _invoke_api() -> int:
446 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
447 if self._is_directory_bucket(bucket):
448 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
449 if if_match:
450 kwargs["IfMatch"] = if_match
451 if if_none_match:
452 kwargs["IfNoneMatch"] = if_none_match
453 validated_attributes = validate_attributes(attributes)
454 if validated_attributes:
455 kwargs["Metadata"] = validated_attributes
456
457 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch"}
458 if (
459 self._rust_client
460 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
461 and not path.endswith("/")
462 and all(key not in kwargs for key in rust_unsupported_feature_keys)
463 ):
464 response = run_async_rust_client_method(self._rust_client, "put", key, body)
465 else:
466 # Capture the response from put_object
467 response = self._s3_client.put_object(**kwargs)
468
469 # Extract and set x-trans-id if present
470 _extract_x_trans_id(response)
471
472 return len(body)
473
474 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
475
476 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
477 bucket, key = split_path(path)
478
479 def _invoke_api() -> bytes:
480 if byte_range:
481 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
482 if self._rust_client:
483 response = run_async_rust_client_method(
484 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
485 )
486 return response
487 else:
488 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
489 else:
490 if self._rust_client:
491 response = run_async_rust_client_method(self._rust_client, "get", key)
492 return response
493 else:
494 response = self._s3_client.get_object(Bucket=bucket, Key=key)
495
496 # Extract and set x-trans-id if present
497 _extract_x_trans_id(response)
498
499 return response["Body"].read()
500
501 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
502
503 def _copy_object(self, src_path: str, dest_path: str) -> int:
504 src_bucket, src_key = split_path(src_path)
505 dest_bucket, dest_key = split_path(dest_path)
506
507 src_object = self._get_object_metadata(src_path)
508
509 def _invoke_api() -> int:
510 response = self._s3_client.copy(
511 CopySource={"Bucket": src_bucket, "Key": src_key},
512 Bucket=dest_bucket,
513 Key=dest_key,
514 Config=self._transfer_config,
515 )
516
517 # Extract and set x-trans-id if present
518 _extract_x_trans_id(response)
519
520 return src_object.content_length
521
522 return self._collect_metrics(
523 _invoke_api,
524 operation="COPY",
525 bucket=dest_bucket,
526 key=dest_key,
527 put_object_size=src_object.content_length,
528 )
529
530 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
531 bucket, key = split_path(path)
532
533 def _invoke_api() -> None:
534 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally
535 if if_match:
536 response = self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
537 else:
538 response = self._s3_client.delete_object(Bucket=bucket, Key=key)
539
540 # Extract and set x-trans-id if present
541 _extract_x_trans_id(response)
542
543 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
544
545 def _is_dir(self, path: str) -> bool:
546 # Ensure the path ends with '/' to mimic a directory
547 path = self._append_delimiter(path)
548
549 bucket, key = split_path(path)
550
551 def _invoke_api() -> bool:
552 # List objects with the given prefix
553 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
554
555 # Extract and set x-trans-id if present
556 _extract_x_trans_id(response)
557
558 # Check if there are any contents or common prefixes
559 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
560
561 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key)
562
563 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
564 bucket, key = split_path(path)
565 if path.endswith("/") or (bucket and not key):
566 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
567 # which metadata is not guaranteed to exist for cases such as
568 # "virtual prefix" that was never explicitly created.
569 if self._is_dir(path):
570 return ObjectMetadata(
571 key=path,
572 type="directory",
573 content_length=0,
574 last_modified=AWARE_DATETIME_MIN,
575 )
576 else:
577 raise FileNotFoundError(f"Directory {path} does not exist.")
578 else:
579
580 def _invoke_api() -> ObjectMetadata:
581 response = self._s3_client.head_object(Bucket=bucket, Key=key)
582
583 # Extract and set x-trans-id if present
584 _extract_x_trans_id(response)
585
586 return ObjectMetadata(
587 key=path,
588 type="file",
589 content_length=response["ContentLength"],
590 content_type=response["ContentType"],
591 last_modified=response["LastModified"],
592 etag=response["ETag"].strip('"'),
593 storage_class=response.get("StorageClass"),
594 metadata=response.get("Metadata"),
595 )
596
597 try:
598 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
599 except FileNotFoundError as error:
600 if strict:
601 # If the object does not exist on the given path, we will append a trailing slash and
602 # check if the path is a directory.
603 path = self._append_delimiter(path)
604 if self._is_dir(path):
605 return ObjectMetadata(
606 key=path,
607 type="directory",
608 content_length=0,
609 last_modified=AWARE_DATETIME_MIN,
610 )
611 raise error
612
613 def _list_objects(
614 self,
615 path: str,
616 start_after: Optional[str] = None,
617 end_at: Optional[str] = None,
618 include_directories: bool = False,
619 ) -> Iterator[ObjectMetadata]:
620 bucket, prefix = split_path(path)
621
622 def _invoke_api() -> Iterator[ObjectMetadata]:
623 paginator = self._s3_client.get_paginator("list_objects_v2")
624 if include_directories:
625 page_iterator = paginator.paginate(
626 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
627 )
628 else:
629 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
630
631 for page in page_iterator:
632 for item in page.get("CommonPrefixes", []):
633 yield ObjectMetadata(
634 key=os.path.join(bucket, item["Prefix"].rstrip("/")),
635 type="directory",
636 content_length=0,
637 last_modified=AWARE_DATETIME_MIN,
638 )
639
640 # S3 guarantees lexicographical order for general purpose buckets (for
641 # normal S3) but not directory buckets (for S3 Express One Zone).
642 for response_object in page.get("Contents", []):
643 key = response_object["Key"]
644 if end_at is None or key <= end_at:
645 if key.endswith("/"):
646 if include_directories:
647 yield ObjectMetadata(
648 key=os.path.join(bucket, key.rstrip("/")),
649 type="directory",
650 content_length=0,
651 last_modified=response_object["LastModified"],
652 )
653 else:
654 yield ObjectMetadata(
655 key=os.path.join(bucket, key),
656 type="file",
657 content_length=response_object["Size"],
658 last_modified=response_object["LastModified"],
659 etag=response_object["ETag"].strip('"'),
660 storage_class=response_object.get("StorageClass"), # Pass storage_class
661 )
662 else:
663 return
664
665 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
666
667 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
668 file_size: int = 0
669
670 if isinstance(f, str):
671 bucket, key = split_path(remote_path)
672 file_size = os.path.getsize(f)
673
674 # Upload small files
675 if file_size <= self._transfer_config.multipart_threshold:
676 if self._rust_client and not attributes:
677 run_async_rust_client_method(self._rust_client, "upload", f, key)
678 else:
679 with open(f, "rb") as fp:
680 self._put_object(remote_path, fp.read(), attributes=attributes)
681 return file_size
682
683 # Upload large files using TransferConfig
684 def _invoke_api() -> int:
685 extra_args = {}
686 if self._is_directory_bucket(bucket):
687 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
688 validated_attributes = validate_attributes(attributes)
689 if validated_attributes:
690 extra_args["Metadata"] = validated_attributes
691 if self._rust_client and not extra_args:
692 response = run_async_rust_client_method(self._rust_client, "upload_multipart", f, key)
693 else:
694 response = self._s3_client.upload_file(
695 Filename=f,
696 Bucket=bucket,
697 Key=key,
698 Config=self._transfer_config,
699 ExtraArgs=extra_args,
700 )
701
702 # Extract and set x-trans-id if present
703 _extract_x_trans_id(response)
704
705 return file_size
706
707 return self._collect_metrics(
708 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
709 )
710 else:
711 # Upload small files
712 f.seek(0, io.SEEK_END)
713 file_size = f.tell()
714 f.seek(0)
715
716 if file_size <= self._transfer_config.multipart_threshold:
717 if isinstance(f, io.StringIO):
718 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
719 else:
720 self._put_object(remote_path, f.read(), attributes=attributes)
721 return file_size
722
723 # Upload large files using TransferConfig
724 bucket, key = split_path(remote_path)
725
726 def _invoke_api() -> int:
727 extra_args = {}
728 if self._is_directory_bucket(bucket):
729 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
730 validated_attributes = validate_attributes(attributes)
731 if validated_attributes:
732 extra_args["Metadata"] = validated_attributes
733 self._s3_client.upload_fileobj(
734 Fileobj=f,
735 Bucket=bucket,
736 Key=key,
737 Config=self._transfer_config,
738 ExtraArgs=extra_args,
739 )
740
741 return file_size
742
743 return self._collect_metrics(
744 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
745 )
746
747 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
748 if metadata is None:
749 metadata = self._get_object_metadata(remote_path)
750
751 if isinstance(f, str):
752 bucket, key = split_path(remote_path)
753 if os.path.dirname(f):
754 os.makedirs(os.path.dirname(f), exist_ok=True)
755
756 # Download small files
757 if metadata.content_length <= self._transfer_config.multipart_threshold:
758 if self._rust_client:
759 run_async_rust_client_method(self._rust_client, "download", key, f)
760 else:
761 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
762 temp_file_path = fp.name
763 fp.write(self._get_object(remote_path))
764 os.rename(src=temp_file_path, dst=f)
765 return metadata.content_length
766
767 # Download large files using TransferConfig
768 def _invoke_api() -> int:
769 response = None
770 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
771 temp_file_path = fp.name
772 if self._rust_client:
773 response = run_async_rust_client_method(
774 self._rust_client, "download_multipart", key, temp_file_path
775 )
776 else:
777 response = self._s3_client.download_fileobj(
778 Bucket=bucket,
779 Key=key,
780 Fileobj=fp,
781 Config=self._transfer_config,
782 )
783
784 # Extract and set x-trans-id if present
785 _extract_x_trans_id(response)
786 os.rename(src=temp_file_path, dst=f)
787
788 return metadata.content_length
789
790 return self._collect_metrics(
791 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
792 )
793 else:
794 # Download small files
795 if metadata.content_length <= self._transfer_config.multipart_threshold:
796 response = self._get_object(remote_path)
797 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
798 # so we need to check whether `.decode()` is available.
799 if isinstance(f, io.StringIO):
800 if hasattr(response, "decode"):
801 f.write(response.decode("utf-8"))
802 else:
803 f.write(codecs.decode(memoryview(response), "utf-8"))
804 else:
805 f.write(response)
806 return metadata.content_length
807
808 # Download large files using TransferConfig
809 bucket, key = split_path(remote_path)
810
811 def _invoke_api() -> int:
812 response = self._s3_client.download_fileobj(
813 Bucket=bucket,
814 Key=key,
815 Fileobj=f,
816 Config=self._transfer_config,
817 )
818
819 # Extract and set x-trans-id if present
820 _extract_x_trans_id(response)
821
822 return metadata.content_length
823
824 return self._collect_metrics(
825 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
826 )