1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29from dateutil.parser import parse as dateutil_parse
30
31from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
32
33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_MAX_POOL_CONNECTIONS, DEFAULT_READ_TIMEOUT
34from ..rust_utils import parse_retry_config, run_async_rust_client_method
35from ..signers import CloudFrontURLSigner, URLSigner
36from ..telemetry import Telemetry
37from ..types import (
38 AWARE_DATETIME_MIN,
39 Credentials,
40 CredentialsProvider,
41 ObjectMetadata,
42 PreconditionFailedError,
43 Range,
44 RetryableError,
45 SignerType,
46 SymlinkHandling,
47)
48from ..utils import (
49 safe_makedirs,
50 split_path,
51 validate_attributes,
52)
53from .base import BaseStorageProvider
54
55_T = TypeVar("_T")
56
57MiB = 1024 * 1024
58
59# Python and Rust share the same multipart_threshold to keep the code simple.
60MULTIPART_THRESHOLD = 64 * MiB
61MULTIPART_CHUNKSIZE = 32 * MiB
62IO_CHUNKSIZE = 32 * MiB
63PYTHON_MAX_CONCURRENCY = 8
64MAX_POOL_CONNECTIONS = DEFAULT_MAX_POOL_CONNECTIONS
65
66PROVIDER = "s3"
67
68EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
69
70# Source: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
71SUPPORTED_CHECKSUM_ALGORITHMS = frozenset({"CRC32", "CRC32C", "SHA1", "SHA256", "CRC64NVME"})
72
73
[docs]
74class StaticS3CredentialsProvider(CredentialsProvider):
75 """
76 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
77 """
78
79 _access_key: str
80 _secret_key: str
81 _session_token: Optional[str]
82
83 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
84 """
85 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
86 session token.
87
88 :param access_key: The access key for S3 authentication.
89 :param secret_key: The secret key for S3 authentication.
90 :param session_token: An optional session token for temporary credentials.
91 """
92 self._access_key = access_key
93 self._secret_key = secret_key
94 self._session_token = session_token
95
[docs]
96 def get_credentials(self) -> Credentials:
97 return Credentials(
98 access_key=self._access_key,
99 secret_key=self._secret_key,
100 token=self._session_token,
101 expiration=None,
102 )
103
[docs]
104 def refresh_credentials(self) -> None:
105 pass
106
107
108DEFAULT_PRESIGN_EXPIRES_IN = 3600
109
110_S3_METHOD_MAPPING: dict[str, str] = {
111 "GET": "get_object",
112 "PUT": "put_object",
113}
114
115
[docs]
116class S3URLSigner(URLSigner):
117 """Generates pre-signed URLs using the boto3 S3 client.
118
119 When the underlying credentials are temporary (STS, IAM role, EC2 instance
120 profile), the effective URL lifetime is the **shorter** of ``expires_in``
121 and the remaining credential lifetime — boto3 will not warn if the
122 credential expires before ``expires_in``.
123
124 See https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html
125 """
126
127 def __init__(self, s3_client: Any, bucket: str, expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN) -> None:
128 self._s3_client = s3_client
129 self._bucket = bucket
130 self._expires_in = expires_in
131
[docs]
132 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str:
133 client_method = _S3_METHOD_MAPPING.get(method.upper())
134 if client_method is None:
135 raise ValueError(f"Unsupported method for S3 presigning: {method!r}")
136 return self._s3_client.generate_presigned_url(
137 ClientMethod=client_method,
138 Params={"Bucket": self._bucket, "Key": path},
139 ExpiresIn=self._expires_in,
140 )
141
142
[docs]
143class S3StorageProvider(BaseStorageProvider):
144 """
145 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
146 """
147
148 def __init__(
149 self,
150 region_name: str = "",
151 endpoint_url: str = "",
152 base_path: str = "",
153 credentials_provider: Optional[CredentialsProvider] = None,
154 profile_name: Optional[str] = None,
155 config_dict: Optional[dict[str, Any]] = None,
156 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
157 verify: Optional[Union[bool, str]] = None,
158 **kwargs: Any,
159 ) -> None:
160 """
161 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
162
163 :param region_name: The AWS region where the S3 bucket is located.
164 :param endpoint_url: The custom endpoint URL for the S3 service.
165 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
166 :param credentials_provider: The provider to retrieve S3 credentials.
167 :param profile_name: AWS shared configuration + credentials files profile. For :py:class:`boto3.session.Session`. Ignored if ``credentials_provider`` is set.
168 :param config_dict: Resolved MSC config.
169 :param telemetry_provider: A function that provides a telemetry instance.
170 :param verify: Controls SSL certificate verification.
171 Can be ``True`` (verify using system CA bundle, default), ``False`` (skip verification), or a string path to a custom CA certificate bundle.
172 :param request_checksum_calculation: For :py:class:`botocore.config.Config`.
173 When the underlying S3 client should calculate request checksums.
174 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
175 :param response_checksum_validation: For :py:class:`botocore.config.Config`.
176 When the underlying S3 client should validate response checksums.
177 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
178 :param max_pool_connections: For :py:class:`botocore.config.Config`.
179 The maximum number of connections to keep in a connection pool.
180 :param connect_timeout: For :py:class:`botocore.config.Config`.
181 The time in seconds till a timeout exception is thrown when attempting to make a connection.
182 :param read_timeout: For :py:class:`botocore.config.Config`.
183 The time in seconds till a timeout exception is thrown when attempting to read from a connection.
184 :param retries: For :py:class:`botocore.config.Config`.
185 A dictionary for configuration related to retry behavior.
186 :param s3: For :py:class:`botocore.config.Config`.
187 A dictionary of S3 specific configurations.
188 :param checksum_algorithm: Upload-only object integrity algorithm.
189 One of ``"CRC32"``, ``"CRC32C"``, ``"SHA1"``, ``"SHA256"``, ``"CRC64NVME"`` (case-insensitive).
190 When ``rust_client`` is enabled, only ``"SHA256"`` is accepted.
191 """
192 super().__init__(
193 base_path=base_path,
194 provider_name=PROVIDER,
195 config_dict=config_dict,
196 telemetry_provider=telemetry_provider,
197 )
198
199 self._region_name = region_name
200 self._endpoint_url = endpoint_url
201 self._credentials_provider = credentials_provider
202 self._profile_name = profile_name
203 self._verify = verify
204
205 self._signature_version = kwargs.get("signature_version", "s3v4")
206 self._s3_client = self._create_s3_client(
207 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
208 response_checksum_validation=kwargs.get("response_checksum_validation"),
209 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
210 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT),
211 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT),
212 retries=kwargs.get("retries"),
213 s3=kwargs.get("s3"),
214 )
215 self._transfer_config = TransferConfig(
216 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
217 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
218 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
219 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
220 use_threads=True,
221 )
222 self._multipart_threshold = self._transfer_config.multipart_threshold
223
224 self._signer_cache: dict[tuple, URLSigner] = {}
225
226 self._checksum_algorithm: Optional[str] = self._validate_checksum_algorithm(kwargs.get("checksum_algorithm"))
227
228 self._rust_client = None
229 if "rust_client" in kwargs:
230 # Inherit the rust client options from the kwargs
231 rust_client_options = kwargs["rust_client"]
232 if "max_pool_connections" in kwargs:
233 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
234 if "max_concurrency" in kwargs:
235 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
236 if "multipart_chunksize" in kwargs:
237 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
238 if "read_timeout" in kwargs:
239 rust_client_options["read_timeout"] = kwargs["read_timeout"]
240 if "connect_timeout" in kwargs:
241 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
242 if "checksum_algorithm" in kwargs:
243 rust_client_options["checksum_algorithm"] = kwargs["checksum_algorithm"]
244 if self._signature_version == "UNSIGNED":
245 rust_client_options["skip_signature"] = True
246
247 # Rust client only supports SHA256 today (object_store limitation).
248 rust_checksum = rust_client_options.get("checksum_algorithm")
249 if rust_checksum is not None:
250 if str(rust_checksum).upper() != "SHA256":
251 raise ValueError(
252 f"Rust client only supports checksum_algorithm='SHA256' today "
253 f"(object_store limitation), got '{rust_checksum!r}'. "
254 f"Disable rust_client or use SHA256."
255 )
256 rust_client_options["checksum_algorithm"] = "sha256"
257 self._rust_client = self._create_rust_client(rust_client_options)
258
259 @staticmethod
260 def _validate_checksum_algorithm(value: Optional[str]) -> Optional[str]:
261 """
262 Normalizes the optional ``checksum_algorithm`` to uppercase, or returns ``None`` if disabled.
263 """
264 if value is None:
265 return None
266 if not isinstance(value, str) or not value:
267 raise ValueError(
268 f"checksum_algorithm must be one of {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)} or None, got {value!r}."
269 )
270 normalized = value.upper()
271 if normalized not in SUPPORTED_CHECKSUM_ALGORITHMS:
272 raise ValueError(
273 f"Unsupported checksum_algorithm '{value}'. Supported: {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)}."
274 )
275 return normalized
276
277 def _is_directory_bucket(self, bucket: str) -> bool:
278 """
279 Determines if the bucket is a directory bucket based on bucket name.
280 """
281 # S3 Express buckets have a specific naming convention
282 return "--x-s3" in bucket
283
284 def _create_s3_client(
285 self,
286 request_checksum_calculation: Optional[str] = None,
287 response_checksum_validation: Optional[str] = None,
288 max_pool_connections: int = MAX_POOL_CONNECTIONS,
289 connect_timeout: Union[float, int, None] = None,
290 read_timeout: Union[float, int, None] = None,
291 retries: Optional[dict[str, Any]] = None,
292 s3: Optional[dict[str, Any]] = None,
293 ):
294 """
295 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
296
297 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
298 :param response_checksum_validation: For :py:class:`botocore.config.Config`. When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
299 :param max_pool_connections: For :py:class:`botocore.config.Config`. The maximum number of connections to keep in a connection pool.
300 :param connect_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to make a connection.
301 :param read_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to read from a connection.
302 :param retries: For :py:class:`botocore.config.Config`. A dictionary for configuration related to retry behavior.
303 :param s3: For :py:class:`botocore.config.Config`. A dictionary of S3 specific configurations.
304
305 :return: The configured S3 client.
306 """
307 options = {
308 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
309 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
310 max_pool_connections=max_pool_connections,
311 connect_timeout=connect_timeout,
312 read_timeout=read_timeout,
313 retries=retries or {"mode": "standard"},
314 request_checksum_calculation=request_checksum_calculation,
315 response_checksum_validation=response_checksum_validation,
316 s3=s3,
317 ),
318 }
319
320 if self._region_name:
321 options["region_name"] = self._region_name
322
323 if self._endpoint_url:
324 options["endpoint_url"] = self._endpoint_url
325
326 if self._verify is not None:
327 options["verify"] = self._verify
328
329 if self._credentials_provider:
330 creds = self._fetch_credentials()
331 if "expiry_time" in creds and creds["expiry_time"]:
332 # Use RefreshableCredentials if expiry_time provided.
333 refreshable_credentials = RefreshableCredentials.create_from_metadata(
334 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
335 )
336
337 botocore_session = get_session()
338 botocore_session._credentials = refreshable_credentials
339
340 boto3_session = boto3.Session(botocore_session=botocore_session)
341
342 return boto3_session.client("s3", **options)
343 else:
344 # Add static credentials to the options dictionary
345 options["aws_access_key_id"] = creds["access_key"]
346 options["aws_secret_access_key"] = creds["secret_key"]
347 if creds["token"]:
348 options["aws_session_token"] = creds["token"]
349
350 if self._signature_version:
351 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
352 signature_version=botocore.UNSIGNED
353 if self._signature_version == "UNSIGNED"
354 else self._signature_version
355 )
356 options["config"] = options["config"].merge(signature_config)
357
358 # Fallback to standard credential chain.
359 session = boto3.Session(profile_name=self._profile_name)
360 return session.client("s3", **options)
361
362 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
363 """
364 Creates and configures the rust client, using refreshable credentials if possible.
365 """
366 configs = dict(rust_client_options) if rust_client_options else {}
367
368 # Extract and parse retry configuration
369 retry_config = parse_retry_config(configs)
370
371 if self._region_name and "region_name" not in configs:
372 configs["region_name"] = self._region_name
373
374 if self._endpoint_url and "endpoint_url" not in configs:
375 configs["endpoint_url"] = self._endpoint_url
376
377 if self._profile_name and "profile_name" not in configs:
378 configs["profile_name"] = self._profile_name
379
380 # If the user specifies a bucket, use it. Otherwise, use the base path.
381 if "bucket" not in configs:
382 bucket, _ = split_path(self._base_path)
383 configs["bucket"] = bucket
384
385 if "max_pool_connections" not in configs:
386 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
387
388 return RustClient(
389 provider=PROVIDER,
390 configs=configs,
391 credentials_provider=self._credentials_provider,
392 retry=retry_config,
393 )
394
395 def _fetch_credentials(self) -> dict:
396 """
397 Refreshes the S3 client if the current credentials are expired.
398 """
399 if not self._credentials_provider:
400 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
401 self._credentials_provider.refresh_credentials()
402 credentials = self._credentials_provider.get_credentials()
403 return {
404 "access_key": credentials.access_key,
405 "secret_key": credentials.secret_key,
406 "token": credentials.token,
407 "expiry_time": credentials.expiration,
408 }
409
410 def _translate_errors(
411 self,
412 func: Callable[[], _T],
413 operation: str,
414 bucket: str,
415 key: str,
416 ) -> _T:
417 """
418 Translates errors like timeouts and client errors.
419
420 :param func: The function that performs the actual S3 operation.
421 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
422 :param bucket: The name of the S3 bucket involved in the operation.
423 :param key: The key of the object within the S3 bucket.
424
425 :return: The result of the S3 operation, typically the return value of the `func` callable.
426 """
427 try:
428 return func()
429 except ClientError as error:
430 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
431 request_id = error.response["ResponseMetadata"].get("RequestId")
432 host_id = error.response["ResponseMetadata"].get("HostId")
433 error_code = error.response["Error"]["Code"]
434 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
435
436 if status_code == 404:
437 if error_code == "NoSuchUpload":
438 error_message = error.response["Error"]["Message"]
439 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
440 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
441 elif status_code == 412: # Precondition Failed
442 raise PreconditionFailedError(
443 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
444 ) from error
445 elif status_code == 429:
446 raise RetryableError(
447 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
448 ) from error
449 elif status_code == 503:
450 raise RetryableError(
451 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
452 ) from error
453 elif status_code == 501:
454 raise NotImplementedError(
455 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
456 ) from error
457 elif status_code == 408:
458 # 408 Request Timeout is from Google Cloud Storage
459 raise RetryableError(
460 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}"
461 ) from error
462 else:
463 raise RuntimeError(
464 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
465 f"error_type: {type(error).__name__}"
466 ) from error
467 except RustClientError as error:
468 message = error.args[0]
469 status_code = error.args[1]
470 if status_code == 404:
471 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
472 elif status_code == 403:
473 raise PermissionError(
474 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
475 ) from error
476 else:
477 raise RetryableError(
478 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
479 ) from error
480 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
481 raise RetryableError(
482 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
483 f"error_type: {type(error).__name__}"
484 ) from error
485 except RustRetryableError as error:
486 raise RetryableError(
487 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
488 f"error_type: {type(error).__name__}"
489 ) from error
490 except Exception as error:
491 raise RuntimeError(
492 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
493 ) from error
494
495 def _put_object(
496 self,
497 path: str,
498 body: bytes,
499 if_match: Optional[str] = None,
500 if_none_match: Optional[str] = None,
501 attributes: Optional[dict[str, str]] = None,
502 content_type: Optional[str] = None,
503 ) -> int:
504 """
505 Uploads an object to the specified S3 path.
506
507 :param path: The S3 path where the object will be uploaded.
508 :param body: The content of the object as bytes.
509 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
510 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
511 :param attributes: Optional attributes to attach to the object.
512 :param content_type: Optional Content-Type header value.
513 """
514 bucket, key = split_path(path)
515
516 def _invoke_api() -> int:
517 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
518 if content_type:
519 kwargs["ContentType"] = content_type
520 if self._is_directory_bucket(bucket):
521 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
522 if if_match:
523 kwargs["IfMatch"] = if_match
524 if if_none_match:
525 kwargs["IfNoneMatch"] = if_none_match
526 validated_attributes = validate_attributes(attributes)
527 if validated_attributes:
528 kwargs["Metadata"] = validated_attributes
529
530 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
531 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
532 if (
533 self._rust_client
534 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
535 and not path.endswith("/")
536 and all(key not in kwargs for key in rust_unsupported_feature_keys)
537 ):
538 run_async_rust_client_method(self._rust_client, "put", key, body)
539 else:
540 if self._checksum_algorithm:
541 kwargs["ChecksumAlgorithm"] = self._checksum_algorithm
542 self._s3_client.put_object(**kwargs)
543
544 return len(body)
545
546 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
547
548 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
549 bucket, key = split_path(path)
550
551 def _invoke_api() -> bytes:
552 if byte_range:
553 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
554 if self._rust_client:
555 response = run_async_rust_client_method(
556 self._rust_client,
557 "get",
558 key,
559 byte_range,
560 )
561 return response
562 else:
563 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
564 else:
565 if self._rust_client:
566 response = run_async_rust_client_method(self._rust_client, "get", key)
567 return response
568 else:
569 response = self._s3_client.get_object(Bucket=bucket, Key=key)
570
571 return response["Body"].read()
572
573 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
574
575 def _copy_object(self, src_path: str, dest_path: str) -> int:
576 src_bucket, src_key = split_path(src_path)
577 dest_bucket, dest_key = split_path(dest_path)
578
579 src_object = self._get_object_metadata(src_path)
580
581 def _invoke_api() -> int:
582 self._s3_client.copy(
583 CopySource={"Bucket": src_bucket, "Key": src_key},
584 Bucket=dest_bucket,
585 Key=dest_key,
586 Config=self._transfer_config,
587 )
588
589 return src_object.content_length
590
591 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
592
593 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
594 bucket, key = split_path(path)
595
596 def _invoke_api() -> None:
597 # Delete conditionally when if_match (etag) is provided; otherwise delete unconditionally
598 if if_match:
599 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
600 else:
601 self._s3_client.delete_object(Bucket=bucket, Key=key)
602
603 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
604
605 def _delete_objects(self, paths: list[str]) -> None:
606 if not paths:
607 return
608
609 by_bucket: dict[str, list[str]] = {}
610 for p in paths:
611 bucket, key = split_path(p)
612 by_bucket.setdefault(bucket, []).append(key)
613
614 S3_BATCH_LIMIT = 1000
615
616 def _invoke_api() -> None:
617 all_errors: list[str] = []
618 for bucket, keys in by_bucket.items():
619 for i in range(0, len(keys), S3_BATCH_LIMIT):
620 chunk = keys[i : i + S3_BATCH_LIMIT]
621 response = self._s3_client.delete_objects(
622 Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}
623 )
624 errors = response.get("Errors") or []
625 for e in errors:
626 all_errors.append(f"{bucket}/{e.get('Key', '?')}: {e.get('Code', '')} {e.get('Message', '')}")
627 if all_errors:
628 raise RuntimeError(f"DeleteObjects reported errors: {'; '.join(all_errors)}")
629
630 bucket_desc = "(" + "|".join(by_bucket) + ")"
631 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)"
632 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc)
633
634 def _is_dir(self, path: str) -> bool:
635 # Ensure the path ends with '/' to mimic a directory
636 path = self._append_delimiter(path)
637
638 bucket, key = split_path(path)
639
640 def _invoke_api() -> bool:
641 # List objects with the given prefix
642 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
643
644 # Check if there are any contents or common prefixes
645 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
646
647 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
648
649 def _make_symlink(self, path: str, target: str) -> None:
650 bucket, key = split_path(path)
651 target_bucket, target_key = split_path(target)
652 if bucket != target_bucket:
653 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket}' -> '{target_bucket}'.")
654 relative_target = ObjectMetadata.encode_symlink_target(key, target_key)
655
656 def _invoke_api() -> None:
657 self._s3_client.put_object(
658 Bucket=bucket,
659 Key=key,
660 Body=b"",
661 Metadata={"msc-symlink-target": relative_target},
662 )
663
664 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
665
666 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
667 bucket, key = split_path(path)
668 if path.endswith("/") or (bucket and not key):
669 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
670 # which metadata is not guaranteed to exist for cases such as
671 # "virtual prefix" that was never explicitly created.
672 if self._is_dir(path):
673 return ObjectMetadata(
674 key=path,
675 type="directory",
676 content_length=0,
677 last_modified=AWARE_DATETIME_MIN,
678 )
679 else:
680 raise FileNotFoundError(f"Directory {path} does not exist.")
681 else:
682
683 def _invoke_api() -> ObjectMetadata:
684 response = self._s3_client.head_object(Bucket=bucket, Key=key)
685 user_metadata = response.get("Metadata")
686 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None
687
688 return ObjectMetadata(
689 key=path,
690 type="file",
691 content_length=response["ContentLength"],
692 content_type=response.get("ContentType"),
693 last_modified=response.get("LastModified", AWARE_DATETIME_MIN),
694 etag=response["ETag"].strip('"') if "ETag" in response else None,
695 storage_class=response.get("StorageClass"),
696 metadata=user_metadata,
697 symlink_target=symlink_target,
698 )
699
700 try:
701 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
702 except FileNotFoundError as error:
703 if strict:
704 # If the object does not exist on the given path, we will append a trailing slash and
705 # check if the path is a directory.
706 path = self._append_delimiter(path)
707 if self._is_dir(path):
708 return ObjectMetadata(
709 key=path,
710 type="directory",
711 content_length=0,
712 last_modified=AWARE_DATETIME_MIN,
713 )
714 raise error
715
716 def _list_objects(
717 self,
718 path: str,
719 start_after: Optional[str] = None,
720 end_at: Optional[str] = None,
721 include_directories: bool = False,
722 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW,
723 ) -> Iterator[ObjectMetadata]:
724 bucket, prefix = split_path(path)
725
726 # Get the prefix of the start_after and end_at paths relative to the bucket.
727 if start_after:
728 _, start_after = split_path(start_after)
729 if end_at:
730 _, end_at = split_path(end_at)
731
732 def _invoke_api() -> Iterator[ObjectMetadata]:
733 paginator = self._s3_client.get_paginator("list_objects_v2")
734 if include_directories:
735 page_iterator = paginator.paginate(
736 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
737 )
738 else:
739 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
740
741 for page in page_iterator:
742 for item in page.get("CommonPrefixes", []):
743 prefix_key = item["Prefix"].rstrip("/")
744 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes
745 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
746 yield ObjectMetadata(
747 key=os.path.join(bucket, prefix_key),
748 type="directory",
749 content_length=0,
750 last_modified=AWARE_DATETIME_MIN,
751 )
752 elif end_at is not None and end_at < prefix_key:
753 return
754
755 # S3 guarantees lexicographical order for general purpose buckets (for
756 # normal S3) but not directory buckets (for S3 Express One Zone).
757 for response_object in page.get("Contents", []):
758 key = response_object["Key"]
759 if end_at is None or key <= end_at:
760 if key.endswith("/"):
761 if include_directories:
762 yield ObjectMetadata(
763 key=os.path.join(bucket, key.rstrip("/")),
764 type="directory",
765 content_length=0,
766 last_modified=response_object["LastModified"],
767 )
768 else:
769 symlink_target = None
770 if response_object.get("Size", 0) == 0:
771 try:
772 meta = self._get_object_metadata(os.path.join(bucket, key))
773 symlink_target = meta.symlink_target
774 except Exception:
775 symlink_target = None
776 yield ObjectMetadata(
777 key=os.path.join(bucket, key),
778 type="file",
779 content_length=response_object["Size"],
780 last_modified=response_object["LastModified"],
781 etag=response_object["ETag"].strip('"'),
782 storage_class=response_object.get("StorageClass"),
783 symlink_target=symlink_target,
784 )
785 else:
786 return
787
788 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
789
790 @property
791 def supports_parallel_listing(self) -> bool:
792 """
793 S3 supports parallel listing via delimiter-based prefix discovery.
794
795 Note: Directory bucket handling is done in list_objects_recursive().
796 """
797 return True
798
[docs]
799 def list_objects_recursive(
800 self,
801 path: str = "",
802 start_after: Optional[str] = None,
803 end_at: Optional[str] = None,
804 max_workers: int = 32,
805 look_ahead: int = 2,
806 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW,
807 ) -> Iterator[ObjectMetadata]:
808 """
809 List all objects recursively using parallel prefix discovery for improved performance.
810
811 For S3, uses the Rust client's list_recursive when available for maximum performance.
812 Falls back to Python implementation otherwise.
813
814 Returns files only (no directories), in lexicographic order.
815
816 :param symlink_handling: How to handle symbolic links during listing.
817 """
818 if (start_after is not None) and (end_at is not None) and not (start_after < end_at):
819 raise ValueError(f"start_after ({start_after}) must be before end_at ({end_at})!")
820
821 full_path = self._prepend_base_path(path)
822 bucket, prefix = split_path(full_path)
823
824 if self._is_directory_bucket(bucket):
825 yield from self.list_objects(
826 path, start_after, end_at, include_directories=False, symlink_handling=symlink_handling
827 )
828 return
829
830 if self._rust_client:
831 yield from self._emit_metrics(
832 operation=BaseStorageProvider._Operation.LIST,
833 f=lambda: self._list_objects_recursive_rust(path, full_path, bucket, start_after, end_at, max_workers),
834 )
835 else:
836 yield from super().list_objects_recursive(
837 path, start_after, end_at, max_workers, look_ahead, symlink_handling
838 )
839
840 def _list_objects_recursive_rust(
841 self,
842 path: str,
843 full_path: str,
844 bucket: str,
845 start_after: Optional[str],
846 end_at: Optional[str],
847 max_workers: int,
848 ) -> Iterator[ObjectMetadata]:
849 """
850 Use Rust client's list_recursive for parallel listing.
851
852 The Rust client already handles parallel listing internally.
853 Returns files only in lexicographic order.
854 """
855 _, prefix = split_path(full_path)
856
857 def _invoke_api() -> Iterator[ObjectMetadata]:
858 result = run_async_rust_client_method(
859 self._rust_client,
860 "list_recursive",
861 [prefix] if prefix else [""],
862 max_concurrency=max_workers,
863 )
864
865 start_after_full = self._prepend_base_path(start_after) if start_after else None
866 end_at_full = self._prepend_base_path(end_at) if end_at else None
867
868 for obj in result.objects:
869 full_key = os.path.join(bucket, obj.key)
870
871 if start_after_full and full_key <= start_after_full:
872 continue
873 if end_at_full and full_key > end_at_full:
874 break
875
876 relative_key = full_key.removeprefix(self._base_path).lstrip("/")
877
878 yield ObjectMetadata(
879 key=relative_key,
880 content_length=obj.content_length,
881 last_modified=dateutil_parse(obj.last_modified),
882 type="file" if obj.object_type == "object" else obj.object_type,
883 etag=obj.etag,
884 )
885
886 yield from self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
887
888 def _upload_file(
889 self,
890 remote_path: str,
891 f: Union[str, IO],
892 attributes: Optional[dict[str, str]] = None,
893 content_type: Optional[str] = None,
894 ) -> int:
895 file_size: int = 0
896
897 if isinstance(f, str):
898 bucket, key = split_path(remote_path)
899 file_size = os.path.getsize(f)
900
901 # Upload small files
902 if file_size <= self._multipart_threshold:
903 if self._rust_client and not attributes and not content_type:
904 run_async_rust_client_method(self._rust_client, "upload", f, key)
905 else:
906 with open(f, "rb") as fp:
907 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
908 return file_size
909
910 # Upload large files using TransferConfig
911 def _invoke_api() -> int:
912 extra_args = {}
913 if content_type:
914 extra_args["ContentType"] = content_type
915 if self._is_directory_bucket(bucket):
916 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
917 validated_attributes = validate_attributes(attributes)
918 if validated_attributes:
919 extra_args["Metadata"] = validated_attributes
920 if self._rust_client and not extra_args:
921 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
922 else:
923 if self._checksum_algorithm:
924 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm
925 self._s3_client.upload_file(
926 Filename=f,
927 Bucket=bucket,
928 Key=key,
929 Config=self._transfer_config,
930 ExtraArgs=extra_args,
931 )
932
933 return file_size
934
935 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
936 else:
937 # Upload small files
938 f.seek(0, io.SEEK_END)
939 file_size = f.tell()
940 f.seek(0)
941
942 if file_size <= self._multipart_threshold:
943 if isinstance(f, io.StringIO):
944 self._put_object(
945 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
946 )
947 else:
948 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
949 return file_size
950
951 # Upload large files using TransferConfig
952 bucket, key = split_path(remote_path)
953
954 def _invoke_api() -> int:
955 extra_args = {}
956 if content_type:
957 extra_args["ContentType"] = content_type
958 if self._is_directory_bucket(bucket):
959 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
960 validated_attributes = validate_attributes(attributes)
961 if validated_attributes:
962 extra_args["Metadata"] = validated_attributes
963
964 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
965 data = f.getbuffer()
966 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
967 else:
968 if self._checksum_algorithm:
969 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm
970 self._s3_client.upload_fileobj(
971 Fileobj=f,
972 Bucket=bucket,
973 Key=key,
974 Config=self._transfer_config,
975 ExtraArgs=extra_args,
976 )
977
978 return file_size
979
980 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
981
982 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
983 if metadata is None:
984 metadata = self._get_object_metadata(remote_path)
985
986 if isinstance(f, str):
987 bucket, key = split_path(remote_path)
988 if os.path.dirname(f):
989 safe_makedirs(os.path.dirname(f))
990
991 # Download small files
992 if metadata.content_length <= self._multipart_threshold:
993 if self._rust_client:
994 run_async_rust_client_method(self._rust_client, "download", key, f)
995 else:
996 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
997 temp_file_path = fp.name
998 fp.write(self._get_object(remote_path))
999 os.rename(src=temp_file_path, dst=f)
1000 return metadata.content_length
1001
1002 # Download large files using TransferConfig
1003 def _invoke_api() -> int:
1004 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
1005 temp_file_path = fp.name
1006 if self._rust_client:
1007 run_async_rust_client_method(
1008 self._rust_client, "download_multipart_to_file", key, temp_file_path
1009 )
1010 else:
1011 self._s3_client.download_fileobj(
1012 Bucket=bucket,
1013 Key=key,
1014 Fileobj=fp,
1015 Config=self._transfer_config,
1016 )
1017
1018 os.rename(src=temp_file_path, dst=f)
1019
1020 return metadata.content_length
1021
1022 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
1023 else:
1024 # Download small files
1025 if metadata.content_length <= self._multipart_threshold:
1026 response = self._get_object(remote_path)
1027 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol,
1028 # so we need to check whether `.decode()` is available.
1029 if isinstance(f, io.StringIO):
1030 if hasattr(response, "decode"):
1031 f.write(response.decode("utf-8"))
1032 else:
1033 f.write(codecs.decode(memoryview(response), "utf-8"))
1034 else:
1035 f.write(response)
1036 return metadata.content_length
1037
1038 # Download large files using TransferConfig
1039 bucket, key = split_path(remote_path)
1040
1041 def _invoke_api() -> int:
1042 self._s3_client.download_fileobj(
1043 Bucket=bucket,
1044 Key=key,
1045 Fileobj=f,
1046 Config=self._transfer_config,
1047 )
1048
1049 return metadata.content_length
1050
1051 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
1052
1053 def _generate_presigned_url(
1054 self,
1055 path: str,
1056 *,
1057 method: str = "GET",
1058 signer_type: Optional[SignerType] = None,
1059 signer_options: Optional[dict[str, Any]] = None,
1060 ) -> str:
1061 options = signer_options or {}
1062 bucket, key = split_path(path)
1063
1064 if signer_type is None or signer_type == SignerType.S3:
1065 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN))
1066 cache_key: tuple = (SignerType.S3, bucket, expires_in)
1067 if cache_key not in self._signer_cache:
1068 self._signer_cache[cache_key] = S3URLSigner(self._s3_client, bucket, expires_in=expires_in)
1069 elif signer_type == SignerType.CLOUDFRONT:
1070 cache_key = (SignerType.CLOUDFRONT, frozenset(options.items()))
1071 if cache_key not in self._signer_cache:
1072 self._signer_cache[cache_key] = CloudFrontURLSigner(**options)
1073 else:
1074 raise ValueError(f"Unsupported signer type for S3 provider: {signer_type!r}")
1075
1076 return self._signer_cache[cache_key].generate_presigned_url(key, method=method)