1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import os
19import tempfile
20from collections.abc import Callable, Iterator
21from typing import IO, Any, Optional, TypeVar, Union
22
23import boto3
24import botocore
25from boto3.s3.transfer import TransferConfig
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import ClientError, IncompleteReadError, ReadTimeoutError, ResponseStreamingError
28from botocore.session import get_session
29from dateutil.parser import parse as dateutil_parse
30
31from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
32
33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
34from ..rust_utils import parse_retry_config, run_async_rust_client_method
35from ..signers import CloudFrontURLSigner, URLSigner
36from ..telemetry import Telemetry
37from ..types import (
38 AWARE_DATETIME_MIN,
39 Credentials,
40 CredentialsProvider,
41 ObjectMetadata,
42 PreconditionFailedError,
43 Range,
44 RetryableError,
45 SignerType,
46)
47from ..utils import (
48 get_available_cpu_count,
49 safe_makedirs,
50 split_path,
51 validate_attributes,
52)
53from .base import BaseStorageProvider
54
55_T = TypeVar("_T")
56
57# Default connection pool size scales with CPU count or MSC Sync Threads count (minimum 64)
58MAX_POOL_CONNECTIONS = max(
59 64,
60 get_available_cpu_count(),
61 int(os.getenv("MSC_NUM_THREADS_PER_PROCESS", "0")),
62)
63
64MiB = 1024 * 1024
65
66# Python and Rust share the same multipart_threshold to keep the code simple.
67MULTIPART_THRESHOLD = 64 * MiB
68MULTIPART_CHUNKSIZE = 32 * MiB
69IO_CHUNKSIZE = 32 * MiB
70PYTHON_MAX_CONCURRENCY = 8
71
72PROVIDER = "s3"
73
74EXPRESS_ONEZONE_STORAGE_CLASS = "EXPRESS_ONEZONE"
75
76
[docs]
77class StaticS3CredentialsProvider(CredentialsProvider):
78 """
79 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
80 """
81
82 _access_key: str
83 _secret_key: str
84 _session_token: Optional[str]
85
86 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
87 """
88 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
89 session token.
90
91 :param access_key: The access key for S3 authentication.
92 :param secret_key: The secret key for S3 authentication.
93 :param session_token: An optional session token for temporary credentials.
94 """
95 self._access_key = access_key
96 self._secret_key = secret_key
97 self._session_token = session_token
98
[docs]
99 def get_credentials(self) -> Credentials:
100 return Credentials(
101 access_key=self._access_key,
102 secret_key=self._secret_key,
103 token=self._session_token,
104 expiration=None,
105 )
106
[docs]
107 def refresh_credentials(self) -> None:
108 pass
109
110
111DEFAULT_PRESIGN_EXPIRES_IN = 3600
112
113_S3_METHOD_MAPPING: dict[str, str] = {
114 "GET": "get_object",
115 "PUT": "put_object",
116}
117
118
[docs]
119class S3URLSigner(URLSigner):
120 """Generates pre-signed URLs using the boto3 S3 client.
121
122 When the underlying credentials are temporary (STS, IAM role, EC2 instance
123 profile), the effective URL lifetime is the **shorter** of ``expires_in``
124 and the remaining credential lifetime — boto3 will not warn if the
125 credential expires before ``expires_in``.
126
127 See https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html
128 """
129
130 def __init__(self, s3_client: Any, bucket: str, expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN) -> None:
131 self._s3_client = s3_client
132 self._bucket = bucket
133 self._expires_in = expires_in
134
[docs]
135 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str:
136 client_method = _S3_METHOD_MAPPING.get(method.upper())
137 if client_method is None:
138 raise ValueError(f"Unsupported method for S3 presigning: {method!r}")
139 return self._s3_client.generate_presigned_url(
140 ClientMethod=client_method,
141 Params={"Bucket": self._bucket, "Key": path},
142 ExpiresIn=self._expires_in,
143 )
144
145
[docs]
146class S3StorageProvider(BaseStorageProvider):
147 """
148 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or S3-compatible object stores.
149 """
150
151 def __init__(
152 self,
153 region_name: str = "",
154 endpoint_url: str = "",
155 base_path: str = "",
156 credentials_provider: Optional[CredentialsProvider] = None,
157 config_dict: Optional[dict[str, Any]] = None,
158 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
159 verify: Optional[Union[bool, str]] = None,
160 **kwargs: Any,
161 ) -> None:
162 """
163 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
164
165 :param region_name: The AWS region where the S3 bucket is located.
166 :param endpoint_url: The custom endpoint URL for the S3 service.
167 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
168 :param credentials_provider: The provider to retrieve S3 credentials.
169 :param config_dict: Resolved MSC config.
170 :param telemetry_provider: A function that provides a telemetry instance.
171 :param verify: Controls SSL certificate verification.
172 Can be ``True`` (verify using system CA bundle, default), ``False`` (skip verification), or a string path to a custom CA certificate bundle.
173 :param request_checksum_calculation: For :py:class:`botocore.config.Config`.
174 When the underlying S3 client should calculate request checksums.
175 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
176 :param response_checksum_validation: For :py:class:`botocore.config.Config`.
177 When the underlying S3 client should validate response checksums.
178 See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
179 :param max_pool_connections: For :py:class:`botocore.config.Config`.
180 The maximum number of connections to keep in a connection pool.
181 :param connect_timeout: For :py:class:`botocore.config.Config`.
182 The time in seconds till a timeout exception is thrown when attempting to make a connection.
183 :param read_timeout: For :py:class:`botocore.config.Config`.
184 The time in seconds till a timeout exception is thrown when attempting to read from a connection.
185 :param retries: For :py:class:`botocore.config.Config`.
186 A dictionary for configuration related to retry behavior.
187 :param s3: For :py:class:`botocore.config.Config`.
188 A dictionary of S3 specific configurations.
189 """
190 super().__init__(
191 base_path=base_path,
192 provider_name=PROVIDER,
193 config_dict=config_dict,
194 telemetry_provider=telemetry_provider,
195 )
196
197 self._region_name = region_name
198 self._endpoint_url = endpoint_url
199 self._credentials_provider = credentials_provider
200 self._verify = verify
201
202 self._signature_version = kwargs.get("signature_version", "s3v4")
203 self._s3_client = self._create_s3_client(
204 request_checksum_calculation=kwargs.get("request_checksum_calculation"),
205 response_checksum_validation=kwargs.get("response_checksum_validation"),
206 max_pool_connections=kwargs.get("max_pool_connections", MAX_POOL_CONNECTIONS),
207 connect_timeout=kwargs.get("connect_timeout", DEFAULT_CONNECT_TIMEOUT),
208 read_timeout=kwargs.get("read_timeout", DEFAULT_READ_TIMEOUT),
209 retries=kwargs.get("retries"),
210 s3=kwargs.get("s3"),
211 )
212 self._transfer_config = TransferConfig(
213 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
214 max_concurrency=int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)),
215 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)),
216 io_chunksize=int(kwargs.get("io_chunksize", IO_CHUNKSIZE)),
217 use_threads=True,
218 )
219
220 self._signer_cache: dict[tuple, URLSigner] = {}
221
222 self._rust_client = None
223 if "rust_client" in kwargs:
224 # Inherit the rust client options from the kwargs
225 rust_client_options = kwargs["rust_client"]
226 if "max_pool_connections" in kwargs:
227 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
228 if "max_concurrency" in kwargs:
229 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
230 if "multipart_chunksize" in kwargs:
231 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
232 if "read_timeout" in kwargs:
233 rust_client_options["read_timeout"] = kwargs["read_timeout"]
234 if "connect_timeout" in kwargs:
235 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
236 if self._signature_version == "UNSIGNED":
237 rust_client_options["skip_signature"] = True
238 self._rust_client = self._create_rust_client(rust_client_options)
239
240 def _is_directory_bucket(self, bucket: str) -> bool:
241 """
242 Determines if the bucket is a directory bucket based on bucket name.
243 """
244 # S3 Express buckets have a specific naming convention
245 return "--x-s3" in bucket
246
247 def _create_s3_client(
248 self,
249 request_checksum_calculation: Optional[str] = None,
250 response_checksum_validation: Optional[str] = None,
251 max_pool_connections: int = MAX_POOL_CONNECTIONS,
252 connect_timeout: Union[float, int, None] = None,
253 read_timeout: Union[float, int, None] = None,
254 retries: Optional[dict[str, Any]] = None,
255 s3: Optional[dict[str, Any]] = None,
256 ):
257 """
258 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
259
260 :param request_checksum_calculation: For :py:class:`botocore.config.Config`. When the underlying S3 client should calculate request checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
261 :param response_checksum_validation: For :py:class:`botocore.config.Config`. When the underlying S3 client should validate response checksums. See the equivalent option in the `AWS configuration file <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-a-configuration-file>`_.
262 :param max_pool_connections: For :py:class:`botocore.config.Config`. The maximum number of connections to keep in a connection pool.
263 :param connect_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to make a connection.
264 :param read_timeout: For :py:class:`botocore.config.Config`. The time in seconds till a timeout exception is thrown when attempting to read from a connection.
265 :param retries: For :py:class:`botocore.config.Config`. A dictionary for configuration related to retry behavior.
266 :param s3: For :py:class:`botocore.config.Config`. A dictionary of S3 specific configurations.
267
268 :return: The configured S3 client.
269 """
270 options = {
271 # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
272 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
273 max_pool_connections=max_pool_connections,
274 connect_timeout=connect_timeout,
275 read_timeout=read_timeout,
276 retries=retries or {"mode": "standard"},
277 request_checksum_calculation=request_checksum_calculation,
278 response_checksum_validation=response_checksum_validation,
279 s3=s3,
280 ),
281 }
282
283 if self._region_name:
284 options["region_name"] = self._region_name
285
286 if self._endpoint_url:
287 options["endpoint_url"] = self._endpoint_url
288
289 if self._verify is not None:
290 options["verify"] = self._verify
291
292 if self._credentials_provider:
293 creds = self._fetch_credentials()
294 if "expiry_time" in creds and creds["expiry_time"]:
295 # Use RefreshableCredentials if expiry_time provided.
296 refreshable_credentials = RefreshableCredentials.create_from_metadata(
297 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
298 )
299
300 botocore_session = get_session()
301 botocore_session._credentials = refreshable_credentials
302
303 boto3_session = boto3.Session(botocore_session=botocore_session)
304
305 return boto3_session.client("s3", **options)
306 else:
307 # Add static credentials to the options dictionary
308 options["aws_access_key_id"] = creds["access_key"]
309 options["aws_secret_access_key"] = creds["secret_key"]
310 if creds["token"]:
311 options["aws_session_token"] = creds["token"]
312
313 if self._signature_version:
314 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
315 signature_version=botocore.UNSIGNED
316 if self._signature_version == "UNSIGNED"
317 else self._signature_version
318 )
319 options["config"] = options["config"].merge(signature_config)
320
321 # Fallback to standard credential chain.
322 return boto3.client("s3", **options)
323
324 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
325 """
326 Creates and configures the rust client, using refreshable credentials if possible.
327 """
328 configs = dict(rust_client_options) if rust_client_options else {}
329
330 # Extract and parse retry configuration
331 retry_config = parse_retry_config(configs)
332
333 if self._region_name and "region_name" not in configs:
334 configs["region_name"] = self._region_name
335
336 if self._endpoint_url and "endpoint_url" not in configs:
337 configs["endpoint_url"] = self._endpoint_url
338
339 # If the user specifies a bucket, use it. Otherwise, use the base path.
340 if "bucket" not in configs:
341 bucket, _ = split_path(self._base_path)
342 configs["bucket"] = bucket
343
344 if "max_pool_connections" not in configs:
345 configs["max_pool_connections"] = MAX_POOL_CONNECTIONS
346
347 return RustClient(
348 provider=PROVIDER,
349 configs=configs,
350 credentials_provider=self._credentials_provider,
351 retry=retry_config,
352 )
353
354 def _fetch_credentials(self) -> dict:
355 """
356 Refreshes the S3 client if the current credentials are expired.
357 """
358 if not self._credentials_provider:
359 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
360 self._credentials_provider.refresh_credentials()
361 credentials = self._credentials_provider.get_credentials()
362 return {
363 "access_key": credentials.access_key,
364 "secret_key": credentials.secret_key,
365 "token": credentials.token,
366 "expiry_time": credentials.expiration,
367 }
368
369 def _translate_errors(
370 self,
371 func: Callable[[], _T],
372 operation: str,
373 bucket: str,
374 key: str,
375 ) -> _T:
376 """
377 Translates errors like timeouts and client errors.
378
379 :param func: The function that performs the actual S3 operation.
380 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
381 :param bucket: The name of the S3 bucket involved in the operation.
382 :param key: The key of the object within the S3 bucket.
383
384 :return: The result of the S3 operation, typically the return value of the `func` callable.
385 """
386 try:
387 return func()
388 except ClientError as error:
389 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
390 request_id = error.response["ResponseMetadata"].get("RequestId")
391 host_id = error.response["ResponseMetadata"].get("HostId")
392 error_code = error.response["Error"]["Code"]
393 error_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
394
395 if status_code == 404:
396 if error_code == "NoSuchUpload":
397 error_message = error.response["Error"]["Message"]
398 raise RetryableError(f"Multipart upload failed for {bucket}/{key}: {error_message}") from error
399 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
400 elif status_code == 412: # Precondition Failed
401 raise PreconditionFailedError(
402 f"ETag mismatch for {operation} operation on {bucket}/{key}. {error_info}"
403 ) from error
404 elif status_code == 429:
405 raise RetryableError(
406 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
407 ) from error
408 elif status_code == 503:
409 raise RetryableError(
410 f"Service unavailable when {operation} object(s) at {bucket}/{key}. {error_info}"
411 ) from error
412 elif status_code == 501:
413 raise NotImplementedError(
414 f"Operation {operation} not implemented for object(s) at {bucket}/{key}. {error_info}"
415 ) from error
416 elif status_code == 408:
417 # 408 Request Timeout is from Google Cloud Storage
418 raise RetryableError(
419 f"Request timeout when {operation} object(s) at {bucket}/{key}. {error_info}"
420 ) from error
421 else:
422 raise RuntimeError(
423 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}, "
424 f"error_type: {type(error).__name__}"
425 ) from error
426 except RustClientError as error:
427 message = error.args[0]
428 status_code = error.args[1]
429 if status_code == 404:
430 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
431 elif status_code == 403:
432 raise PermissionError(
433 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
434 ) from error
435 else:
436 raise RetryableError(
437 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
438 ) from error
439 except (ReadTimeoutError, IncompleteReadError, ResponseStreamingError) as error:
440 raise RetryableError(
441 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read. "
442 f"error_type: {type(error).__name__}"
443 ) from error
444 except RustRetryableError as error:
445 raise RetryableError(
446 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
447 f"error_type: {type(error).__name__}"
448 ) from error
449 except Exception as error:
450 raise RuntimeError(
451 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
452 ) from error
453
454 def _put_object(
455 self,
456 path: str,
457 body: bytes,
458 if_match: Optional[str] = None,
459 if_none_match: Optional[str] = None,
460 attributes: Optional[dict[str, str]] = None,
461 content_type: Optional[str] = None,
462 ) -> int:
463 """
464 Uploads an object to the specified S3 path.
465
466 :param path: The S3 path where the object will be uploaded.
467 :param body: The content of the object as bytes.
468 :param if_match: Optional If-Match header value. Use "*" to only upload if the object doesn't exist.
469 :param if_none_match: Optional If-None-Match header value. Use "*" to only upload if the object doesn't exist.
470 :param attributes: Optional attributes to attach to the object.
471 :param content_type: Optional Content-Type header value.
472 """
473 bucket, key = split_path(path)
474
475 def _invoke_api() -> int:
476 kwargs = {"Bucket": bucket, "Key": key, "Body": body}
477 if content_type:
478 kwargs["ContentType"] = content_type
479 if self._is_directory_bucket(bucket):
480 kwargs["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
481 if if_match:
482 kwargs["IfMatch"] = if_match
483 if if_none_match:
484 kwargs["IfNoneMatch"] = if_none_match
485 validated_attributes = validate_attributes(attributes)
486 if validated_attributes:
487 kwargs["Metadata"] = validated_attributes
488
489 # TODO(NGCDP-5804): Add support to update ContentType header in Rust client
490 rust_unsupported_feature_keys = {"Metadata", "StorageClass", "IfMatch", "IfNoneMatch", "ContentType"}
491 if (
492 self._rust_client
493 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
494 and not path.endswith("/")
495 and all(key not in kwargs for key in rust_unsupported_feature_keys)
496 ):
497 run_async_rust_client_method(self._rust_client, "put", key, body)
498 else:
499 self._s3_client.put_object(**kwargs)
500
501 return len(body)
502
503 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
504
505 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
506 bucket, key = split_path(path)
507
508 def _invoke_api() -> bytes:
509 if byte_range:
510 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
511 if self._rust_client:
512 response = run_async_rust_client_method(
513 self._rust_client,
514 "get",
515 key,
516 byte_range,
517 )
518 return response
519 else:
520 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
521 else:
522 if self._rust_client:
523 response = run_async_rust_client_method(self._rust_client, "get", key)
524 return response
525 else:
526 response = self._s3_client.get_object(Bucket=bucket, Key=key)
527
528 return response["Body"].read()
529
530 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
531
532 def _copy_object(self, src_path: str, dest_path: str) -> int:
533 src_bucket, src_key = split_path(src_path)
534 dest_bucket, dest_key = split_path(dest_path)
535
536 src_object = self._get_object_metadata(src_path)
537
538 def _invoke_api() -> int:
539 self._s3_client.copy(
540 CopySource={"Bucket": src_bucket, "Key": src_key},
541 Bucket=dest_bucket,
542 Key=dest_key,
543 Config=self._transfer_config,
544 )
545
546 return src_object.content_length
547
548 return self._translate_errors(_invoke_api, operation="COPY", bucket=dest_bucket, key=dest_key)
549
550 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
551 bucket, key = split_path(path)
552
553 def _invoke_api() -> None:
554 # Delete conditionally when if_match (etag) is provided; otherwise delete unconditionally
555 if if_match:
556 self._s3_client.delete_object(Bucket=bucket, Key=key, IfMatch=if_match)
557 else:
558 self._s3_client.delete_object(Bucket=bucket, Key=key)
559
560 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
561
562 def _delete_objects(self, paths: list[str]) -> None:
563 if not paths:
564 return
565
566 by_bucket: dict[str, list[str]] = {}
567 for p in paths:
568 bucket, key = split_path(p)
569 by_bucket.setdefault(bucket, []).append(key)
570
571 S3_BATCH_LIMIT = 1000
572
573 def _invoke_api() -> None:
574 all_errors: list[str] = []
575 for bucket, keys in by_bucket.items():
576 for i in range(0, len(keys), S3_BATCH_LIMIT):
577 chunk = keys[i : i + S3_BATCH_LIMIT]
578 response = self._s3_client.delete_objects(
579 Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}
580 )
581 errors = response.get("Errors") or []
582 for e in errors:
583 all_errors.append(f"{bucket}/{e.get('Key', '?')}: {e.get('Code', '')} {e.get('Message', '')}")
584 if all_errors:
585 raise RuntimeError(f"DeleteObjects reported errors: {'; '.join(all_errors)}")
586
587 bucket_desc = "(" + "|".join(by_bucket) + ")"
588 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)"
589 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc)
590
591 def _is_dir(self, path: str) -> bool:
592 # Ensure the path ends with '/' to mimic a directory
593 path = self._append_delimiter(path)
594
595 bucket, key = split_path(path)
596
597 def _invoke_api() -> bool:
598 # List objects with the given prefix
599 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
600
601 # Check if there are any contents or common prefixes
602 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
603
604 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
605
606 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
607 bucket, key = split_path(path)
608 if path.endswith("/") or (bucket and not key):
609 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
610 # which metadata is not guaranteed to exist for cases such as
611 # "virtual prefix" that was never explicitly created.
612 if self._is_dir(path):
613 return ObjectMetadata(
614 key=path,
615 type="directory",
616 content_length=0,
617 last_modified=AWARE_DATETIME_MIN,
618 )
619 else:
620 raise FileNotFoundError(f"Directory {path} does not exist.")
621 else:
622
623 def _invoke_api() -> ObjectMetadata:
624 response = self._s3_client.head_object(Bucket=bucket, Key=key)
625
626 return ObjectMetadata(
627 key=path,
628 type="file",
629 content_length=response["ContentLength"],
630 content_type=response.get("ContentType"),
631 last_modified=response["LastModified"],
632 etag=response["ETag"].strip('"') if "ETag" in response else None,
633 storage_class=response.get("StorageClass"),
634 metadata=response.get("Metadata"),
635 )
636
637 try:
638 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
639 except FileNotFoundError as error:
640 if strict:
641 # If the object does not exist on the given path, we will append a trailing slash and
642 # check if the path is a directory.
643 path = self._append_delimiter(path)
644 if self._is_dir(path):
645 return ObjectMetadata(
646 key=path,
647 type="directory",
648 content_length=0,
649 last_modified=AWARE_DATETIME_MIN,
650 )
651 raise error
652
653 def _list_objects(
654 self,
655 path: str,
656 start_after: Optional[str] = None,
657 end_at: Optional[str] = None,
658 include_directories: bool = False,
659 follow_symlinks: bool = True,
660 ) -> Iterator[ObjectMetadata]:
661 bucket, prefix = split_path(path)
662
663 # Get the prefix of the start_after and end_at paths relative to the bucket.
664 if start_after:
665 _, start_after = split_path(start_after)
666 if end_at:
667 _, end_at = split_path(end_at)
668
669 def _invoke_api() -> Iterator[ObjectMetadata]:
670 paginator = self._s3_client.get_paginator("list_objects_v2")
671 if include_directories:
672 page_iterator = paginator.paginate(
673 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
674 )
675 else:
676 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
677
678 for page in page_iterator:
679 for item in page.get("CommonPrefixes", []):
680 prefix_key = item["Prefix"].rstrip("/")
681 # Filter by start_after and end_at - S3's StartAfter doesn't filter CommonPrefixes
682 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
683 yield ObjectMetadata(
684 key=os.path.join(bucket, prefix_key),
685 type="directory",
686 content_length=0,
687 last_modified=AWARE_DATETIME_MIN,
688 )
689 elif end_at is not None and end_at < prefix_key:
690 return
691
692 # S3 guarantees lexicographical order for general purpose buckets (for
693 # normal S3) but not directory buckets (for S3 Express One Zone).
694 for response_object in page.get("Contents", []):
695 key = response_object["Key"]
696 if end_at is None or key <= end_at:
697 if key.endswith("/"):
698 if include_directories:
699 yield ObjectMetadata(
700 key=os.path.join(bucket, key.rstrip("/")),
701 type="directory",
702 content_length=0,
703 last_modified=response_object["LastModified"],
704 )
705 else:
706 yield ObjectMetadata(
707 key=os.path.join(bucket, key),
708 type="file",
709 content_length=response_object["Size"],
710 last_modified=response_object["LastModified"],
711 etag=response_object["ETag"].strip('"'),
712 storage_class=response_object.get("StorageClass"), # Pass storage_class
713 )
714 else:
715 return
716
717 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
718
719 @property
720 def supports_parallel_listing(self) -> bool:
721 """
722 S3 supports parallel listing via delimiter-based prefix discovery.
723
724 Note: Directory bucket handling is done in list_objects_recursive().
725 """
726 return True
727
[docs]
728 def list_objects_recursive(
729 self,
730 path: str = "",
731 start_after: Optional[str] = None,
732 end_at: Optional[str] = None,
733 max_workers: int = 32,
734 look_ahead: int = 2,
735 follow_symlinks: bool = True,
736 ) -> Iterator[ObjectMetadata]:
737 """
738 List all objects recursively using parallel prefix discovery for improved performance.
739
740 For S3, uses the Rust client's list_recursive when available for maximum performance.
741 Falls back to Python implementation otherwise.
742
743 Returns files only (no directories), in lexicographic order.
744
745 :param follow_symlinks: Whether to follow symbolic links (POSIX providers only).
746 """
747 if (start_after is not None) and (end_at is not None) and not (start_after < end_at):
748 raise ValueError(f"start_after ({start_after}) must be before end_at ({end_at})!")
749
750 full_path = self._prepend_base_path(path)
751 bucket, prefix = split_path(full_path)
752
753 if self._is_directory_bucket(bucket):
754 yield from self.list_objects(
755 path, start_after, end_at, include_directories=False, follow_symlinks=follow_symlinks
756 )
757 return
758
759 if self._rust_client:
760 yield from self._emit_metrics(
761 operation=BaseStorageProvider._Operation.LIST,
762 f=lambda: self._list_objects_recursive_rust(path, full_path, bucket, start_after, end_at, max_workers),
763 )
764 else:
765 yield from super().list_objects_recursive(
766 path, start_after, end_at, max_workers, look_ahead, follow_symlinks
767 )
768
769 def _list_objects_recursive_rust(
770 self,
771 path: str,
772 full_path: str,
773 bucket: str,
774 start_after: Optional[str],
775 end_at: Optional[str],
776 max_workers: int,
777 ) -> Iterator[ObjectMetadata]:
778 """
779 Use Rust client's list_recursive for parallel listing.
780
781 The Rust client already handles parallel listing internally.
782 Returns files only in lexicographic order.
783 """
784 _, prefix = split_path(full_path)
785
786 def _invoke_api() -> Iterator[ObjectMetadata]:
787 result = run_async_rust_client_method(
788 self._rust_client,
789 "list_recursive",
790 [prefix] if prefix else [""],
791 max_concurrency=max_workers,
792 )
793
794 start_after_full = self._prepend_base_path(start_after) if start_after else None
795 end_at_full = self._prepend_base_path(end_at) if end_at else None
796
797 for obj in result.objects:
798 full_key = os.path.join(bucket, obj.key)
799
800 if start_after_full and full_key <= start_after_full:
801 continue
802 if end_at_full and full_key > end_at_full:
803 break
804
805 relative_key = full_key.removeprefix(self._base_path).lstrip("/")
806
807 yield ObjectMetadata(
808 key=relative_key,
809 content_length=obj.content_length,
810 last_modified=dateutil_parse(obj.last_modified),
811 type="file" if obj.object_type == "object" else obj.object_type,
812 etag=obj.etag,
813 )
814
815 yield from self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
816
817 def _upload_file(
818 self,
819 remote_path: str,
820 f: Union[str, IO],
821 attributes: Optional[dict[str, str]] = None,
822 content_type: Optional[str] = None,
823 ) -> int:
824 file_size: int = 0
825
826 if isinstance(f, str):
827 bucket, key = split_path(remote_path)
828 file_size = os.path.getsize(f)
829
830 # Upload small files
831 if file_size <= self._transfer_config.multipart_threshold:
832 if self._rust_client and not attributes and not content_type:
833 run_async_rust_client_method(self._rust_client, "upload", f, key)
834 else:
835 with open(f, "rb") as fp:
836 self._put_object(remote_path, fp.read(), attributes=attributes, content_type=content_type)
837 return file_size
838
839 # Upload large files using TransferConfig
840 def _invoke_api() -> int:
841 extra_args = {}
842 if content_type:
843 extra_args["ContentType"] = content_type
844 if self._is_directory_bucket(bucket):
845 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
846 validated_attributes = validate_attributes(attributes)
847 if validated_attributes:
848 extra_args["Metadata"] = validated_attributes
849 if self._rust_client and not extra_args:
850 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
851 else:
852 self._s3_client.upload_file(
853 Filename=f,
854 Bucket=bucket,
855 Key=key,
856 Config=self._transfer_config,
857 ExtraArgs=extra_args,
858 )
859
860 return file_size
861
862 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
863 else:
864 # Upload small files
865 f.seek(0, io.SEEK_END)
866 file_size = f.tell()
867 f.seek(0)
868
869 if file_size <= self._transfer_config.multipart_threshold:
870 if isinstance(f, io.StringIO):
871 self._put_object(
872 remote_path, f.read().encode("utf-8"), attributes=attributes, content_type=content_type
873 )
874 else:
875 self._put_object(remote_path, f.read(), attributes=attributes, content_type=content_type)
876 return file_size
877
878 # Upload large files using TransferConfig
879 bucket, key = split_path(remote_path)
880
881 def _invoke_api() -> int:
882 extra_args = {}
883 if content_type:
884 extra_args["ContentType"] = content_type
885 if self._is_directory_bucket(bucket):
886 extra_args["StorageClass"] = EXPRESS_ONEZONE_STORAGE_CLASS
887 validated_attributes = validate_attributes(attributes)
888 if validated_attributes:
889 extra_args["Metadata"] = validated_attributes
890
891 if self._rust_client and isinstance(f, io.BytesIO) and not extra_args:
892 data = f.getbuffer()
893 run_async_rust_client_method(self._rust_client, "upload_multipart_from_bytes", key, data)
894 else:
895 self._s3_client.upload_fileobj(
896 Fileobj=f,
897 Bucket=bucket,
898 Key=key,
899 Config=self._transfer_config,
900 ExtraArgs=extra_args,
901 )
902
903 return file_size
904
905 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
906
907 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
908 if metadata is None:
909 metadata = self._get_object_metadata(remote_path)
910
911 if isinstance(f, str):
912 bucket, key = split_path(remote_path)
913 if os.path.dirname(f):
914 safe_makedirs(os.path.dirname(f))
915
916 # Download small files
917 if metadata.content_length <= self._transfer_config.multipart_threshold:
918 if self._rust_client:
919 run_async_rust_client_method(self._rust_client, "download", key, f)
920 else:
921 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
922 temp_file_path = fp.name
923 fp.write(self._get_object(remote_path))
924 os.rename(src=temp_file_path, dst=f)
925 return metadata.content_length
926
927 # Download large files using TransferConfig
928 def _invoke_api() -> int:
929 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
930 temp_file_path = fp.name
931 if self._rust_client:
932 run_async_rust_client_method(
933 self._rust_client, "download_multipart_to_file", key, temp_file_path
934 )
935 else:
936 self._s3_client.download_fileobj(
937 Bucket=bucket,
938 Key=key,
939 Fileobj=fp,
940 Config=self._transfer_config,
941 )
942
943 os.rename(src=temp_file_path, dst=f)
944
945 return metadata.content_length
946
947 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
948 else:
949 # Download small files
950 if metadata.content_length <= self._transfer_config.multipart_threshold:
951 response = self._get_object(remote_path)
952 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol,
953 # so we need to check whether `.decode()` is available.
954 if isinstance(f, io.StringIO):
955 if hasattr(response, "decode"):
956 f.write(response.decode("utf-8"))
957 else:
958 f.write(codecs.decode(memoryview(response), "utf-8"))
959 else:
960 f.write(response)
961 return metadata.content_length
962
963 # Download large files using TransferConfig
964 bucket, key = split_path(remote_path)
965
966 def _invoke_api() -> int:
967 self._s3_client.download_fileobj(
968 Bucket=bucket,
969 Key=key,
970 Fileobj=f,
971 Config=self._transfer_config,
972 )
973
974 return metadata.content_length
975
976 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
977
978 def _generate_presigned_url(
979 self,
980 path: str,
981 *,
982 method: str = "GET",
983 signer_type: Optional[SignerType] = None,
984 signer_options: Optional[dict[str, Any]] = None,
985 ) -> str:
986 options = signer_options or {}
987 bucket, key = split_path(path)
988
989 if signer_type is None or signer_type == SignerType.S3:
990 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN))
991 cache_key: tuple = (SignerType.S3, bucket, expires_in)
992 if cache_key not in self._signer_cache:
993 self._signer_cache[cache_key] = S3URLSigner(self._s3_client, bucket, expires_in=expires_in)
994 elif signer_type == SignerType.CLOUDFRONT:
995 cache_key = (SignerType.CLOUDFRONT, frozenset(options.items()))
996 if cache_key not in self._signer_cache:
997 self._signer_cache[cache_key] = CloudFrontURLSigner(**options)
998 else:
999 raise ValueError(f"Unsupported signer type for S3 provider: {signer_type!r}")
1000
1001 return self._signer_cache[cache_key].generate_presigned_url(key, method=method)