1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29from dateutil.parser import parse as dateutil_parse
30
31from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
32
33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_MAX_POOL_CONNECTIONS, DEFAULT_READ_TIMEOUT
34from ..rust_utils import parse_retry_config, run_async_rust_client_method
35from ..signers import CloudFrontURLSigner, URLSigner
36from ..telemetry import Telemetry
37from ..types import (
38 AWARE_DATETIME_MIN,
39 Credentials,
40 CredentialsProvider,
41 ObjectMetadata,
42 PreconditionFailedError,
43 Range,
44 RetryableError,
45 SignerType,
46 SymlinkHandling,
47)
48from ..utils import (
49 safe_makedirs,
50 split_path,
51 validate_attributes,
52)
53from .base import BaseStorageProvider
54
55_T = TypeVar("_T")
56
57MiB = 1024 * 1024
58
59# Python and Rust share the same multipart_threshold to keep the code simple.
60MULTIPART_THRESHOLD = 64 * MiB
61MULTIPART_CHUNKSIZE = 32 * MiB
62IO_CHUNKSIZE = 32 * MiB
63PYTHON_MAX_CONCURRENCY = 8
64MAX_POOL_CONNECTIONS = DEFAULT_MAX_POOL_CONNECTIONS
65
66PROVIDER = "s3"
67
68EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
69
70# Source: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
71SUPPORTED_CHECKSUM_ALGORITHMS = frozenset({"CRC32", "CRC32C", "SHA1", "SHA256", "CRC64NVME"})
72
73
[docs]
74class StaticS3CredentialsProvider(CredentialsProvider):
75 """
76 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
77 """
78
79 _access_key: str
80 _secret_key: str
81 _session_token: Optional[str]
82
83 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
84 """
85 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
86 session token.
87
88 :param access_key: The access key for S3 authentication.
89 :param secret_key: The secret key for S3 authentication.
90 :param session_token: An optional session token for temporary credentials.
91 """
92 self._access_key = access_key
93 self._secret_key = secret_key
94 self._session_token = session_token
95
[docs]
96 def get_credentials(self) -> Credentials:
97 return Credentials(
98 access_key=self._access_key,
99 secret_key=self._secret_key,
100 token=self._session_token,
101 expiration=None,
102 )
103
[docs]
104 def refresh_credentials(self) -> None:
105 pass
106
107
108DEFAULT_PRESIGN_EXPIRES_IN = 3600
109
110_S3_METHOD_MAPPING: dict[str, str] = {
111 "GET": "get_object",
112 "PUT": "put_object",
113}
114
115
[docs]
116class S3URLSigner(URLSigner):
117 """Generates pre-signed URLs using the boto3 S3 client.
118
119 When the underlying credentials are temporary (STS, IAM role, EC2 instance
120 profile), the effective URL lifetime is the **shorter** of ``expires_in``
121 and the remaining credential lifetime — boto3 will not warn if the
122 credential expires before ``expires_in``.
123
124 See https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html
125 """
126
127 def __init__(self, s3_client: Any, bucket: str, expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN) -> None:
128 self._s3_client = s3_client
129 self._bucket = bucket
130 self._expires_in = expires_in
131
[docs]
132 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str:
133 client_method = _S3_METHOD_MAPPING.get(method.upper())
134 if client_method is None:
135 raise ValueError(f"Unsupported method for S3 presigning: {method!r}")
136 return self._s3_client.generate_presigned_url(
137 ClientMethod=client_method,
138 Params={"Bucket": self._bucket, "Key": path},
139 ExpiresIn=self._expires_in,
140 )
141
142
[docs]
143class S3StorageProvider(BaseStorageProvider):
144 """
145 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
146 """
147
148 def __init__(
149 self,
150 region_name: str = "",
151 endpoint_url: str = "",
152 base_path: str = "",
153 credentials_provider: Optional[CredentialsProvider] = None,
154 profile_name: Optional[str] = None,
155 config_dict: Optional[dict[str, Any]] = None,
156 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
157 verify: Optional[Union[bool, str]] = None,
158 **kwargs: Any,
159 ) -> None:
160 """
161 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
162
163 :param region_name: The AWS region where the S3 bucket is located.
164 :param endpoint_url: The custom endpoint URL for the S3 service.
165 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
166 :param credentials_provider: The provider to retrieve S3 credentials.
167 :param profile_name: AWS shared configuration + credentials files profile. For :py:class:`boto3.session.Session`. Ignored if ``credentials_provider`` is set.
168 :param config_dict: Resolved MSC config.
169 :param telemetry_provider: A function that provides a telemetry instance.
170 :param verify: Controls SSL certificate verification.
171 Can be ``True`` (verify using system CA bundle, default), ``False`` (skip verification), or a string path to a custom CA certificate bundle.
172 :param request_checksum_calculation: For :py:class:`botocore.config.Config`.
173 When the underlying S3 client should calculate request checksums.
174 :param response_checksum_validation: For :py:class:`botocore.config.Config`.
175 When the underlying S3 client should validate response checksums.
176 :param max_pool_connections: For :py:class:`botocore.config.Config`.
177 The maximum number of connections to keep in a connection pool.
178 :param connect_timeout: For :py:class:`botocore.config.Config`.
179 The time in seconds till a timeout exception is thrown when attempting to make a connection.
180 :param read_timeout: For :py:class:`botocore.config.Config`.
181 The time in seconds till a timeout exception is thrown when attempting to read from a connection.
182 :param retries: For :py:class:`botocore.config.Config`.
183 A dictionary for configuration related to retry behavior.
184 :param s3: For :py:class:`botocore.config.Config`.
185 A dictionary of S3 specific configurations.
186 :param checksum_algorithm: Upload-only object integrity algorithm.
187 One of ``"CRC32"``, ``"CRC32C"``, ``"SHA1"``, ``"SHA256"``, ``"CRC64NVME"`` (case-insensitive).
188 When ``rust_client`` is enabled, only ``"SHA256"`` is accepted.
189 :param multipart_threshold: For :py:class:`boto3.s3.transfer.TransferConfig`.
190 The transfer size threshold for which multipart uploads, downloads, and copies will automatically be triggered.
191 :param max_concurrency: For :py:class:`boto3.s3.transfer.TransferConfig`.
192 The maximum number of threads that will be making requests to perform a transfer.
193 If ``use_threads`` is set to ``False``, the value provided is ignored as the transfer will only ever use the current thread.
194 :param multipart_chunksize: For :py:class:`boto3.s3.transfer.TransferConfig`.
195 The partition size of each part for a multipart transfer.
196 :param io_chunksize: For :py:class:`boto3.s3.transfer.TransferConfig`.
197 The max size of each chunk in the ``io`` queue. Currently, this is the size used when read is called on the downloaded stream as well.
198 Note: This value is ignored when resolved transfer manager type is CRTTransferManager.
199 """
200 super().__init__(
201 base_path=base_path,
202 provider_name=PROVIDER,
203 config_dict=config_dict,
204 telemetry_provider=telemetry_provider,
205 )
206
207 self._region_name = region_name
208 self._endpoint_url = endpoint_url
209 self._credentials_provider = credentials_provider
210 self._profile_name = profile_name
211 self._verify = verify
212
213 self._signature_version = kwargs.get("signature_version", "s3v4")
214 self._s3_client = self._create_s3_client(
215 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
216 response_checksum_validation=kwargs.get("response_checksum_validation"),
217 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
218 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT),
219 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT),
220 retries=kwargs.get("retries"),
221 s3=kwargs.get("s3"),
222 )
223 self._transfer_config = TransferConfig(
224 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
225 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
226 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
227 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
228 use_threads=True,
229 )
230 self._multipart_threshold = self._transfer_config.multipart_threshold
231
232 self._signer_cache: dict[tuple, URLSigner] = {}
233
234 self._checksum_algorithm: Optional[str] = self._validate_checksum_algorithm(kwargs.get("checksum_algorithm"))
235
236 self._rust_client = None
237 if "rust_client" in kwargs:
238 # Inherit the rust client options from the kwargs
239 rust_client_options = kwargs["rust_client"]
240 if "max_pool_connections" in kwargs:
241 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
242 if "max_concurrency" in kwargs:
243 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
244 if "multipart_chunksize" in kwargs:
245 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
246 if "read_timeout" in kwargs:
247 rust_client_options["read_timeout"] = kwargs["read_timeout"]
248 if "connect_timeout" in kwargs:
249 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
250 if "checksum_algorithm" in kwargs:
251 rust_client_options["checksum_algorithm"] = kwargs["checksum_algorithm"]
252 if self._signature_version == "UNSIGNED":
253 rust_client_options["skip_signature"] = True
254
255 # Rust client only supports SHA256 today (object_store limitation).
256 rust_checksum = rust_client_options.get("checksum_algorithm")
257 if rust_checksum is not None:
258 if str(rust_checksum).upper() != "SHA256":
259 raise ValueError(
260 f"Rust client only supports checksum_algorithm='SHA256' today "
261 f"(object_store limitation), got '{rust_checksum!r}'. "
262 f"Disable rust_client or use SHA256."
263 )
264 rust_client_options["checksum_algorithm"] = "sha256"
265 self._rust_client = self._create_rust_client(rust_client_options)
266
267 @staticmethod
268 def _validate_checksum_algorithm(value: Optional[str]) -> Optional[str]:
269 """
270 Normalizes the optional ``checksum_algorithm`` to uppercase, or returns ``None`` if disabled.
271 """
272 if value is None:
273 return None
274 if not isinstance(value, str) or not value:
275 raise ValueError(
276 f"checksum_algorithm must be one of {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)} or None, got {value!r}."
277 )
278 normalized = value.upper()
279 if normalized not in SUPPORTED_CHECKSUM_ALGORITHMS:
280 raise ValueError(
281 f"Unsupported checksum_algorithm '{value}'. Supported: {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)}."
282 )
283 return normalized
284
285 def _is_directory_bucket(self, bucket: str) -> bool:
286 """
287 Determines if the bucket is a directory bucket based on bucket name.
288 """
289 # S3 Express buckets have a specific naming convention
290 return "--x-s3" in bucket
291
292 def _create_s3_client(
293 self,
294 request_checksum_calculation: Optional[str] = None,
295 response_checksum_validation: Optional[str] = None,
296 max_pool_connections: int = MAX_POOL_CONNECTIONS,
297 connect_timeout: Union[float, int, None] = None,
298 read_timeout: Union[float, int, None] = None,
299 retries: Optional[dict[str, Any]] = None,
300 s3: Optional[dict[str, Any]] = None,
301 ):
302 """
303 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
304
305 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
306 :param response_checksum_validation: For :py:class:`botocore.config.Config`. When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
307 :param max_pool_connections: For :py:class:`botocore.config.Config`. The maximum number of connections to keep in a connection pool.
308 :param connect_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to make a connection.
309 :param read_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to read from a connection.
310 :param retries: For :py:class:`botocore.config.Config`. A dictionary for configuration related to retry behavior.
311 :param s3: For :py:class:`botocore.config.Config`. A dictionary of S3 specific configurations.
312
313 :return: The configured S3 client.
314 """
315 options = {
316 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
317 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
318 max_pool_connections=max_pool_connections,
319 connect_timeout=connect_timeout,
320 read_timeout=read_timeout,
321 retries=retries or {"mode": "standard"},
322 request_checksum_calculation=request_checksum_calculation,
323 response_checksum_validation=response_checksum_validation,
324 s3=s3,
325 ),
326 }
327
328 if self._region_name:
329 options["region_name"] = self._region_name
330
331 if self._endpoint_url:
332 options["endpoint_url"] = self._endpoint_url
333
334 if self._verify is not None:
335 options["verify"] = self._verify
336
337 if self._credentials_provider:
338 creds = self._fetch_credentials()
339 if "expiry_time" in creds and creds["expiry_time"]:
340 # Use RefreshableCredentials if expiry_time provided.
341 refreshable_credentials = RefreshableCredentials.create_from_metadata(
342 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
343 )
344
345 botocore_session = get_session()
346 botocore_session._credentials = refreshable_credentials
347
348 boto3_session = boto3.Session(botocore_session=botocore_session)
349
350 return boto3_session.client("s3", **options)
351 else:
352 # Add static credentials to the options dictionary
353 options["aws_access_key_id"] = creds["access_key"]
354 options["aws_secret_access_key"] = creds["secret_key"]
355 if creds["token"]:
356 options["aws_session_token"] = creds["token"]
357
358 if self._signature_version:
359 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
360 signature_version=botocore.UNSIGNED
361 if self._signature_version == "UNSIGNED"
362 else self._signature_version
363 )
364 options["config"] = options["config"].merge(signature_config)
365
366 # Fallback to standard credential chain.
367 session = boto3.Session(profile_name=self._profile_name)
368 return session.client("s3", **options)
369
370 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
371 """
372 Creates and configures the rust client, using refreshable credentials if possible.
373 """
374 configs = dict(rust_client_options) if rust_client_options else {}
375
376 # Extract and parse retry configuration
377 retry_config = parse_retry_config(configs)
378
379 if self._region_name and "region_name" not in configs:
380 configs["region_name"] = self._region_name
381
382 if self._endpoint_url and "endpoint_url" not in configs:
383 configs["endpoint_url"] = self._endpoint_url
384
385 if self._profile_name and "profile_name" not in configs:
386 configs["profile_name"] = self._profile_name
387
388 # If the user specifies a bucket, use it. Otherwise, use the base path.
389 if "bucket" not in configs:
390 bucket, _ = split_path(self._base_path)
391 configs["bucket"] = bucket
392
393 if "max_pool_connections" not in configs:
394 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
395
396 return RustClient(
397 provider=PROVIDER,
398 configs=configs,
399 credentials_provider=self._credentials_provider,
400 retry=retry_config,
401 )
402
403 def _fetch_credentials(self) -> dict:
404 """
405 Refreshes the S3 client if the current credentials are expired.
406 """
407 if not self._credentials_provider:
408 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
409 self._credentials_provider.refresh_credentials()
410 credentials = self._credentials_provider.get_credentials()
411 return {
412 "access_key": credentials.access_key,
413 "secret_key": credentials.secret_key,
414 "token": credentials.token,
415 "expiry_time": credentials.expiration,
416 }
417
418 def _translate_errors(
419 self,
420 func: Callable[[], _T],
421 operation: str,
422 bucket: str,
423 key: str,
424 ) -> _T:
425 """
426 Translates errors like timeouts and client errors.
427
428 :param func: The function that performs the actual S3 operation.
429 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
430 :param bucket: The name of the S3 bucket involved in the operation.
431 :param key: The key of the object within the S3 bucket.
432
433 :return: The result of the S3 operation, typically the return value of the `func` callable.
434 """
435 try:
436 return func()
437 except ClientError as error:
438 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
439 request_id = error.response["ResponseMetadata"].get("RequestId")
440 host_id = error.response["ResponseMetadata"].get("HostId")
441 error_code = error.response["Error"]["Code"]
442 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
443
444 if status_code == 404:
445 if error_code == "NoSuchUpload":
446 error_message = error.response["Error"]["Message"]
447 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
448 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
449 elif status_code == 412: # Precondition Failed
450 raise PreconditionFailedError(
451 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
452 ) from error
453 elif status_code == 429:
454 raise RetryableError(
455 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
456 ) from error
457 elif status_code == 503:
458 raise RetryableError(
459 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
460 ) from error
461 elif status_code == 501:
462 raise NotImplementedError(
463 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
464 ) from error
465 elif status_code == 408:
466 # 408 Request Timeout is from Google Cloud Storage
467 raise RetryableError(
468 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}"
469 ) from error
470 else:
471 raise RuntimeError(
472 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
473 f"error_type: {type(error).__name__}"
474 ) from error
475 except RustClientError as error:
476 message = error.args[0]
477 status_code = error.args[1]
478 if status_code == 404:
479 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
480 elif status_code == 403:
481 raise PermissionError(
482 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
483 ) from error
484 else:
485 raise RetryableError(
486 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
487 ) from error
488 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
489 raise RetryableError(
490 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
491 f"error_type: {type(error).__name__}"
492 ) from error
493 except RustRetryableError as error:
494 raise RetryableError(
495 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
496 f"error_type: {type(error).__name__}"
497 ) from error
498 except Exception as error:
499 raise RuntimeError(
500 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
501 ) from error
502
503 def _put_object(
504 self,
505 path: str,
506 body: bytes,
507 if_match: Optional[str] = None,
508 if_none_match: Optional[str] = None,
509 attributes: Optional[dict[str, str]] = None,
510 content_type: Optional[str] = None,
511 ) -> int:
512 """
513 Uploads an object to the specified S3 path.
514
515 :param path: The S3 path where the object will be uploaded.
516 :param body: The content of the object as bytes.
517 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
518 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
519 :param attributes: Optional attributes to attach to the object.
520 :param content_type: Optional Content-Type header value.
521 """
522 bucket, key = split_path(path)
523
524 def _invoke_api() -> int:
525 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
526 if content_type:
527 kwargs["ContentType"] = content_type
528 if self._is_directory_bucket(bucket):
529 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
530 if if_match:
531 kwargs["IfMatch"] = if_match
532 if if_none_match:
533 kwargs["IfNoneMatch"] = if_none_match
534 validated_attributes = validate_attributes(attributes)
535 if validated_attributes:
536 kwargs["Metadata"] = validated_attributes
537
538 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
539 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
540 if (
541 self._rust_client
542 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
543 and not path.endswith("/")
544 and all(key not in kwargs for key in rust_unsupported_feature_keys)
545 ):
546 run_async_rust_client_method(self._rust_client, "put", key, body)
547 else:
548 if self._checksum_algorithm:
549 kwargs["ChecksumAlgorithm"] = self._checksum_algorithm
550 self._s3_client.put_object(**kwargs)
551
552 return len(body)
553
554 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
555
556 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
557 bucket, key = split_path(path)
558
559 def _invoke_api() -> bytes:
560 if byte_range:
561 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
562 if self._rust_client:
563 response = run_async_rust_client_method(
564 self._rust_client,
565 "get",
566 key,
567 byte_range,
568 )
569 return response
570 else:
571 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
572 else:
573 if self._rust_client:
574 response = run_async_rust_client_method(self._rust_client, "get", key)
575 return response
576 else:
577 response = self._s3_client.get_object(Bucket=bucket, Key=key)
578
579 return response["Body"].read()
580
581 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
582
583 def _copy_object(self, src_path: str, dest_path: str) -> int:
584 src_bucket, src_key = split_path(src_path)
585 dest_bucket, dest_key = split_path(dest_path)
586
587 src_object = self._get_object_metadata(src_path)
588
589 def _invoke_api() -> int:
590 self._s3_client.copy(
591 CopySource={"Bucket": src_bucket, "Key": src_key},
592 Bucket=dest_bucket,
593 Key=dest_key,
594 Config=self._transfer_config,
595 )
596
597 return src_object.content_length
598
599 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
600
601 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
602 bucket, key = split_path(path)
603
604 def _invoke_api() -> None:
605 # Delete conditionally when if_match (etag) is provided; otherwise delete unconditionally
606 if if_match:
607 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
608 else:
609 self._s3_client.delete_object(Bucket=bucket, Key=key)
610
611 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
612
613 def _delete_objects(self, paths: list[str]) -> None:
614 if not paths:
615 return
616
617 by_bucket: dict[str, list[str]] = {}
618 for p in paths:
619 bucket, key = split_path(p)
620 by_bucket.setdefault(bucket, []).append(key)
621
622 S3_BATCH_LIMIT = 1000
623
624 def _invoke_api() -> None:
625 all_errors: list[str] = []
626 for bucket, keys in by_bucket.items():
627 for i in range(0, len(keys), S3_BATCH_LIMIT):
628 chunk = keys[i : i + S3_BATCH_LIMIT]
629 response = self._s3_client.delete_objects(
630 Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}
631 )
632 errors = response.get("Errors") or []
633 for e in errors:
634 all_errors.append(f"{bucket}/{e.get('Key', '?')}: {e.get('Code', '')} {e.get('Message', '')}")
635 if all_errors:
636 raise RuntimeError(f"DeleteObjects reported errors: {'; '.join(all_errors)}")
637
638 bucket_desc = "(" + "|".join(by_bucket) + ")"
639 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)"
640 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc)
641
642 def _is_dir(self, path: str) -> bool:
643 # Ensure the path ends with '/' to mimic a directory
644 path = self._append_delimiter(path)
645
646 bucket, key = split_path(path)
647
648 def _invoke_api() -> bool:
649 # List objects with the given prefix
650 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
651
652 # Check if there are any contents or common prefixes
653 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
654
655 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
656
657 def _make_symlink(self, path: str, target: str) -> None:
658 bucket, key = split_path(path)
659 target_bucket, target_key = split_path(target)
660 if bucket != target_bucket:
661 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket}' -> '{target_bucket}'.")
662 relative_target = ObjectMetadata.encode_symlink_target(key, target_key)
663
664 def _invoke_api() -> None:
665 self._s3_client.put_object(
666 Bucket=bucket,
667 Key=key,
668 Body=b"",
669 Metadata={"msc-symlink-target": relative_target},
670 )
671
672 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
673
674 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
675 bucket, key = split_path(path)
676 if path.endswith("/") or (bucket and not key):
677 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
678 # which metadata is not guaranteed to exist for cases such as
679 # "virtual prefix" that was never explicitly created.
680 if self._is_dir(path):
681 return ObjectMetadata(
682 key=path,
683 type="directory",
684 content_length=0,
685 last_modified=AWARE_DATETIME_MIN,
686 )
687 else:
688 raise FileNotFoundError(f"Directory {path} does not exist.")
689 else:
690
691 def _invoke_api() -> ObjectMetadata:
692 response = self._s3_client.head_object(Bucket=bucket, Key=key)
693 user_metadata = response.get("Metadata")
694 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None
695
696 return ObjectMetadata(
697 key=path,
698 type="file",
699 content_length=response["ContentLength"],
700 content_type=response.get("ContentType"),
701 last_modified=response.get("LastModified", AWARE_DATETIME_MIN),
702 etag=response["ETag"].strip('"') if "ETag" in response else None,
703 storage_class=response.get("StorageClass"),
704 metadata=user_metadata,
705 symlink_target=symlink_target,
706 )
707
708 try:
709 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
710 except FileNotFoundError as error:
711 if strict:
712 # If the object does not exist on the given path, we will append a trailing slash and
713 # check if the path is a directory.
714 path = self._append_delimiter(path)
715 if self._is_dir(path):
716 return ObjectMetadata(
717 key=path,
718 type="directory",
719 content_length=0,
720 last_modified=AWARE_DATETIME_MIN,
721 )
722 raise error
723
724 def _list_objects(
725 self,
726 path: str,
727 start_after: Optional[str] = None,
728 end_at: Optional[str] = None,
729 include_directories: bool = False,
730 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW,
731 ) -> Iterator[ObjectMetadata]:
732 bucket, prefix = split_path(path)
733
734 # Get the prefix of the start_after and end_at paths relative to the bucket.
735 if start_after:
736 _, start_after = split_path(start_after)
737 if end_at:
738 _, end_at = split_path(end_at)
739
740 def _invoke_api() -> Iterator[ObjectMetadata]:
741 paginator = self._s3_client.get_paginator("list_objects_v2")
742 if include_directories:
743 page_iterator = paginator.paginate(
744 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
745 )
746 else:
747 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
748
749 for page in page_iterator:
750 for item in page.get("CommonPrefixes", []):
751 prefix_key = item["Prefix"].rstrip("/")
752 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes
753 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
754 yield ObjectMetadata(
755 key=os.path.join(bucket, prefix_key),
756 type="directory",
757 content_length=0,
758 last_modified=AWARE_DATETIME_MIN,
759 )
760 elif end_at is not None and end_at < prefix_key:
761 return
762
763 # S3 guarantees lexicographical order for general purpose buckets (for
764 # normal S3) but not directory buckets (for S3 Express One Zone).
765 for response_object in page.get("Contents", []):
766 key = response_object["Key"]
767 if end_at is None or key <= end_at:
768 if key.endswith("/"):
769 if include_directories:
770 yield ObjectMetadata(
771 key=os.path.join(bucket, key.rstrip("/")),
772 type="directory",
773 content_length=0,
774 last_modified=response_object["LastModified"],
775 )
776 else:
777 symlink_target = None
778 if response_object.get("Size", 0) == 0:
779 try:
780 meta = self._get_object_metadata(os.path.join(bucket, key))
781 symlink_target = meta.symlink_target
782 except Exception:
783 symlink_target = None
784 yield ObjectMetadata(
785 key=os.path.join(bucket, key),
786 type="file",
787 content_length=response_object["Size"],
788 last_modified=response_object["LastModified"],
789 etag=response_object["ETag"].strip('"'),
790 storage_class=response_object.get("StorageClass"),
791 symlink_target=symlink_target,
792 )
793 else:
794 return
795
796 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
797
798 @property
799 def supports_parallel_listing(self) -> bool:
800 """
801 S3 supports parallel listing via delimiter-based prefix discovery.
802
803 Note: Directory bucket handling is done in list_objects_recursive().
804 """
805 return True
806
[docs]
807 def list_objects_recursive(
808 self,
809 path: str = "",
810 start_after: Optional[str] = None,
811 end_at: Optional[str] = None,
812 max_workers: int = 32,
813 look_ahead: int = 2,
814 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW,
815 ) -> Iterator[ObjectMetadata]:
816 """
817 List all objects recursively using parallel prefix discovery for improved performance.
818
819 For S3, uses the Rust client's list_recursive when available for maximum performance.
820 Falls back to Python implementation otherwise.
821
822 Returns files only (no directories), in lexicographic order.
823
824 :param symlink_handling: How to handle symbolic links during listing.
825 """
826 if (start_after is not None) and (end_at is not None) and not (start_after < end_at):
827 raise ValueError(f"start_after ({start_after}) must be before end_at ({end_at})!")
828
829 full_path = self._prepend_base_path(path)
830 bucket, prefix = split_path(full_path)
831
832 if self._is_directory_bucket(bucket):
833 yield from self.list_objects(
834 path, start_after, end_at, include_directories=False, symlink_handling=symlink_handling
835 )
836 return
837
838 if self._rust_client:
839 yield from self._emit_metrics(
840 operation=BaseStorageProvider._Operation.LIST,
841 f=lambda: self._list_objects_recursive_rust(path, full_path, bucket, start_after, end_at, max_workers),
842 )
843 else:
844 yield from super().list_objects_recursive(
845 path, start_after, end_at, max_workers, look_ahead, symlink_handling
846 )
847
848 def _list_objects_recursive_rust(
849 self,
850 path: str,
851 full_path: str,
852 bucket: str,
853 start_after: Optional[str],
854 end_at: Optional[str],
855 max_workers: int,
856 ) -> Iterator[ObjectMetadata]:
857 """
858 Use Rust client's list_recursive for parallel listing.
859
860 The Rust client already handles parallel listing internally.
861 Returns files only in lexicographic order.
862 """
863 _, prefix = split_path(full_path)
864
865 def _invoke_api() -> Iterator[ObjectMetadata]:
866 result = run_async_rust_client_method(
867 self._rust_client,
868 "list_recursive",
869 [prefix] if prefix else [""],
870 max_concurrency=max_workers,
871 )
872
873 start_after_full = self._prepend_base_path(start_after) if start_after else None
874 end_at_full = self._prepend_base_path(end_at) if end_at else None
875
876 for obj in result.objects:
877 full_key = os.path.join(bucket, obj.key)
878
879 if start_after_full and full_key <= start_after_full:
880 continue
881 if end_at_full and full_key > end_at_full:
882 break
883
884 relative_key = full_key.removeprefix(self._base_path).lstrip("/")
885
886 yield ObjectMetadata(
887 key=relative_key,
888 content_length=obj.content_length,
889 last_modified=dateutil_parse(obj.last_modified),
890 type="file" if obj.object_type == "object" else obj.object_type,
891 etag=obj.etag,
892 )
893
894 yield from self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
895
896 def _upload_file(
897 self,
898 remote_path: str,
899 f: Union[str, IO],
900 attributes: Optional[dict[str, str]] = None,
901 content_type: Optional[str] = None,
902 ) -> int:
903 file_size: int = 0
904
905 if isinstance(f, str):
906 bucket, key = split_path(remote_path)
907 file_size = os.path.getsize(f)
908
909 # Upload small files
910 if file_size <= self._multipart_threshold:
911 if self._rust_client and not attributes and not content_type:
912 run_async_rust_client_method(self._rust_client, "upload", f, key)
913 else:
914 with open(f, "rb") as fp:
915 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
916 return file_size
917
918 # Upload large files using TransferConfig
919 def _invoke_api() -> int:
920 extra_args = {}
921 if content_type:
922 extra_args["ContentType"] = content_type
923 if self._is_directory_bucket(bucket):
924 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
925 validated_attributes = validate_attributes(attributes)
926 if validated_attributes:
927 extra_args["Metadata"] = validated_attributes
928 if self._rust_client and not extra_args:
929 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
930 else:
931 if self._checksum_algorithm:
932 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm
933 self._s3_client.upload_file(
934 Filename=f,
935 Bucket=bucket,
936 Key=key,
937 Config=self._transfer_config,
938 ExtraArgs=extra_args,
939 )
940
941 return file_size
942
943 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
944 else:
945 # Upload small files
946 f.seek(0, io.SEEK_END)
947 file_size = f.tell()
948 f.seek(0)
949
950 if file_size <= self._multipart_threshold:
951 if isinstance(f, io.StringIO):
952 self._put_object(
953 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
954 )
955 else:
956 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
957 return file_size
958
959 # Upload large files using TransferConfig
960 bucket, key = split_path(remote_path)
961
962 def _invoke_api() -> int:
963 extra_args = {}
964 if content_type:
965 extra_args["ContentType"] = content_type
966 if self._is_directory_bucket(bucket):
967 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
968 validated_attributes = validate_attributes(attributes)
969 if validated_attributes:
970 extra_args["Metadata"] = validated_attributes
971
972 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
973 data = f.getbuffer()
974 try:
975 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
976 finally:
977 data.release()
978 else:
979 if self._checksum_algorithm:
980 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm
981 self._s3_client.upload_fileobj(
982 Fileobj=f,
983 Bucket=bucket,
984 Key=key,
985 Config=self._transfer_config,
986 ExtraArgs=extra_args,
987 )
988
989 return file_size
990
991 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
992
993 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
994 if metadata is None:
995 metadata = self._get_object_metadata(remote_path)
996
997 if isinstance(f, str):
998 bucket, key = split_path(remote_path)
999 if os.path.dirname(f):
1000 safe_makedirs(os.path.dirname(f))
1001
1002 # Download small files
1003 if metadata.content_length <= self._multipart_threshold:
1004 if self._rust_client:
1005 run_async_rust_client_method(self._rust_client, "download", key, f)
1006 else:
1007 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
1008 temp_file_path = fp.name
1009 fp.write(self._get_object(remote_path))
1010 os.rename(src=temp_file_path, dst=f)
1011 return metadata.content_length
1012
1013 # Download large files using TransferConfig
1014 def _invoke_api() -> int:
1015 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
1016 temp_file_path = fp.name
1017 if self._rust_client:
1018 run_async_rust_client_method(
1019 self._rust_client, "download_multipart_to_file", key, temp_file_path
1020 )
1021 else:
1022 self._s3_client.download_fileobj(
1023 Bucket=bucket,
1024 Key=key,
1025 Fileobj=fp,
1026 Config=self._transfer_config,
1027 )
1028
1029 os.rename(src=temp_file_path, dst=f)
1030
1031 return metadata.content_length
1032
1033 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
1034 else:
1035 # Download small files
1036 if metadata.content_length <= self._multipart_threshold:
1037 response = self._get_object(remote_path)
1038 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol,
1039 # so we need to check whether `.decode()` is available.
1040 if isinstance(f, io.StringIO):
1041 if hasattr(response, "decode"):
1042 f.write(response.decode("utf-8"))
1043 else:
1044 f.write(codecs.decode(memoryview(response), "utf-8"))
1045 else:
1046 f.write(response)
1047 return metadata.content_length
1048
1049 # Download large files using TransferConfig
1050 bucket, key = split_path(remote_path)
1051
1052 def _invoke_api() -> int:
1053 self._s3_client.download_fileobj(
1054 Bucket=bucket,
1055 Key=key,
1056 Fileobj=f,
1057 Config=self._transfer_config,
1058 )
1059
1060 return metadata.content_length
1061
1062 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
1063
1064 def _generate_presigned_url(
1065 self,
1066 path: str,
1067 *,
1068 method: str = "GET",
1069 signer_type: Optional[SignerType] = None,
1070 signer_options: Optional[dict[str, Any]] = None,
1071 ) -> str:
1072 options = signer_options or {}
1073 bucket, key = split_path(path)
1074
1075 if signer_type is None or signer_type == SignerType.S3:
1076 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN))
1077 cache_key: tuple = (SignerType.S3, bucket, expires_in)
1078 if cache_key not in self._signer_cache:
1079 self._signer_cache[cache_key] = S3URLSigner(self._s3_client, bucket, expires_in=expires_in)
1080 elif signer_type == SignerType.CLOUDFRONT:
1081 cache_key = (SignerType.CLOUDFRONT, frozenset(options.items()))
1082 if cache_key not in self._signer_cache:
1083 self._signer_cache[cache_key] = CloudFrontURLSigner(**options)
1084 else:
1085 raise ValueError(f"Unsupported signer type for S3 provider: {signer_type!r}")
1086
1087 return self._signer_cache[cache_key].generate_presigned_url(key, method=method)