1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import time
20from collections.abc import Callable, Iterator, Sequence, Sized
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25import opentelemetry.metrics as api_metrics
26from boto3.s3.transfer import TransferConfig
27from botocore.credentials import RefreshableCredentials
28from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
29from botocore.session import get_session
30
31from ..instrumentation.utils import set_span_attribute
32from ..rust_utils import run_async_rust_client_method
33from ..telemetry import Telemetry
34from ..telemetry.attributes.base import AttributesProvider
35from ..types import (
36 AWARE_DATETIME_MIN,
37 Credentials,
38 CredentialsProvider,
39 ObjectMetadata,
40 PreconditionFailedError,
41 Range,
42 RetryableError,
43)
44from ..utils import (
45 split_path,
46 validate_attributes,
47)
48from .base import BaseStorageProvider
49
50_T = TypeVar("_T")
51
52BOTO3_MAX_POOL_CONNECTIONS = 32
53
54MB = 1024 * 1024
55
56MULTIPART_THRESHOLD = 512 * MB
57MULTIPART_CHUNK_SIZE = 256 * MB
58IO_CHUNK_SIZE = 128 * MB
59MAX_CONCURRENCY = 16
60PROVIDER = "s3"
61
62EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
63
64
65def _extract_x_trans_id(response: Any) -> None:
66 """Extract x-trans-id from boto3 response and set it as span attribute."""
67 try:
68 if response and isinstance(response, dict):
69 headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
70 if headers and isinstance(headers, dict) and "x-trans-id" in headers:
71 set_span_attribute("x_trans_id", headers["x-trans-id"])
72 except (KeyError, AttributeError, TypeError):
73 # Silently ignore any errors in extraction
74 pass
75
76
[docs]
77class StaticS3CredentialsProvider(CredentialsProvider):
78 """
79 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
80 """
81
82 _access_key: str
83 _secret_key: str
84 _session_token: Optional[str]
85
86 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
87 """
88 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
89 session token.
90
91 :param access_key: The access key for S3 authentication.
92 :param secret_key: The secret key for S3 authentication.
93 :param session_token: An optional session token for temporary credentials.
94 """
95 self._access_key = access_key
96 self._secret_key = secret_key
97 self._session_token = session_token
98
[docs]
99 def get_credentials(self) -> Credentials:
100 return Credentials(
101 access_key=self._access_key,
102 secret_key=self._secret_key,
103 token=self._session_token,
104 expiration=None,
105 )
106
[docs]
107 def refresh_credentials(self) -> None:
108 pass
109
110
[docs]
111class S3StorageProvider(BaseStorageProvider):
112 """
113 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
114 """
115
116 def __init__(
117 self,
118 region_name: str = "",
119 endpoint_url: str = "",
120 base_path: str = "",
121 credentials_provider: Optional[CredentialsProvider] = None,
122 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
123 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
124 metric_attributes_providers: Sequence[AttributesProvider] = (),
125 **kwargs: Any,
126 ) -> None:
127 """
128 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
129
130 :param region_name: The AWS region where the S3 bucket is located.
131 :param endpoint_url: The custom endpoint URL for the S3 service.
132 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
133 :param credentials_provider: The provider to retrieve S3 credentials.
134 :param metric_counters: Metric counters.
135 :param metric_gauges: Metric gauges.
136 :param metric_attributes_providers: Metric attributes providers.
137 """
138 super().__init__(
139 base_path=base_path,
140 provider_name=PROVIDER,
141 metric_counters=metric_counters,
142 metric_gauges=metric_gauges,
143 metric_attributes_providers=metric_attributes_providers,
144 )
145
146 self._region_name = region_name
147 self._endpoint_url = endpoint_url
148 self._credentials_provider = credentials_provider
149
150 self._signature_version = kwargs.get("signature_version", "")
151 self._s3_client = self._create_s3_client(
152 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
153 response_checksum_validation=kwargs.get("response_checksum_validation"),
154 max_pool_connections=kwargs.get("max_pool_connections", BOTO3_MAX_POOL_CONNECTIONS),
155 connect_timeout=kwargs.get("connect_timeout"),
156 read_timeout=kwargs.get("read_timeout"),
157 retries=kwargs.get("retries"),
158 )
159 self._transfer_config = TransferConfig(
160 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
161 max_concurrency=int(kwargs.get("max_concurrency", MAX_CONCURRENCY)),
162 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNK_SIZE)),
163 io_chunksize=int(kwargs.get("io_chunk_size", IO_CHUNK_SIZE)),
164 use_threads=True,
165 )
166
167 self._rust_client = None
168 if "rust_client" in kwargs:
169 self._rust_client = self._create_rust_client(kwargs.get("rust_client"))
170
171 def _is_directory_bucket(self, bucket: str) -> bool:
172 """
173 Determines if the bucket is a directory bucket based on bucket name.
174 """
175 # S3 Express buckets have a specific naming convention
176 return "--x-s3" in bucket
177
178 def _create_s3_client(
179 self,
180 request_checksum_calculation: Optional[str] = None,
181 response_checksum_validation: Optional[str] = None,
182 max_pool_connections: int = BOTO3_MAX_POOL_CONNECTIONS,
183 connect_timeout: Union[float, int, None] = None,
184 read_timeout: Union[float, int, None] = None,
185 retries: Optional[dict[str, Any]] = None,
186 ):
187 """
188 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
189
190 :param request_checksum_calculation: When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
191 :param response_checksum_validation: When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
192 :param max_pool_connections: The maximum number of connections to keep in a connection pool.
193 :param connect_timeout: The time in seconds till a timeout exception is thrown when attempting to make a connection.
194 :param read_timeout: The time in seconds till a timeout exception is thrown when attempting to read from a connection.
195 :param retries: A dictionary for configuration related to retry behavior.
196
197 :return: The configured S3 client.
198 """
199 options = {
200 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
201 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
202 max_pool_connections=max_pool_connections,
203 connect_timeout=connect_timeout,
204 read_timeout=read_timeout,
205 retries=retries or {"mode": "standard"},
206 request_checksum_calculation=request_checksum_calculation,
207 response_checksum_validation=response_checksum_validation,
208 ),
209 }
210
211 if self._region_name:
212 options["region_name"] = self._region_name
213
214 if self._endpoint_url:
215 options["endpoint_url"] = self._endpoint_url
216
217 if self._credentials_provider:
218 creds = self._fetch_credentials()
219 if "expiry_time" in creds and creds["expiry_time"]:
220 # Use RefreshableCredentials if expiry_time provided.
221 refreshable_credentials = RefreshableCredentials.create_from_metadata(
222 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
223 )
224
225 botocore_session = get_session()
226 botocore_session._credentials = refreshable_credentials
227
228 boto3_session = boto3.Session(botocore_session=botocore_session)
229
230 return boto3_session.client("s3", **options)
231 else:
232 # Add static credentials to the options dictionary
233 options["aws_access_key_id"] = creds["access_key"]
234 options["aws_secret_access_key"] = creds["secret_key"]
235 if creds["token"]:
236 options["aws_session_token"] = creds["token"]
237
238 if self._signature_version:
239 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
240 signature_version=botocore.UNSIGNED
241 if self._signature_version == "UNSIGNED"
242 else self._signature_version
243 )
244 options["config"] = options["config"].merge(signature_config)
245
246 # Fallback to standard credential chain.
247 return boto3.client("s3", **options)
248
249 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
250 """
251 Creates and configures the rust client, using refreshable credentials if possible.
252 """
253 from multistorageclient_rust import RustClient
254
255 configs = {}
256 if self._region_name:
257 configs["region_name"] = self._region_name
258
259 bucket, _ = split_path(self._base_path)
260 configs["bucket"] = bucket
261
262 if self._endpoint_url:
263 configs["endpoint_url"] = self._endpoint_url
264
265 if self._credentials_provider:
266 creds = self._fetch_credentials()
267 if "expiry_time" in creds and creds["expiry_time"]:
268 # TODO: Implement refreshable credentials
269 raise NotImplementedError("Refreshable credentials are not yet implemented for the rust client.")
270 else:
271 # Add static credentials to the configs dictionary
272 configs["aws_access_key_id"] = creds["access_key"]
273 configs["aws_secret_access_key"] = creds["secret_key"]
274 if creds["token"]:
275 configs["aws_session_token"] = creds["token"]
276
277 if rust_client_options:
278 if rust_client_options.get("allow_http", False):
279 configs["allow_http"] = True
280
281 return RustClient(
282 provider=PROVIDER,
283 configs=configs,
284 )
285
286 def _fetch_credentials(self) -> dict:
287 """
288 Refreshes the S3 client if the current credentials are expired.
289 """
290 if not self._credentials_provider:
291 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
292 self._credentials_provider.refresh_credentials()
293 credentials = self._credentials_provider.get_credentials()
294 return {
295 "access_key": credentials.access_key,
296 "secret_key": credentials.secret_key,
297 "token": credentials.token,
298 "expiry_time": credentials.expiration,
299 }
300
301 def _collect_metrics(
302 self,
303 func: Callable[[], _T],
304 operation: str,
305 bucket: str,
306 key: str,
307 put_object_size: Optional[int] = None,
308 get_object_size: Optional[int] = None,
309 ) -> _T:
310 """
311 Collects and records performance metrics around S3 operations such as PUT, GET, DELETE, etc.
312
313 This method wraps an S3 operation and measures the time it takes to complete, along with recording
314 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
315 proper logging of duration and object size.
316
317 :param func: The function that performs the actual S3 operation.
318 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
319 :param bucket: The name of the S3 bucket involved in the operation.
320 :param key: The key of the object within the S3 bucket.
321 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
322 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
323
324 :return: The result of the S3 operation, typically the return value of the `func` callable.
325 """
326 # Import the span attribute helper
327 from ..instrumentation.utils import set_span_attribute
328
329 # Set basic operation attributes
330 set_span_attribute("s3_operation", operation)
331 set_span_attribute("s3_bucket", bucket)
332 set_span_attribute("s3_key", key)
333
334 start_time = time.time()
335 status_code = 200
336
337 object_size = None
338 if operation == "PUT":
339 object_size = put_object_size
340 elif operation == "GET" and get_object_size:
341 object_size = get_object_size
342
343 try:
344 result = func()
345 if operation == "GET" and object_size is None and isinstance(result, Sized):
346 object_size = len(result)
347 return result
348 except ClientError as error:
349 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
350 request_id = error.response["ResponseMetadata"].get("RequestId")
351 host_id = error.response["ResponseMetadata"].get("HostId")
352 header = error.response["ResponseMetadata"].get("HTTPHeaders", {})
353 error_code = error.response["Error"]["Code"]
354
355 # Ensure header is a dictionary before trying to get from it
356 x_trans_id = header.get("x-trans-id") if isinstance(header, dict) else None
357
358 # Record error details in span
359 set_span_attribute("request_id", request_id)
360 set_span_attribute("host_id", host_id)
361
362 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
363 if x_trans_id:
364 error_info += f", x-trans-id: {x_trans_id}"
365 set_span_attribute("x_trans_id", x_trans_id)
366
367 if status_code == 404:
368 if error_code == "NoSuchUpload":
369 error_message = error.response["Error"]["Message"]
370 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
371 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
372 elif status_code == 412: # Precondition Failed
373 raise PreconditionFailedError(
374 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
375 ) from error
376 elif status_code == 429:
377 raise RetryableError(
378 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
379 ) from error
380 elif status_code == 503:
381 raise RetryableError(
382 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
383 ) from error
384 elif status_code == 501:
385 raise NotImplementedError(
386 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
387 ) from error
388 else:
389 raise RuntimeError(
390 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
391 f"error_type: {type(error).__name__}"
392 ) from error
393 except FileNotFoundError as error:
394 status_code = -1
395 raise error
396 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
397 status_code = -1
398 raise RetryableError(
399 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
400 f"error_type: {type(error).__name__}"
401 ) from error
402 except Exception as error:
403 status_code = -1
404 raise RuntimeError(
405 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
406 ) from error
407 finally:
408 elapsed_time = time.time() - start_time
409
410 set_span_attribute("status_code", status_code)
411
412 # Record metrics
413 self._metric_helper.record_duration(
414 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code
415 )
416 if object_size:
417 self._metric_helper.record_object_size(
418 object_size,
419 provider=self._provider_name,
420 operation=operation,
421 bucket=bucket,
422 status_code=status_code,
423 )
424
425 set_span_attribute("object_size", object_size)
426
427 def _put_object(
428 self,
429 path: str,
430 body: bytes,
431 if_match: Optional[str] = None,
432 if_none_match: Optional[str] = None,
433 attributes: Optional[dict[str, str]] = None,
434 ) -> int:
435 """
436 Uploads an object to the specified S3 path.
437
438 :param path: The S3 path where the object will be uploaded.
439 :param body: The content of the object as bytes.
440 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
441 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
442 :param attributes: Optional attributes to attach to the object.
443 """
444 bucket, key = split_path(path)
445
446 def _invoke_api() -> int:
447 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
448 if self._is_directory_bucket(bucket):
449 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
450 if if_match:
451 kwargs["IfMatch"] = if_match
452 if if_none_match:
453 kwargs["IfNoneMatch"] = if_none_match
454 validated_attributes = validate_attributes(attributes)
455 if validated_attributes:
456 kwargs["Metadata"] = validated_attributes
457
458 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch"}
459 if (
460 self._rust_client
461 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
462 and not path.endswith("/")
463 and all(key not in kwargs for key in rust_unsupported_feature_keys)
464 ):
465 response = run_async_rust_client_method(self._rust_client, "put", key, body)
466 else:
467 # Capture the response from put_object
468 response = self._s3_client.put_object(**kwargs)
469
470 # Extract and set x-trans-id if present
471 _extract_x_trans_id(response)
472
473 return len(body)
474
475 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
476
477 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
478 bucket, key = split_path(path)
479
480 def _invoke_api() -> bytes:
481 if byte_range:
482 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
483 if self._rust_client:
484 response = run_async_rust_client_method(
485 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
486 )
487 return response
488 else:
489 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
490 else:
491 if self._rust_client:
492 response = run_async_rust_client_method(self._rust_client, "get", key)
493 return response
494 else:
495 response = self._s3_client.get_object(Bucket=bucket, Key=key)
496
497 # Extract and set x-trans-id if present
498 _extract_x_trans_id(response)
499
500 return response["Body"].read()
501
502 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
503
504 def _copy_object(self, src_path: str, dest_path: str) -> int:
505 src_bucket, src_key = split_path(src_path)
506 dest_bucket, dest_key = split_path(dest_path)
507
508 src_object = self._get_object_metadata(src_path)
509
510 def _invoke_api() -> int:
511 response = self._s3_client.copy(
512 CopySource={"Bucket": src_bucket, "Key": src_key},
513 Bucket=dest_bucket,
514 Key=dest_key,
515 Config=self._transfer_config,
516 )
517
518 # Extract and set x-trans-id if present
519 _extract_x_trans_id(response)
520
521 return src_object.content_length
522
523 return self._collect_metrics(
524 _invoke_api,
525 operation="COPY",
526 bucket=dest_bucket,
527 key=dest_key,
528 put_object_size=src_object.content_length,
529 )
530
531 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
532 bucket, key = split_path(path)
533
534 def _invoke_api() -> None:
535 # conditionally delete the object if if_match(etag) is provided, if not, delete the object unconditionally
536 if if_match:
537 response = self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
538 else:
539 response = self._s3_client.delete_object(Bucket=bucket, Key=key)
540
541 # Extract and set x-trans-id if present
542 _extract_x_trans_id(response)
543
544 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
545
546 def _is_dir(self, path: str) -> bool:
547 # Ensure the path ends with '/' to mimic a directory
548 path = self._append_delimiter(path)
549
550 bucket, key = split_path(path)
551
552 def _invoke_api() -> bool:
553 # List objects with the given prefix
554 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
555
556 # Extract and set x-trans-id if present
557 _extract_x_trans_id(response)
558
559 # Check if there are any contents or common prefixes
560 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
561
562 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key)
563
564 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
565 if path.endswith("/"):
566 # If path is a "directory", then metadata is not guaranteed to exist if
567 # it is a "virtual prefix" that was never explicitly created.
568 if self._is_dir(path):
569 return ObjectMetadata(
570 key=path,
571 type="directory",
572 content_length=0,
573 last_modified=AWARE_DATETIME_MIN,
574 )
575 else:
576 raise FileNotFoundError(f"Directory {path} does not exist.")
577 else:
578 bucket, key = split_path(path)
579
580 def _invoke_api() -> ObjectMetadata:
581 response = self._s3_client.head_object(Bucket=bucket, Key=key)
582
583 # Extract and set x-trans-id if present
584 _extract_x_trans_id(response)
585
586 return ObjectMetadata(
587 key=path,
588 type="file",
589 content_length=response["ContentLength"],
590 content_type=response["ContentType"],
591 last_modified=response["LastModified"],
592 etag=response["ETag"].strip('"'),
593 storage_class=response.get("StorageClass"),
594 metadata=response.get("Metadata"),
595 )
596
597 try:
598 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
599 except FileNotFoundError as error:
600 if strict:
601 # If the object does not exist on the given path, we will append a trailing slash and
602 # check if the path is a directory.
603 path = self._append_delimiter(path)
604 if self._is_dir(path):
605 return ObjectMetadata(
606 key=path,
607 type="directory",
608 content_length=0,
609 last_modified=AWARE_DATETIME_MIN,
610 )
611 raise error
612
613 def _list_objects(
614 self,
615 prefix: str,
616 start_after: Optional[str] = None,
617 end_at: Optional[str] = None,
618 include_directories: bool = False,
619 ) -> Iterator[ObjectMetadata]:
620 bucket, prefix = split_path(prefix)
621
622 def _invoke_api() -> Iterator[ObjectMetadata]:
623 paginator = self._s3_client.get_paginator("list_objects_v2")
624 if include_directories:
625 page_iterator = paginator.paginate(
626 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
627 )
628 else:
629 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
630
631 for page in page_iterator:
632 for item in page.get("CommonPrefixes", []):
633 yield ObjectMetadata(
634 key=os.path.join(bucket, item["Prefix"].rstrip("/")),
635 type="directory",
636 content_length=0,
637 last_modified=AWARE_DATETIME_MIN,
638 )
639
640 # S3 guarantees lexicographical order for general purpose buckets (for
641 # normal S3) but not directory buckets (for S3 Express One Zone).
642 for response_object in page.get("Contents", []):
643 key = response_object["Key"]
644 if end_at is None or key <= end_at:
645 if key.endswith("/"):
646 if include_directories:
647 yield ObjectMetadata(
648 key=os.path.join(bucket, key.rstrip("/")),
649 type="directory",
650 content_length=0,
651 last_modified=response_object["LastModified"],
652 )
653 else:
654 yield ObjectMetadata(
655 key=os.path.join(bucket, key),
656 type="file",
657 content_length=response_object["Size"],
658 last_modified=response_object["LastModified"],
659 etag=response_object["ETag"].strip('"'),
660 storage_class=response_object.get("StorageClass"), # Pass storage_class
661 )
662 else:
663 return
664
665 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
666
667 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
668 file_size: int = 0
669
670 if isinstance(f, str):
671 bucket, key = split_path(remote_path)
672 file_size = os.path.getsize(f)
673
674 # Upload small files
675 if file_size <= self._transfer_config.multipart_threshold:
676 if self._rust_client and not attributes:
677 run_async_rust_client_method(self._rust_client, "upload", f, key)
678 else:
679 with open(f, "rb") as fp:
680 self._put_object(remote_path, fp.read(), attributes=attributes)
681 return file_size
682
683 # Upload large files using TransferConfig
684 def _invoke_api() -> int:
685 extra_args = {}
686 if self._is_directory_bucket(bucket):
687 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
688 validated_attributes = validate_attributes(attributes)
689 if validated_attributes:
690 extra_args["Metadata"] = validated_attributes
691 # TODO: Add support for multipart upload of rust client
692 response = self._s3_client.upload_file(
693 Filename=f,
694 Bucket=bucket,
695 Key=key,
696 Config=self._transfer_config,
697 ExtraArgs=extra_args,
698 )
699
700 # Extract and set x-trans-id if present
701 _extract_x_trans_id(response)
702
703 return file_size
704
705 return self._collect_metrics(
706 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
707 )
708 else:
709 # Upload small files
710 f.seek(0, io.SEEK_END)
711 file_size = f.tell()
712 f.seek(0)
713
714 if file_size <= self._transfer_config.multipart_threshold:
715 if isinstance(f, io.StringIO):
716 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
717 else:
718 self._put_object(remote_path, f.read(), attributes=attributes)
719 return file_size
720
721 # Upload large files using TransferConfig
722 bucket, key = split_path(remote_path)
723
724 def _invoke_api() -> int:
725 extra_args = {}
726 if self._is_directory_bucket(bucket):
727 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
728 validated_attributes = validate_attributes(attributes)
729 if validated_attributes:
730 extra_args["Metadata"] = validated_attributes
731 self._s3_client.upload_fileobj(
732 Fileobj=f,
733 Bucket=bucket,
734 Key=key,
735 Config=self._transfer_config,
736 ExtraArgs=extra_args,
737 )
738
739 return file_size
740
741 return self._collect_metrics(
742 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
743 )
744
745 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
746 if metadata is None:
747 metadata = self._get_object_metadata(remote_path)
748
749 if isinstance(f, str):
750 bucket, key = split_path(remote_path)
751 os.makedirs(os.path.dirname(f), exist_ok=True)
752
753 # Download small files
754 if metadata.content_length <= self._transfer_config.multipart_threshold:
755 if self._rust_client:
756 run_async_rust_client_method(self._rust_client, "download", key, f)
757 else:
758 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
759 temp_file_path = fp.name
760 fp.write(self._get_object(remote_path))
761 os.rename(src=temp_file_path, dst=f)
762 return metadata.content_length
763
764 # Download large files using TransferConfig
765 def _invoke_api() -> int:
766 response = None
767 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
768 temp_file_path = fp.name
769 # TODO: Add support for multipart download of rust client
770 response = self._s3_client.download_fileobj(
771 Bucket=bucket,
772 Key=key,
773 Fileobj=fp,
774 Config=self._transfer_config,
775 )
776
777 # Extract and set x-trans-id if present
778 _extract_x_trans_id(response)
779 os.rename(src=temp_file_path, dst=f)
780
781 return metadata.content_length
782
783 return self._collect_metrics(
784 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
785 )
786 else:
787 # Download small files
788 if metadata.content_length <= self._transfer_config.multipart_threshold:
789 if isinstance(f, io.StringIO):
790 f.write(self._get_object(remote_path).decode("utf-8"))
791 else:
792 f.write(self._get_object(remote_path))
793 return metadata.content_length
794
795 # Download large files using TransferConfig
796 bucket, key = split_path(remote_path)
797
798 def _invoke_api() -> int:
799 response = self._s3_client.download_fileobj(
800 Bucket=bucket,
801 Key=key,
802 Fileobj=f,
803 Config=self._transfer_config,
804 )
805
806 # Extract and set x-trans-id if present
807 _extract_x_trans_id(response)
808
809 return metadata.content_length
810
811 return self._collect_metrics(
812 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
813 )