1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29from dateutil.parser import parse as dateutil_parse
30
31from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
32
33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
34from ..rust_utils import parse_retry_config, run_async_rust_client_method
35from ..signers import CloudFrontURLSigner, URLSigner
36from ..telemetry import Telemetry
37from ..types import (
38 AWARE_DATETIME_MIN,
39 Credentials,
40 CredentialsProvider,
41 ObjectMetadata,
42 PreconditionFailedError,
43 Range,
44 RetryableError,
45 SignerType,
46)
47from ..utils import (
48 get_available_cpu_count,
49 safe_makedirs,
50 split_path,
51 validate_attributes,
52)
53from .base import BaseStorageProvider
54
55_T = TypeVar("_T")
56
57# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
58MAX_POOL_CONNECTIONS = max(
59 64,
60 get_available_cpu_count(),
61 int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
62)
63
64MiB = 1024 * 1024
65
66# Python and Rust share the same multipart_threshold to keep the code simple.
67MULTIPART_THRESHOLD = 64 * MiB
68MULTIPART_CHUNKSIZE = 32 * MiB
69IO_CHUNKSIZE = 32 * MiB
70PYTHON_MAX_CONCURRENCY = 8
71
72PROVIDER = "s3"
73
74EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
75
76
[docs]
77class StaticS3CredentialsProvider(CredentialsProvider):
78 """
79 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
80 """
81
82 _access_key: str
83 _secret_key: str
84 _session_token: Optional[str]
85
86 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
87 """
88 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
89 session token.
90
91 :param access_key: The access key for S3 authentication.
92 :param secret_key: The secret key for S3 authentication.
93 :param session_token: An optional session token for temporary credentials.
94 """
95 self._access_key = access_key
96 self._secret_key = secret_key
97 self._session_token = session_token
98
[docs]
99 def get_credentials(self) -> Credentials:
100 return Credentials(
101 access_key=self._access_key,
102 secret_key=self._secret_key,
103 token=self._session_token,
104 expiration=None,
105 )
106
[docs]
107 def refresh_credentials(self) -> None:
108 pass
109
110
111DEFAULT_PRESIGN_EXPIRES_IN = 3600
112
113_S3_METHOD_MAPPING: dict[str, str] = {
114 "GET": "get_object",
115 "PUT": "put_object",
116}
117
118
[docs]
119class S3URLSigner(URLSigner):
120 """Generates pre-signed URLs using the boto3 S3 client.
121
122 When the underlying credentials are temporary (STS, IAM role, EC2 instance
123 profile), the effective URL lifetime is the **shorter** of ``expires_in``
124 and the remaining credential lifetime — boto3 will not warn if the
125 credential expires before ``expires_in``.
126
127 See https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html
128 """
129
130 def __init__(self, s3_client: Any, bucket: str, expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN) -> None:
131 self._s3_client = s3_client
132 self._bucket = bucket
133 self._expires_in = expires_in
134
[docs]
135 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str:
136 client_method = _S3_METHOD_MAPPING.get(method.upper())
137 if client_method is None:
138 raise ValueError(f"Unsupported method for S3 presigning: {method!r}")
139 return self._s3_client.generate_presigned_url(
140 ClientMethod=client_method,
141 Params={"Bucket": self._bucket, "Key": path},
142 ExpiresIn=self._expires_in,
143 )
144
145
[docs]
146class S3StorageProvider(BaseStorageProvider):
147 """
148 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
149 """
150
151 def __init__(
152 self,
153 region_name: str = "",
154 endpoint_url: str = "",
155 base_path: str = "",
156 credentials_provider: Optional[CredentialsProvider] = None,
157 profile_name: Optional[str] = None,
158 config_dict: Optional[dict[str, Any]] = None,
159 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
160 verify: Optional[Union[bool, str]] = None,
161 **kwargs: Any,
162 ) -> None:
163 """
164 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
165
166 :param region_name: The AWS region where the S3 bucket is located.
167 :param endpoint_url: The custom endpoint URL for the S3 service.
168 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
169 :param credentials_provider: The provider to retrieve S3 credentials.
170 :param profile_name: AWS shared configuration + credentials files profile. For :py:class:`boto3.session.Session`. Ignored if ``credentials_provider`` is set.
171 :param config_dict: Resolved MSC config.
172 :param telemetry_provider: A function that provides a telemetry instance.
173 :param verify: Controls SSL certificate verification.
174 Can be ``True`` (verify using system CA bundle, default), ``False`` (skip verification), or a string path to a custom CA certificate bundle.
175 :param request_checksum_calculation: For :py:class:`botocore.config.Config`.
176 When the underlying S3 client should calculate request checksums.
177 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
178 :param response_checksum_validation: For :py:class:`botocore.config.Config`.
179 When the underlying S3 client should validate response checksums.
180 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
181 :param max_pool_connections: For :py:class:`botocore.config.Config`.
182 The maximum number of connections to keep in a connection pool.
183 :param connect_timeout: For :py:class:`botocore.config.Config`.
184 The time in seconds till a timeout exception is thrown when attempting to make a connection.
185 :param read_timeout: For :py:class:`botocore.config.Config`.
186 The time in seconds till a timeout exception is thrown when attempting to read from a connection.
187 :param retries: For :py:class:`botocore.config.Config`.
188 A dictionary for configuration related to retry behavior.
189 :param s3: For :py:class:`botocore.config.Config`.
190 A dictionary of S3 specific configurations.
191 """
192 super().__init__(
193 base_path=base_path,
194 provider_name=PROVIDER,
195 config_dict=config_dict,
196 telemetry_provider=telemetry_provider,
197 )
198
199 self._region_name = region_name
200 self._endpoint_url = endpoint_url
201 self._credentials_provider = credentials_provider
202 self._profile_name = profile_name
203 self._verify = verify
204
205 self._signature_version = kwargs.get("signature_version", "s3v4")
206 self._s3_client = self._create_s3_client(
207 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
208 response_checksum_validation=kwargs.get("response_checksum_validation"),
209 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
210 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT),
211 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT),
212 retries=kwargs.get("retries"),
213 s3=kwargs.get("s3"),
214 )
215 self._transfer_config = TransferConfig(
216 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
217 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
218 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
219 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
220 use_threads=True,
221 )
222 self._multipart_threshold = self._transfer_config.multipart_threshold
223
224 self._signer_cache: dict[tuple, URLSigner] = {}
225
226 self._rust_client = None
227 if "rust_client" in kwargs:
228 # Inherit the rust client options from the kwargs
229 rust_client_options = kwargs["rust_client"]
230 if "max_pool_connections" in kwargs:
231 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
232 if "max_concurrency" in kwargs:
233 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
234 if "multipart_chunksize" in kwargs:
235 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
236 if "read_timeout" in kwargs:
237 rust_client_options["read_timeout"] = kwargs["read_timeout"]
238 if "connect_timeout" in kwargs:
239 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
240 if self._signature_version == "UNSIGNED":
241 rust_client_options["skip_signature"] = True
242 self._rust_client = self._create_rust_client(rust_client_options)
243
244 def _is_directory_bucket(self, bucket: str) -> bool:
245 """
246 Determines if the bucket is a directory bucket based on bucket name.
247 """
248 # S3 Express buckets have a specific naming convention
249 return "--x-s3" in bucket
250
251 def _create_s3_client(
252 self,
253 request_checksum_calculation: Optional[str] = None,
254 response_checksum_validation: Optional[str] = None,
255 max_pool_connections: int = MAX_POOL_CONNECTIONS,
256 connect_timeout: Union[float, int, None] = None,
257 read_timeout: Union[float, int, None] = None,
258 retries: Optional[dict[str, Any]] = None,
259 s3: Optional[dict[str, Any]] = None,
260 ):
261 """
262 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
263
264 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
265 :param response_checksum_validation: For :py:class:`botocore.config.Config`. When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
266 :param max_pool_connections: For :py:class:`botocore.config.Config`. The maximum number of connections to keep in a connection pool.
267 :param connect_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to make a connection.
268 :param read_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to read from a connection.
269 :param retries: For :py:class:`botocore.config.Config`. A dictionary for configuration related to retry behavior.
270 :param s3: For :py:class:`botocore.config.Config`. A dictionary of S3 specific configurations.
271
272 :return: The configured S3 client.
273 """
274 options = {
275 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
276 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
277 max_pool_connections=max_pool_connections,
278 connect_timeout=connect_timeout,
279 read_timeout=read_timeout,
280 retries=retries or {"mode": "standard"},
281 request_checksum_calculation=request_checksum_calculation,
282 response_checksum_validation=response_checksum_validation,
283 s3=s3,
284 ),
285 }
286
287 if self._region_name:
288 options["region_name"] = self._region_name
289
290 if self._endpoint_url:
291 options["endpoint_url"] = self._endpoint_url
292
293 if self._verify is not None:
294 options["verify"] = self._verify
295
296 if self._credentials_provider:
297 creds = self._fetch_credentials()
298 if "expiry_time" in creds and creds["expiry_time"]:
299 # Use RefreshableCredentials if expiry_time provided.
300 refreshable_credentials = RefreshableCredentials.create_from_metadata(
301 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
302 )
303
304 botocore_session = get_session()
305 botocore_session._credentials = refreshable_credentials
306
307 boto3_session = boto3.Session(botocore_session=botocore_session)
308
309 return boto3_session.client("s3", **options)
310 else:
311 # Add static credentials to the options dictionary
312 options["aws_access_key_id"] = creds["access_key"]
313 options["aws_secret_access_key"] = creds["secret_key"]
314 if creds["token"]:
315 options["aws_session_token"] = creds["token"]
316
317 if self._signature_version:
318 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
319 signature_version=botocore.UNSIGNED
320 if self._signature_version == "UNSIGNED"
321 else self._signature_version
322 )
323 options["config"] = options["config"].merge(signature_config)
324
325 # Fallback to standard credential chain.
326 session = boto3.Session(profile_name=self._profile_name)
327 return session.client("s3", **options)
328
329 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
330 """
331 Creates and configures the rust client, using refreshable credentials if possible.
332 """
333 configs = dict(rust_client_options) if rust_client_options else {}
334
335 # Extract and parse retry configuration
336 retry_config = parse_retry_config(configs)
337
338 if self._region_name and "region_name" not in configs:
339 configs["region_name"] = self._region_name
340
341 if self._endpoint_url and "endpoint_url" not in configs:
342 configs["endpoint_url"] = self._endpoint_url
343
344 if self._profile_name and "profile_name" not in configs:
345 configs["profile_name"] = self._profile_name
346
347 # If the user specifies a bucket, use it. Otherwise, use the base path.
348 if "bucket" not in configs:
349 bucket, _ = split_path(self._base_path)
350 configs["bucket"] = bucket
351
352 if "max_pool_connections" not in configs:
353 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
354
355 return RustClient(
356 provider=PROVIDER,
357 configs=configs,
358 credentials_provider=self._credentials_provider,
359 retry=retry_config,
360 )
361
362 def _fetch_credentials(self) -> dict:
363 """
364 Refreshes the S3 client if the current credentials are expired.
365 """
366 if not self._credentials_provider:
367 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
368 self._credentials_provider.refresh_credentials()
369 credentials = self._credentials_provider.get_credentials()
370 return {
371 "access_key": credentials.access_key,
372 "secret_key": credentials.secret_key,
373 "token": credentials.token,
374 "expiry_time": credentials.expiration,
375 }
376
377 def _translate_errors(
378 self,
379 func: Callable[[], _T],
380 operation: str,
381 bucket: str,
382 key: str,
383 ) -> _T:
384 """
385 Translates errors like timeouts and client errors.
386
387 :param func: The function that performs the actual S3 operation.
388 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
389 :param bucket: The name of the S3 bucket involved in the operation.
390 :param key: The key of the object within the S3 bucket.
391
392 :return: The result of the S3 operation, typically the return value of the `func` callable.
393 """
394 try:
395 return func()
396 except ClientError as error:
397 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
398 request_id = error.response["ResponseMetadata"].get("RequestId")
399 host_id = error.response["ResponseMetadata"].get("HostId")
400 error_code = error.response["Error"]["Code"]
401 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
402
403 if status_code == 404:
404 if error_code == "NoSuchUpload":
405 error_message = error.response["Error"]["Message"]
406 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
407 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
408 elif status_code == 412: # Precondition Failed
409 raise PreconditionFailedError(
410 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
411 ) from error
412 elif status_code == 429:
413 raise RetryableError(
414 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
415 ) from error
416 elif status_code == 503:
417 raise RetryableError(
418 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
419 ) from error
420 elif status_code == 501:
421 raise NotImplementedError(
422 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
423 ) from error
424 elif status_code == 408:
425 # 408 Request Timeout is from Google Cloud Storage
426 raise RetryableError(
427 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}"
428 ) from error
429 else:
430 raise RuntimeError(
431 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
432 f"error_type: {type(error).__name__}"
433 ) from error
434 except RustClientError as error:
435 message = error.args[0]
436 status_code = error.args[1]
437 if status_code == 404:
438 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
439 elif status_code == 403:
440 raise PermissionError(
441 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
442 ) from error
443 else:
444 raise RetryableError(
445 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
446 ) from error
447 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
448 raise RetryableError(
449 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
450 f"error_type: {type(error).__name__}"
451 ) from error
452 except RustRetryableError as error:
453 raise RetryableError(
454 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
455 f"error_type: {type(error).__name__}"
456 ) from error
457 except Exception as error:
458 raise RuntimeError(
459 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
460 ) from error
461
462 def _put_object(
463 self,
464 path: str,
465 body: bytes,
466 if_match: Optional[str] = None,
467 if_none_match: Optional[str] = None,
468 attributes: Optional[dict[str, str]] = None,
469 content_type: Optional[str] = None,
470 ) -> int:
471 """
472 Uploads an object to the specified S3 path.
473
474 :param path: The S3 path where the object will be uploaded.
475 :param body: The content of the object as bytes.
476 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
477 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
478 :param attributes: Optional attributes to attach to the object.
479 :param content_type: Optional Content-Type header value.
480 """
481 bucket, key = split_path(path)
482
483 def _invoke_api() -> int:
484 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
485 if content_type:
486 kwargs["ContentType"] = content_type
487 if self._is_directory_bucket(bucket):
488 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
489 if if_match:
490 kwargs["IfMatch"] = if_match
491 if if_none_match:
492 kwargs["IfNoneMatch"] = if_none_match
493 validated_attributes = validate_attributes(attributes)
494 if validated_attributes:
495 kwargs["Metadata"] = validated_attributes
496
497 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
498 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
499 if (
500 self._rust_client
501 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
502 and not path.endswith("/")
503 and all(key not in kwargs for key in rust_unsupported_feature_keys)
504 ):
505 run_async_rust_client_method(self._rust_client, "put", key, body)
506 else:
507 self._s3_client.put_object(**kwargs)
508
509 return len(body)
510
511 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
512
513 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
514 bucket, key = split_path(path)
515
516 def _invoke_api() -> bytes:
517 if byte_range:
518 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
519 if self._rust_client:
520 response = run_async_rust_client_method(
521 self._rust_client,
522 "get",
523 key,
524 byte_range,
525 )
526 return response
527 else:
528 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
529 else:
530 if self._rust_client:
531 response = run_async_rust_client_method(self._rust_client, "get", key)
532 return response
533 else:
534 response = self._s3_client.get_object(Bucket=bucket, Key=key)
535
536 return response["Body"].read()
537
538 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
539
540 def _copy_object(self, src_path: str, dest_path: str) -> int:
541 src_bucket, src_key = split_path(src_path)
542 dest_bucket, dest_key = split_path(dest_path)
543
544 src_object = self._get_object_metadata(src_path)
545
546 def _invoke_api() -> int:
547 self._s3_client.copy(
548 CopySource={"Bucket": src_bucket, "Key": src_key},
549 Bucket=dest_bucket,
550 Key=dest_key,
551 Config=self._transfer_config,
552 )
553
554 return src_object.content_length
555
556 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
557
558 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
559 bucket, key = split_path(path)
560
561 def _invoke_api() -> None:
562 # Delete conditionally when if_match (etag) is provided; otherwise delete unconditionally
563 if if_match:
564 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
565 else:
566 self._s3_client.delete_object(Bucket=bucket, Key=key)
567
568 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
569
570 def _delete_objects(self, paths: list[str]) -> None:
571 if not paths:
572 return
573
574 by_bucket: dict[str, list[str]] = {}
575 for p in paths:
576 bucket, key = split_path(p)
577 by_bucket.setdefault(bucket, []).append(key)
578
579 S3_BATCH_LIMIT = 1000
580
581 def _invoke_api() -> None:
582 all_errors: list[str] = []
583 for bucket, keys in by_bucket.items():
584 for i in range(0, len(keys), S3_BATCH_LIMIT):
585 chunk = keys[i : i + S3_BATCH_LIMIT]
586 response = self._s3_client.delete_objects(
587 Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}
588 )
589 errors = response.get("Errors") or []
590 for e in errors:
591 all_errors.append(f"{bucket}/{e.get('Key', '?')}: {e.get('Code', '')} {e.get('Message', '')}")
592 if all_errors:
593 raise RuntimeError(f"DeleteObjects reported errors: {'; '.join(all_errors)}")
594
595 bucket_desc = "(" + "|".join(by_bucket) + ")"
596 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)"
597 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc)
598
599 def _is_dir(self, path: str) -> bool:
600 # Ensure the path ends with '/' to mimic a directory
601 path = self._append_delimiter(path)
602
603 bucket, key = split_path(path)
604
605 def _invoke_api() -> bool:
606 # List objects with the given prefix
607 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
608
609 # Check if there are any contents or common prefixes
610 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
611
612 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
613
614 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
615 bucket, key = split_path(path)
616 if path.endswith("/") or (bucket and not key):
617 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
618 # which metadata is not guaranteed to exist for cases such as
619 # "virtual prefix" that was never explicitly created.
620 if self._is_dir(path):
621 return ObjectMetadata(
622 key=path,
623 type="directory",
624 content_length=0,
625 last_modified=AWARE_DATETIME_MIN,
626 )
627 else:
628 raise FileNotFoundError(f"Directory {path} does not exist.")
629 else:
630
631 def _invoke_api() -> ObjectMetadata:
632 response = self._s3_client.head_object(Bucket=bucket, Key=key)
633
634 return ObjectMetadata(
635 key=path,
636 type="file",
637 content_length=response["ContentLength"],
638 content_type=response.get("ContentType"),
639 last_modified=response["LastModified"],
640 etag=response["ETag"].strip('"') if "ETag" in response else None,
641 storage_class=response.get("StorageClass"),
642 metadata=response.get("Metadata"),
643 )
644
645 try:
646 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
647 except FileNotFoundError as error:
648 if strict:
649 # If the object does not exist on the given path, we will append a trailing slash and
650 # check if the path is a directory.
651 path = self._append_delimiter(path)
652 if self._is_dir(path):
653 return ObjectMetadata(
654 key=path,
655 type="directory",
656 content_length=0,
657 last_modified=AWARE_DATETIME_MIN,
658 )
659 raise error
660
661 def _list_objects(
662 self,
663 path: str,
664 start_after: Optional[str] = None,
665 end_at: Optional[str] = None,
666 include_directories: bool = False,
667 follow_symlinks: bool = True,
668 ) -> Iterator[ObjectMetadata]:
669 bucket, prefix = split_path(path)
670
671 # Get the prefix of the start_after and end_at paths relative to the bucket.
672 if start_after:
673 _, start_after = split_path(start_after)
674 if end_at:
675 _, end_at = split_path(end_at)
676
677 def _invoke_api() -> Iterator[ObjectMetadata]:
678 paginator = self._s3_client.get_paginator("list_objects_v2")
679 if include_directories:
680 page_iterator = paginator.paginate(
681 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
682 )
683 else:
684 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
685
686 for page in page_iterator:
687 for item in page.get("CommonPrefixes", []):
688 prefix_key = item["Prefix"].rstrip("/")
689 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes
690 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
691 yield ObjectMetadata(
692 key=os.path.join(bucket, prefix_key),
693 type="directory",
694 content_length=0,
695 last_modified=AWARE_DATETIME_MIN,
696 )
697 elif end_at is not None and end_at < prefix_key:
698 return
699
700 # S3 guarantees lexicographical order for general purpose buckets (for
701 # normal S3) but not directory buckets (for S3 Express One Zone).
702 for response_object in page.get("Contents", []):
703 key = response_object["Key"]
704 if end_at is None or key <= end_at:
705 if key.endswith("/"):
706 if include_directories:
707 yield ObjectMetadata(
708 key=os.path.join(bucket, key.rstrip("/")),
709 type="directory",
710 content_length=0,
711 last_modified=response_object["LastModified"],
712 )
713 else:
714 yield ObjectMetadata(
715 key=os.path.join(bucket, key),
716 type="file",
717 content_length=response_object["Size"],
718 last_modified=response_object["LastModified"],
719 etag=response_object["ETag"].strip('"'),
720 storage_class=response_object.get("StorageClass"), # Pass storage_class
721 )
722 else:
723 return
724
725 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
726
727 @property
728 def supports_parallel_listing(self) -> bool:
729 """
730 S3 supports parallel listing via delimiter-based prefix discovery.
731
732 Note: Directory bucket handling is done in list_objects_recursive().
733 """
734 return True
735
[docs]
736 def list_objects_recursive(
737 self,
738 path: str = "",
739 start_after: Optional[str] = None,
740 end_at: Optional[str] = None,
741 max_workers: int = 32,
742 look_ahead: int = 2,
743 follow_symlinks: bool = True,
744 ) -> Iterator[ObjectMetadata]:
745 """
746 List all objects recursively using parallel prefix discovery for improved performance.
747
748 For S3, uses the Rust client's list_recursive when available for maximum performance.
749 Falls back to Python implementation otherwise.
750
751 Returns files only (no directories), in lexicographic order.
752
753 :param follow_symlinks: Whether to follow symbolic links (POSIX providers only).
754 """
755 if (start_after is not None) and (end_at is not None) and not (start_after < end_at):
756 raise ValueError(f"start_after ({start_after}) must be before end_at ({end_at})!")
757
758 full_path = self._prepend_base_path(path)
759 bucket, prefix = split_path(full_path)
760
761 if self._is_directory_bucket(bucket):
762 yield from self.list_objects(
763 path, start_after, end_at, include_directories=False, follow_symlinks=follow_symlinks
764 )
765 return
766
767 if self._rust_client:
768 yield from self._emit_metrics(
769 operation=BaseStorageProvider._Operation.LIST,
770 f=lambda: self._list_objects_recursive_rust(path, full_path, bucket, start_after, end_at, max_workers),
771 )
772 else:
773 yield from super().list_objects_recursive(
774 path, start_after, end_at, max_workers, look_ahead, follow_symlinks
775 )
776
777 def _list_objects_recursive_rust(
778 self,
779 path: str,
780 full_path: str,
781 bucket: str,
782 start_after: Optional[str],
783 end_at: Optional[str],
784 max_workers: int,
785 ) -> Iterator[ObjectMetadata]:
786 """
787 Use Rust client's list_recursive for parallel listing.
788
789 The Rust client already handles parallel listing internally.
790 Returns files only in lexicographic order.
791 """
792 _, prefix = split_path(full_path)
793
794 def _invoke_api() -> Iterator[ObjectMetadata]:
795 result = run_async_rust_client_method(
796 self._rust_client,
797 "list_recursive",
798 [prefix] if prefix else [""],
799 max_concurrency=max_workers,
800 )
801
802 start_after_full = self._prepend_base_path(start_after) if start_after else None
803 end_at_full = self._prepend_base_path(end_at) if end_at else None
804
805 for obj in result.objects:
806 full_key = os.path.join(bucket, obj.key)
807
808 if start_after_full and full_key <= start_after_full:
809 continue
810 if end_at_full and full_key > end_at_full:
811 break
812
813 relative_key = full_key.removeprefix(self._base_path).lstrip("/")
814
815 yield ObjectMetadata(
816 key=relative_key,
817 content_length=obj.content_length,
818 last_modified=dateutil_parse(obj.last_modified),
819 type="file" if obj.object_type == "object" else obj.object_type,
820 etag=obj.etag,
821 )
822
823 yield from self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
824
825 def _upload_file(
826 self,
827 remote_path: str,
828 f: Union[str, IO],
829 attributes: Optional[dict[str, str]] = None,
830 content_type: Optional[str] = None,
831 ) -> int:
832 file_size: int = 0
833
834 if isinstance(f, str):
835 bucket, key = split_path(remote_path)
836 file_size = os.path.getsize(f)
837
838 # Upload small files
839 if file_size <= self._multipart_threshold:
840 if self._rust_client and not attributes and not content_type:
841 run_async_rust_client_method(self._rust_client, "upload", f, key)
842 else:
843 with open(f, "rb") as fp:
844 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
845 return file_size
846
847 # Upload large files using TransferConfig
848 def _invoke_api() -> int:
849 extra_args = {}
850 if content_type:
851 extra_args["ContentType"] = content_type
852 if self._is_directory_bucket(bucket):
853 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
854 validated_attributes = validate_attributes(attributes)
855 if validated_attributes:
856 extra_args["Metadata"] = validated_attributes
857 if self._rust_client and not extra_args:
858 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
859 else:
860 self._s3_client.upload_file(
861 Filename=f,
862 Bucket=bucket,
863 Key=key,
864 Config=self._transfer_config,
865 ExtraArgs=extra_args,
866 )
867
868 return file_size
869
870 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
871 else:
872 # Upload small files
873 f.seek(0, io.SEEK_END)
874 file_size = f.tell()
875 f.seek(0)
876
877 if file_size <= self._multipart_threshold:
878 if isinstance(f, io.StringIO):
879 self._put_object(
880 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
881 )
882 else:
883 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
884 return file_size
885
886 # Upload large files using TransferConfig
887 bucket, key = split_path(remote_path)
888
889 def _invoke_api() -> int:
890 extra_args = {}
891 if content_type:
892 extra_args["ContentType"] = content_type
893 if self._is_directory_bucket(bucket):
894 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
895 validated_attributes = validate_attributes(attributes)
896 if validated_attributes:
897 extra_args["Metadata"] = validated_attributes
898
899 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
900 data = f.getbuffer()
901 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
902 else:
903 self._s3_client.upload_fileobj(
904 Fileobj=f,
905 Bucket=bucket,
906 Key=key,
907 Config=self._transfer_config,
908 ExtraArgs=extra_args,
909 )
910
911 return file_size
912
913 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
914
915 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
916 if metadata is None:
917 metadata = self._get_object_metadata(remote_path)
918
919 if isinstance(f, str):
920 bucket, key = split_path(remote_path)
921 if os.path.dirname(f):
922 safe_makedirs(os.path.dirname(f))
923
924 # Download small files
925 if metadata.content_length <= self._multipart_threshold:
926 if self._rust_client:
927 run_async_rust_client_method(self._rust_client, "download", key, f)
928 else:
929 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
930 temp_file_path = fp.name
931 fp.write(self._get_object(remote_path))
932 os.rename(src=temp_file_path, dst=f)
933 return metadata.content_length
934
935 # Download large files using TransferConfig
936 def _invoke_api() -> int:
937 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
938 temp_file_path = fp.name
939 if self._rust_client:
940 run_async_rust_client_method(
941 self._rust_client, "download_multipart_to_file", key, temp_file_path
942 )
943 else:
944 self._s3_client.download_fileobj(
945 Bucket=bucket,
946 Key=key,
947 Fileobj=fp,
948 Config=self._transfer_config,
949 )
950
951 os.rename(src=temp_file_path, dst=f)
952
953 return metadata.content_length
954
955 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
956 else:
957 # Download small files
958 if metadata.content_length <= self._multipart_threshold:
959 response = self._get_object(remote_path)
960 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol,
961 # so we need to check whether `.decode()` is available.
962 if isinstance(f, io.StringIO):
963 if hasattr(response, "decode"):
964 f.write(response.decode("utf-8"))
965 else:
966 f.write(codecs.decode(memoryview(response), "utf-8"))
967 else:
968 f.write(response)
969 return metadata.content_length
970
971 # Download large files using TransferConfig
972 bucket, key = split_path(remote_path)
973
974 def _invoke_api() -> int:
975 self._s3_client.download_fileobj(
976 Bucket=bucket,
977 Key=key,
978 Fileobj=f,
979 Config=self._transfer_config,
980 )
981
982 return metadata.content_length
983
984 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
985
986 def _generate_presigned_url(
987 self,
988 path: str,
989 *,
990 method: str = "GET",
991 signer_type: Optional[SignerType] = None,
992 signer_options: Optional[dict[str, Any]] = None,
993 ) -> str:
994 options = signer_options or {}
995 bucket, key = split_path(path)
996
997 if signer_type is None or signer_type == SignerType.S3:
998 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN))
999 cache_key: tuple = (SignerType.S3, bucket, expires_in)
1000 if cache_key not in self._signer_cache:
1001 self._signer_cache[cache_key] = S3URLSigner(self._s3_client, bucket, expires_in=expires_in)
1002 elif signer_type == SignerType.CLOUDFRONT:
1003 cache_key = (SignerType.CLOUDFRONT, frozenset(options.items()))
1004 if cache_key not in self._signer_cache:
1005 self._signer_cache[cache_key] = CloudFrontURLSigner(**options)
1006 else:
1007 raise ValueError(f"Unsupported signer type for S3 provider: {signer_type!r}")
1008
1009 return self._signer_cache[cache_key].generate_presigned_url(key, method=method)