1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29from dateutil.parser import parse as dateutil_parse
30
31from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
32
33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
34from ..rust_utils import parse_retry_config, run_async_rust_client_method
35from ..signers import CloudFrontURLSigner, URLSigner
36from ..telemetry import Telemetry
37from ..types import (
38 AWARE_DATETIME_MIN,
39 Credentials,
40 CredentialsProvider,
41 ObjectMetadata,
42 PreconditionFailedError,
43 Range,
44 RetryableError,
45 SignerType,
46 SymlinkHandling,
47)
48from ..utils import (
49 get_available_cpu_count,
50 safe_makedirs,
51 split_path,
52 validate_attributes,
53)
54from .base import BaseStorageProvider
55
56_T = TypeVar("_T")
57
58# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
59MAX_POOL_CONNECTIONS = max(
60 64,
61 get_available_cpu_count(),
62 int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
63)
64
65MiB = 1024 * 1024
66
67# Python and Rust share the same multipart_threshold to keep the code simple.
68MULTIPART_THRESHOLD = 64 * MiB
69MULTIPART_CHUNKSIZE = 32 * MiB
70IO_CHUNKSIZE = 32 * MiB
71PYTHON_MAX_CONCURRENCY = 8
72
73PROVIDER = "s3"
74
75EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
76
77# Source: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
78SUPPORTED_CHECKSUM_ALGORITHMS = frozenset({"CRC32", "CRC32C", "SHA1", "SHA256", "CRC64NVME"})
79
80
[docs]
81class StaticS3CredentialsProvider(CredentialsProvider):
82 """
83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
84 """
85
86 _access_key: str
87 _secret_key: str
88 _session_token: Optional[str]
89
90 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
91 """
92 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
93 session token.
94
95 :param access_key: The access key for S3 authentication.
96 :param secret_key: The secret key for S3 authentication.
97 :param session_token: An optional session token for temporary credentials.
98 """
99 self._access_key = access_key
100 self._secret_key = secret_key
101 self._session_token = session_token
102
[docs]
103 def get_credentials(self) -> Credentials:
104 return Credentials(
105 access_key=self._access_key,
106 secret_key=self._secret_key,
107 token=self._session_token,
108 expiration=None,
109 )
110
[docs]
111 def refresh_credentials(self) -> None:
112 pass
113
114
115DEFAULT_PRESIGN_EXPIRES_IN = 3600
116
117_S3_METHOD_MAPPING: dict[str, str] = {
118 "GET": "get_object",
119 "PUT": "put_object",
120}
121
122
[docs]
123class S3URLSigner(URLSigner):
124 """Generates pre-signed URLs using the boto3 S3 client.
125
126 When the underlying credentials are temporary (STS, IAM role, EC2 instance
127 profile), the effective URL lifetime is the **shorter** of ``expires_in``
128 and the remaining credential lifetime — boto3 will not warn if the
129 credential expires before ``expires_in``.
130
131 See https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html
132 """
133
134 def __init__(self, s3_client: Any, bucket: str, expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN) -> None:
135 self._s3_client = s3_client
136 self._bucket = bucket
137 self._expires_in = expires_in
138
[docs]
139 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str:
140 client_method = _S3_METHOD_MAPPING.get(method.upper())
141 if client_method is None:
142 raise ValueError(f"Unsupported method for S3 presigning: {method!r}")
143 return self._s3_client.generate_presigned_url(
144 ClientMethod=client_method,
145 Params={"Bucket": self._bucket, "Key": path},
146 ExpiresIn=self._expires_in,
147 )
148
149
[docs]
150class S3StorageProvider(BaseStorageProvider):
151 """
152 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
153 """
154
155 def __init__(
156 self,
157 region_name: str = "",
158 endpoint_url: str = "",
159 base_path: str = "",
160 credentials_provider: Optional[CredentialsProvider] = None,
161 profile_name: Optional[str] = None,
162 config_dict: Optional[dict[str, Any]] = None,
163 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
164 verify: Optional[Union[bool, str]] = None,
165 **kwargs: Any,
166 ) -> None:
167 """
168 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
169
170 :param region_name: The AWS region where the S3 bucket is located.
171 :param endpoint_url: The custom endpoint URL for the S3 service.
172 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
173 :param credentials_provider: The provider to retrieve S3 credentials.
174 :param profile_name: AWS shared configuration + credentials files profile. For :py:class:`boto3.session.Session`. Ignored if ``credentials_provider`` is set.
175 :param config_dict: Resolved MSC config.
176 :param telemetry_provider: A function that provides a telemetry instance.
177 :param verify: Controls SSL certificate verification.
178 Can be ``True`` (verify using system CA bundle, default), ``False`` (skip verification), or a string path to a custom CA certificate bundle.
179 :param request_checksum_calculation: For :py:class:`botocore.config.Config`.
180 When the underlying S3 client should calculate request checksums.
181 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
182 :param response_checksum_validation: For :py:class:`botocore.config.Config`.
183 When the underlying S3 client should validate response checksums.
184 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
185 :param max_pool_connections: For :py:class:`botocore.config.Config`.
186 The maximum number of connections to keep in a connection pool.
187 :param connect_timeout: For :py:class:`botocore.config.Config`.
188 The time in seconds till a timeout exception is thrown when attempting to make a connection.
189 :param read_timeout: For :py:class:`botocore.config.Config`.
190 The time in seconds till a timeout exception is thrown when attempting to read from a connection.
191 :param retries: For :py:class:`botocore.config.Config`.
192 A dictionary for configuration related to retry behavior.
193 :param s3: For :py:class:`botocore.config.Config`.
194 A dictionary of S3 specific configurations.
195 :param checksum_algorithm: Upload-only object integrity algorithm.
196 One of ``"CRC32"``, ``"CRC32C"``, ``"SHA1"``, ``"SHA256"``, ``"CRC64NVME"`` (case-insensitive).
197 When ``rust_client`` is enabled, only ``"SHA256"`` is accepted.
198 """
199 super().__init__(
200 base_path=base_path,
201 provider_name=PROVIDER,
202 config_dict=config_dict,
203 telemetry_provider=telemetry_provider,
204 )
205
206 self._region_name = region_name
207 self._endpoint_url = endpoint_url
208 self._credentials_provider = credentials_provider
209 self._profile_name = profile_name
210 self._verify = verify
211
212 self._signature_version = kwargs.get("signature_version", "s3v4")
213 self._s3_client = self._create_s3_client(
214 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
215 response_checksum_validation=kwargs.get("response_checksum_validation"),
216 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
217 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT),
218 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT),
219 retries=kwargs.get("retries"),
220 s3=kwargs.get("s3"),
221 )
222 self._transfer_config = TransferConfig(
223 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
224 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
225 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
226 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
227 use_threads=True,
228 )
229 self._multipart_threshold = self._transfer_config.multipart_threshold
230
231 self._signer_cache: dict[tuple, URLSigner] = {}
232
233 self._checksum_algorithm: Optional[str] = self._validate_checksum_algorithm(kwargs.get("checksum_algorithm"))
234
235 self._rust_client = None
236 if "rust_client" in kwargs:
237 # Inherit the rust client options from the kwargs
238 rust_client_options = kwargs["rust_client"]
239 if "max_pool_connections" in kwargs:
240 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
241 if "max_concurrency" in kwargs:
242 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
243 if "multipart_chunksize" in kwargs:
244 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
245 if "read_timeout" in kwargs:
246 rust_client_options["read_timeout"] = kwargs["read_timeout"]
247 if "connect_timeout" in kwargs:
248 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
249 if "checksum_algorithm" in kwargs:
250 rust_client_options["checksum_algorithm"] = kwargs["checksum_algorithm"]
251 if self._signature_version == "UNSIGNED":
252 rust_client_options["skip_signature"] = True
253
254 # Rust client only supports SHA256 today (object_store limitation).
255 rust_checksum = rust_client_options.get("checksum_algorithm")
256 if rust_checksum is not None:
257 if str(rust_checksum).upper() != "SHA256":
258 raise ValueError(
259 f"Rust client only supports checksum_algorithm='SHA256' today "
260 f"(object_store limitation), got '{rust_checksum!r}'. "
261 f"Disable rust_client or use SHA256."
262 )
263 rust_client_options["checksum_algorithm"] = "sha256"
264 self._rust_client = self._create_rust_client(rust_client_options)
265
266 @staticmethod
267 def _validate_checksum_algorithm(value: Optional[str]) -> Optional[str]:
268 """
269 Normalizes the optional ``checksum_algorithm`` to uppercase, or returns ``None`` if disabled.
270 """
271 if value is None:
272 return None
273 if not isinstance(value, str) or not value:
274 raise ValueError(
275 f"checksum_algorithm must be one of {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)} or None, got {value!r}."
276 )
277 normalized = value.upper()
278 if normalized not in SUPPORTED_CHECKSUM_ALGORITHMS:
279 raise ValueError(
280 f"Unsupported checksum_algorithm '{value}'. Supported: {sorted(SUPPORTED_CHECKSUM_ALGORITHMS)}."
281 )
282 return normalized
283
284 def _is_directory_bucket(self, bucket: str) -> bool:
285 """
286 Determines if the bucket is a directory bucket based on bucket name.
287 """
288 # S3 Express buckets have a specific naming convention
289 return "--x-s3" in bucket
290
291 def _create_s3_client(
292 self,
293 request_checksum_calculation: Optional[str] = None,
294 response_checksum_validation: Optional[str] = None,
295 max_pool_connections: int = MAX_POOL_CONNECTIONS,
296 connect_timeout: Union[float, int, None] = None,
297 read_timeout: Union[float, int, None] = None,
298 retries: Optional[dict[str, Any]] = None,
299 s3: Optional[dict[str, Any]] = None,
300 ):
301 """
302 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
303
304 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
305 :param response_checksum_validation: For :py:class:`botocore.config.Config`. When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
306 :param max_pool_connections: For :py:class:`botocore.config.Config`. The maximum number of connections to keep in a connection pool.
307 :param connect_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to make a connection.
308 :param read_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to read from a connection.
309 :param retries: For :py:class:`botocore.config.Config`. A dictionary for configuration related to retry behavior.
310 :param s3: For :py:class:`botocore.config.Config`. A dictionary of S3 specific configurations.
311
312 :return: The configured S3 client.
313 """
314 options = {
315 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
316 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
317 max_pool_connections=max_pool_connections,
318 connect_timeout=connect_timeout,
319 read_timeout=read_timeout,
320 retries=retries or {"mode": "standard"},
321 request_checksum_calculation=request_checksum_calculation,
322 response_checksum_validation=response_checksum_validation,
323 s3=s3,
324 ),
325 }
326
327 if self._region_name:
328 options["region_name"] = self._region_name
329
330 if self._endpoint_url:
331 options["endpoint_url"] = self._endpoint_url
332
333 if self._verify is not None:
334 options["verify"] = self._verify
335
336 if self._credentials_provider:
337 creds = self._fetch_credentials()
338 if "expiry_time" in creds and creds["expiry_time"]:
339 # Use RefreshableCredentials if expiry_time provided.
340 refreshable_credentials = RefreshableCredentials.create_from_metadata(
341 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
342 )
343
344 botocore_session = get_session()
345 botocore_session._credentials = refreshable_credentials
346
347 boto3_session = boto3.Session(botocore_session=botocore_session)
348
349 return boto3_session.client("s3", **options)
350 else:
351 # Add static credentials to the options dictionary
352 options["aws_access_key_id"] = creds["access_key"]
353 options["aws_secret_access_key"] = creds["secret_key"]
354 if creds["token"]:
355 options["aws_session_token"] = creds["token"]
356
357 if self._signature_version:
358 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
359 signature_version=botocore.UNSIGNED
360 if self._signature_version == "UNSIGNED"
361 else self._signature_version
362 )
363 options["config"] = options["config"].merge(signature_config)
364
365 # Fallback to standard credential chain.
366 session = boto3.Session(profile_name=self._profile_name)
367 return session.client("s3", **options)
368
369 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
370 """
371 Creates and configures the rust client, using refreshable credentials if possible.
372 """
373 configs = dict(rust_client_options) if rust_client_options else {}
374
375 # Extract and parse retry configuration
376 retry_config = parse_retry_config(configs)
377
378 if self._region_name and "region_name" not in configs:
379 configs["region_name"] = self._region_name
380
381 if self._endpoint_url and "endpoint_url" not in configs:
382 configs["endpoint_url"] = self._endpoint_url
383
384 if self._profile_name and "profile_name" not in configs:
385 configs["profile_name"] = self._profile_name
386
387 # If the user specifies a bucket, use it. Otherwise, use the base path.
388 if "bucket" not in configs:
389 bucket, _ = split_path(self._base_path)
390 configs["bucket"] = bucket
391
392 if "max_pool_connections" not in configs:
393 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
394
395 return RustClient(
396 provider=PROVIDER,
397 configs=configs,
398 credentials_provider=self._credentials_provider,
399 retry=retry_config,
400 )
401
402 def _fetch_credentials(self) -> dict:
403 """
404 Refreshes the S3 client if the current credentials are expired.
405 """
406 if not self._credentials_provider:
407 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
408 self._credentials_provider.refresh_credentials()
409 credentials = self._credentials_provider.get_credentials()
410 return {
411 "access_key": credentials.access_key,
412 "secret_key": credentials.secret_key,
413 "token": credentials.token,
414 "expiry_time": credentials.expiration,
415 }
416
417 def _translate_errors(
418 self,
419 func: Callable[[], _T],
420 operation: str,
421 bucket: str,
422 key: str,
423 ) -> _T:
424 """
425 Translates errors like timeouts and client errors.
426
427 :param func: The function that performs the actual S3 operation.
428 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
429 :param bucket: The name of the S3 bucket involved in the operation.
430 :param key: The key of the object within the S3 bucket.
431
432 :return: The result of the S3 operation, typically the return value of the `func` callable.
433 """
434 try:
435 return func()
436 except ClientError as error:
437 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
438 request_id = error.response["ResponseMetadata"].get("RequestId")
439 host_id = error.response["ResponseMetadata"].get("HostId")
440 error_code = error.response["Error"]["Code"]
441 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
442
443 if status_code == 404:
444 if error_code == "NoSuchUpload":
445 error_message = error.response["Error"]["Message"]
446 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
447 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
448 elif status_code == 412: # Precondition Failed
449 raise PreconditionFailedError(
450 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
451 ) from error
452 elif status_code == 429:
453 raise RetryableError(
454 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
455 ) from error
456 elif status_code == 503:
457 raise RetryableError(
458 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
459 ) from error
460 elif status_code == 501:
461 raise NotImplementedError(
462 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
463 ) from error
464 elif status_code == 408:
465 # 408 Request Timeout is from Google Cloud Storage
466 raise RetryableError(
467 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}"
468 ) from error
469 else:
470 raise RuntimeError(
471 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
472 f"error_type: {type(error).__name__}"
473 ) from error
474 except RustClientError as error:
475 message = error.args[0]
476 status_code = error.args[1]
477 if status_code == 404:
478 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
479 elif status_code == 403:
480 raise PermissionError(
481 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
482 ) from error
483 else:
484 raise RetryableError(
485 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
486 ) from error
487 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
488 raise RetryableError(
489 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
490 f"error_type: {type(error).__name__}"
491 ) from error
492 except RustRetryableError as error:
493 raise RetryableError(
494 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
495 f"error_type: {type(error).__name__}"
496 ) from error
497 except Exception as error:
498 raise RuntimeError(
499 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
500 ) from error
501
502 def _put_object(
503 self,
504 path: str,
505 body: bytes,
506 if_match: Optional[str] = None,
507 if_none_match: Optional[str] = None,
508 attributes: Optional[dict[str, str]] = None,
509 content_type: Optional[str] = None,
510 ) -> int:
511 """
512 Uploads an object to the specified S3 path.
513
514 :param path: The S3 path where the object will be uploaded.
515 :param body: The content of the object as bytes.
516 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
517 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
518 :param attributes: Optional attributes to attach to the object.
519 :param content_type: Optional Content-Type header value.
520 """
521 bucket, key = split_path(path)
522
523 def _invoke_api() -> int:
524 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
525 if content_type:
526 kwargs["ContentType"] = content_type
527 if self._is_directory_bucket(bucket):
528 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
529 if if_match:
530 kwargs["IfMatch"] = if_match
531 if if_none_match:
532 kwargs["IfNoneMatch"] = if_none_match
533 validated_attributes = validate_attributes(attributes)
534 if validated_attributes:
535 kwargs["Metadata"] = validated_attributes
536
537 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
538 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
539 if (
540 self._rust_client
541 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
542 and not path.endswith("/")
543 and all(key not in kwargs for key in rust_unsupported_feature_keys)
544 ):
545 run_async_rust_client_method(self._rust_client, "put", key, body)
546 else:
547 if self._checksum_algorithm:
548 kwargs["ChecksumAlgorithm"] = self._checksum_algorithm
549 self._s3_client.put_object(**kwargs)
550
551 return len(body)
552
553 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
554
555 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
556 bucket, key = split_path(path)
557
558 def _invoke_api() -> bytes:
559 if byte_range:
560 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
561 if self._rust_client:
562 response = run_async_rust_client_method(
563 self._rust_client,
564 "get",
565 key,
566 byte_range,
567 )
568 return response
569 else:
570 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
571 else:
572 if self._rust_client:
573 response = run_async_rust_client_method(self._rust_client, "get", key)
574 return response
575 else:
576 response = self._s3_client.get_object(Bucket=bucket, Key=key)
577
578 return response["Body"].read()
579
580 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
581
582 def _copy_object(self, src_path: str, dest_path: str) -> int:
583 src_bucket, src_key = split_path(src_path)
584 dest_bucket, dest_key = split_path(dest_path)
585
586 src_object = self._get_object_metadata(src_path)
587
588 def _invoke_api() -> int:
589 self._s3_client.copy(
590 CopySource={"Bucket": src_bucket, "Key": src_key},
591 Bucket=dest_bucket,
592 Key=dest_key,
593 Config=self._transfer_config,
594 )
595
596 return src_object.content_length
597
598 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
599
600 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
601 bucket, key = split_path(path)
602
603 def _invoke_api() -> None:
604 # Delete conditionally when if_match (etag) is provided; otherwise delete unconditionally
605 if if_match:
606 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
607 else:
608 self._s3_client.delete_object(Bucket=bucket, Key=key)
609
610 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
611
612 def _delete_objects(self, paths: list[str]) -> None:
613 if not paths:
614 return
615
616 by_bucket: dict[str, list[str]] = {}
617 for p in paths:
618 bucket, key = split_path(p)
619 by_bucket.setdefault(bucket, []).append(key)
620
621 S3_BATCH_LIMIT = 1000
622
623 def _invoke_api() -> None:
624 all_errors: list[str] = []
625 for bucket, keys in by_bucket.items():
626 for i in range(0, len(keys), S3_BATCH_LIMIT):
627 chunk = keys[i : i + S3_BATCH_LIMIT]
628 response = self._s3_client.delete_objects(
629 Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}
630 )
631 errors = response.get("Errors") or []
632 for e in errors:
633 all_errors.append(f"{bucket}/{e.get('Key', '?')}: {e.get('Code', '')} {e.get('Message', '')}")
634 if all_errors:
635 raise RuntimeError(f"DeleteObjects reported errors: {'; '.join(all_errors)}")
636
637 bucket_desc = "(" + "|".join(by_bucket) + ")"
638 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)"
639 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc)
640
641 def _is_dir(self, path: str) -> bool:
642 # Ensure the path ends with '/' to mimic a directory
643 path = self._append_delimiter(path)
644
645 bucket, key = split_path(path)
646
647 def _invoke_api() -> bool:
648 # List objects with the given prefix
649 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
650
651 # Check if there are any contents or common prefixes
652 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
653
654 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
655
656 def _make_symlink(self, path: str, target: str) -> None:
657 bucket, key = split_path(path)
658 target_bucket, target_key = split_path(target)
659 if bucket != target_bucket:
660 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket}' -> '{target_bucket}'.")
661 relative_target = ObjectMetadata.encode_symlink_target(key, target_key)
662
663 def _invoke_api() -> None:
664 self._s3_client.put_object(
665 Bucket=bucket,
666 Key=key,
667 Body=b"",
668 Metadata={"msc-symlink-target": relative_target},
669 )
670
671 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
672
673 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
674 bucket, key = split_path(path)
675 if path.endswith("/") or (bucket and not key):
676 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
677 # which metadata is not guaranteed to exist for cases such as
678 # "virtual prefix" that was never explicitly created.
679 if self._is_dir(path):
680 return ObjectMetadata(
681 key=path,
682 type="directory",
683 content_length=0,
684 last_modified=AWARE_DATETIME_MIN,
685 )
686 else:
687 raise FileNotFoundError(f"Directory {path} does not exist.")
688 else:
689
690 def _invoke_api() -> ObjectMetadata:
691 response = self._s3_client.head_object(Bucket=bucket, Key=key)
692 user_metadata = response.get("Metadata")
693 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None
694
695 return ObjectMetadata(
696 key=path,
697 type="file",
698 content_length=response["ContentLength"],
699 content_type=response.get("ContentType"),
700 last_modified=response.get("LastModified", AWARE_DATETIME_MIN),
701 etag=response["ETag"].strip('"') if "ETag" in response else None,
702 storage_class=response.get("StorageClass"),
703 metadata=user_metadata,
704 symlink_target=symlink_target,
705 )
706
707 try:
708 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
709 except FileNotFoundError as error:
710 if strict:
711 # If the object does not exist on the given path, we will append a trailing slash and
712 # check if the path is a directory.
713 path = self._append_delimiter(path)
714 if self._is_dir(path):
715 return ObjectMetadata(
716 key=path,
717 type="directory",
718 content_length=0,
719 last_modified=AWARE_DATETIME_MIN,
720 )
721 raise error
722
723 def _list_objects(
724 self,
725 path: str,
726 start_after: Optional[str] = None,
727 end_at: Optional[str] = None,
728 include_directories: bool = False,
729 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW,
730 ) -> Iterator[ObjectMetadata]:
731 bucket, prefix = split_path(path)
732
733 # Get the prefix of the start_after and end_at paths relative to the bucket.
734 if start_after:
735 _, start_after = split_path(start_after)
736 if end_at:
737 _, end_at = split_path(end_at)
738
739 def _invoke_api() -> Iterator[ObjectMetadata]:
740 paginator = self._s3_client.get_paginator("list_objects_v2")
741 if include_directories:
742 page_iterator = paginator.paginate(
743 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
744 )
745 else:
746 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
747
748 for page in page_iterator:
749 for item in page.get("CommonPrefixes", []):
750 prefix_key = item["Prefix"].rstrip("/")
751 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes
752 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
753 yield ObjectMetadata(
754 key=os.path.join(bucket, prefix_key),
755 type="directory",
756 content_length=0,
757 last_modified=AWARE_DATETIME_MIN,
758 )
759 elif end_at is not None and end_at < prefix_key:
760 return
761
762 # S3 guarantees lexicographical order for general purpose buckets (for
763 # normal S3) but not directory buckets (for S3 Express One Zone).
764 for response_object in page.get("Contents", []):
765 key = response_object["Key"]
766 if end_at is None or key <= end_at:
767 if key.endswith("/"):
768 if include_directories:
769 yield ObjectMetadata(
770 key=os.path.join(bucket, key.rstrip("/")),
771 type="directory",
772 content_length=0,
773 last_modified=response_object["LastModified"],
774 )
775 else:
776 symlink_target = None
777 if response_object.get("Size", 0) == 0:
778 try:
779 meta = self._get_object_metadata(os.path.join(bucket, key))
780 symlink_target = meta.symlink_target
781 except Exception:
782 symlink_target = None
783 yield ObjectMetadata(
784 key=os.path.join(bucket, key),
785 type="file",
786 content_length=response_object["Size"],
787 last_modified=response_object["LastModified"],
788 etag=response_object["ETag"].strip('"'),
789 storage_class=response_object.get("StorageClass"),
790 symlink_target=symlink_target,
791 )
792 else:
793 return
794
795 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
796
797 @property
798 def supports_parallel_listing(self) -> bool:
799 """
800 S3 supports parallel listing via delimiter-based prefix discovery.
801
802 Note: Directory bucket handling is done in list_objects_recursive().
803 """
804 return True
805
[docs]
806 def list_objects_recursive(
807 self,
808 path: str = "",
809 start_after: Optional[str] = None,
810 end_at: Optional[str] = None,
811 max_workers: int = 32,
812 look_ahead: int = 2,
813 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW,
814 ) -> Iterator[ObjectMetadata]:
815 """
816 List all objects recursively using parallel prefix discovery for improved performance.
817
818 For S3, uses the Rust client's list_recursive when available for maximum performance.
819 Falls back to Python implementation otherwise.
820
821 Returns files only (no directories), in lexicographic order.
822
823 :param symlink_handling: How to handle symbolic links during listing.
824 """
825 if (start_after is not None) and (end_at is not None) and not (start_after < end_at):
826 raise ValueError(f"start_after ({start_after}) must be before end_at ({end_at})!")
827
828 full_path = self._prepend_base_path(path)
829 bucket, prefix = split_path(full_path)
830
831 if self._is_directory_bucket(bucket):
832 yield from self.list_objects(
833 path, start_after, end_at, include_directories=False, symlink_handling=symlink_handling
834 )
835 return
836
837 if self._rust_client:
838 yield from self._emit_metrics(
839 operation=BaseStorageProvider._Operation.LIST,
840 f=lambda: self._list_objects_recursive_rust(path, full_path, bucket, start_after, end_at, max_workers),
841 )
842 else:
843 yield from super().list_objects_recursive(
844 path, start_after, end_at, max_workers, look_ahead, symlink_handling
845 )
846
847 def _list_objects_recursive_rust(
848 self,
849 path: str,
850 full_path: str,
851 bucket: str,
852 start_after: Optional[str],
853 end_at: Optional[str],
854 max_workers: int,
855 ) -> Iterator[ObjectMetadata]:
856 """
857 Use Rust client's list_recursive for parallel listing.
858
859 The Rust client already handles parallel listing internally.
860 Returns files only in lexicographic order.
861 """
862 _, prefix = split_path(full_path)
863
864 def _invoke_api() -> Iterator[ObjectMetadata]:
865 result = run_async_rust_client_method(
866 self._rust_client,
867 "list_recursive",
868 [prefix] if prefix else [""],
869 max_concurrency=max_workers,
870 )
871
872 start_after_full = self._prepend_base_path(start_after) if start_after else None
873 end_at_full = self._prepend_base_path(end_at) if end_at else None
874
875 for obj in result.objects:
876 full_key = os.path.join(bucket, obj.key)
877
878 if start_after_full and full_key <= start_after_full:
879 continue
880 if end_at_full and full_key > end_at_full:
881 break
882
883 relative_key = full_key.removeprefix(self._base_path).lstrip("/")
884
885 yield ObjectMetadata(
886 key=relative_key,
887 content_length=obj.content_length,
888 last_modified=dateutil_parse(obj.last_modified),
889 type="file" if obj.object_type == "object" else obj.object_type,
890 etag=obj.etag,
891 )
892
893 yield from self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
894
895 def _upload_file(
896 self,
897 remote_path: str,
898 f: Union[str, IO],
899 attributes: Optional[dict[str, str]] = None,
900 content_type: Optional[str] = None,
901 ) -> int:
902 file_size: int = 0
903
904 if isinstance(f, str):
905 bucket, key = split_path(remote_path)
906 file_size = os.path.getsize(f)
907
908 # Upload small files
909 if file_size <= self._multipart_threshold:
910 if self._rust_client and not attributes and not content_type:
911 run_async_rust_client_method(self._rust_client, "upload", f, key)
912 else:
913 with open(f, "rb") as fp:
914 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
915 return file_size
916
917 # Upload large files using TransferConfig
918 def _invoke_api() -> int:
919 extra_args = {}
920 if content_type:
921 extra_args["ContentType"] = content_type
922 if self._is_directory_bucket(bucket):
923 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
924 validated_attributes = validate_attributes(attributes)
925 if validated_attributes:
926 extra_args["Metadata"] = validated_attributes
927 if self._rust_client and not extra_args:
928 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
929 else:
930 if self._checksum_algorithm:
931 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm
932 self._s3_client.upload_file(
933 Filename=f,
934 Bucket=bucket,
935 Key=key,
936 Config=self._transfer_config,
937 ExtraArgs=extra_args,
938 )
939
940 return file_size
941
942 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
943 else:
944 # Upload small files
945 f.seek(0, io.SEEK_END)
946 file_size = f.tell()
947 f.seek(0)
948
949 if file_size <= self._multipart_threshold:
950 if isinstance(f, io.StringIO):
951 self._put_object(
952 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
953 )
954 else:
955 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
956 return file_size
957
958 # Upload large files using TransferConfig
959 bucket, key = split_path(remote_path)
960
961 def _invoke_api() -> int:
962 extra_args = {}
963 if content_type:
964 extra_args["ContentType"] = content_type
965 if self._is_directory_bucket(bucket):
966 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
967 validated_attributes = validate_attributes(attributes)
968 if validated_attributes:
969 extra_args["Metadata"] = validated_attributes
970
971 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
972 data = f.getbuffer()
973 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
974 else:
975 if self._checksum_algorithm:
976 extra_args["ChecksumAlgorithm"] = self._checksum_algorithm
977 self._s3_client.upload_fileobj(
978 Fileobj=f,
979 Bucket=bucket,
980 Key=key,
981 Config=self._transfer_config,
982 ExtraArgs=extra_args,
983 )
984
985 return file_size
986
987 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
988
989 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
990 if metadata is None:
991 metadata = self._get_object_metadata(remote_path)
992
993 if isinstance(f, str):
994 bucket, key = split_path(remote_path)
995 if os.path.dirname(f):
996 safe_makedirs(os.path.dirname(f))
997
998 # Download small files
999 if metadata.content_length <= self._multipart_threshold:
1000 if self._rust_client:
1001 run_async_rust_client_method(self._rust_client, "download", key, f)
1002 else:
1003 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
1004 temp_file_path = fp.name
1005 fp.write(self._get_object(remote_path))
1006 os.rename(src=temp_file_path, dst=f)
1007 return metadata.content_length
1008
1009 # Download large files using TransferConfig
1010 def _invoke_api() -> int:
1011 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
1012 temp_file_path = fp.name
1013 if self._rust_client:
1014 run_async_rust_client_method(
1015 self._rust_client, "download_multipart_to_file", key, temp_file_path
1016 )
1017 else:
1018 self._s3_client.download_fileobj(
1019 Bucket=bucket,
1020 Key=key,
1021 Fileobj=fp,
1022 Config=self._transfer_config,
1023 )
1024
1025 os.rename(src=temp_file_path, dst=f)
1026
1027 return metadata.content_length
1028
1029 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
1030 else:
1031 # Download small files
1032 if metadata.content_length <= self._multipart_threshold:
1033 response = self._get_object(remote_path)
1034 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol,
1035 # so we need to check whether `.decode()` is available.
1036 if isinstance(f, io.StringIO):
1037 if hasattr(response, "decode"):
1038 f.write(response.decode("utf-8"))
1039 else:
1040 f.write(codecs.decode(memoryview(response), "utf-8"))
1041 else:
1042 f.write(response)
1043 return metadata.content_length
1044
1045 # Download large files using TransferConfig
1046 bucket, key = split_path(remote_path)
1047
1048 def _invoke_api() -> int:
1049 self._s3_client.download_fileobj(
1050 Bucket=bucket,
1051 Key=key,
1052 Fileobj=f,
1053 Config=self._transfer_config,
1054 )
1055
1056 return metadata.content_length
1057
1058 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
1059
1060 def _generate_presigned_url(
1061 self,
1062 path: str,
1063 *,
1064 method: str = "GET",
1065 signer_type: Optional[SignerType] = None,
1066 signer_options: Optional[dict[str, Any]] = None,
1067 ) -> str:
1068 options = signer_options or {}
1069 bucket, key = split_path(path)
1070
1071 if signer_type is None or signer_type == SignerType.S3:
1072 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN))
1073 cache_key: tuple = (SignerType.S3, bucket, expires_in)
1074 if cache_key not in self._signer_cache:
1075 self._signer_cache[cache_key] = S3URLSigner(self._s3_client, bucket, expires_in=expires_in)
1076 elif signer_type == SignerType.CLOUDFRONT:
1077 cache_key = (SignerType.CLOUDFRONT, frozenset(options.items()))
1078 if cache_key not in self._signer_cache:
1079 self._signer_cache[cache_key] = CloudFrontURLSigner(**options)
1080 else:
1081 raise ValueError(f"Unsupported signer type for S3 provider: {signer_type!r}")
1082
1083 return self._signer_cache[cache_key].generate_presigned_url(key, method=method)