1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import time
20from datetime import datetime
21from typing import IO, Any, Callable, Iterator, Optional, Union
22
23import boto3
24from boto3.s3.transfer import TransferConfig
25import botocore
26from botocore.credentials import RefreshableCredentials
27from botocore.exceptions import (
28 ClientError,
29 ReadTimeoutError,
30 IncompleteReadError,
31)
32from botocore.session import get_session
33
34from ..types import (
35 Credentials,
36 CredentialsProvider,
37 ObjectMetadata,
38 Range,
39 RetryableError,
40)
41from ..utils import split_path
42from .base import BaseStorageProvider
43
44BOTO3_MAX_POOL_CONNECTIONS = 32
45BOTO3_CONNECT_TIMEOUT = 10
46BOTO3_READ_TIMEOUT = 10
47
48MB = 1024 * 1024
49
50MULTIPART_THRESHOLD = 512 * MB
51MULTIPART_CHUNK_SIZE = 256 * MB
52IO_CHUNK_SIZE = 128 * MB
53MAX_CONCURRENCY = 16
54PROVIDER = "s3"
55
56
[docs]
57class StaticS3CredentialsProvider(CredentialsProvider):
58 """
59 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
60 """
61
62 _access_key: str
63 _secret_key: str
64 _session_token: Optional[str]
65
66 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None):
67 """
68 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional
69 session token.
70
71 :param access_key: The access key for S3 authentication.
72 :param secret_key: The secret key for S3 authentication.
73 :param session_token: An optional session token for temporary credentials.
74 """
75 self._access_key = access_key
76 self._secret_key = secret_key
77 self._session_token = session_token
78
[docs]
79 def get_credentials(self) -> Credentials:
80 return Credentials(
81 access_key=self._access_key,
82 secret_key=self._secret_key,
83 token=self._session_token,
84 expiration=None,
85 )
86
[docs]
87 def refresh_credentials(self) -> None:
88 pass
89
90
[docs]
91class S3StorageProvider(BaseStorageProvider):
92 """
93 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or SwiftStack.
94 """
95
96 def __init__(
97 self,
98 region_name: str = "",
99 endpoint_url: str = "",
100 base_path: str = "",
101 credentials_provider: Optional[CredentialsProvider] = None,
102 **kwargs: Any,
103 ) -> None:
104 """
105 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider.
106
107 :param region_name: The AWS region where the S3 bucket is located.
108 :param endpoint_url: The custom endpoint URL for the S3 service.
109 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped.
110 :param credentials_provider: The provider to retrieve S3 credentials.
111 """
112 super().__init__(base_path=base_path, provider_name=PROVIDER)
113
114 self._region_name = region_name
115 self._endpoint_url = endpoint_url
116 self._credentials_provider = credentials_provider
117 self._signature_version = kwargs.get("signature_version", "")
118 self._s3_client = self._create_s3_client()
119 self._transfer_config = TransferConfig(
120 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)),
121 max_concurrency=int(kwargs.get("max_concurrency", MAX_CONCURRENCY)),
122 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNK_SIZE)),
123 io_chunksize=int(kwargs.get("io_chunk_size", IO_CHUNK_SIZE)),
124 use_threads=True,
125 )
126
127 def _create_s3_client(self):
128 """
129 Creates and configures the boto3 S3 client, using refreshable credentials if possible.
130
131 :return The configured S3 client.
132 """
133 options = {
134 "region_name": self._region_name,
135 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue]
136 max_pool_connections=BOTO3_MAX_POOL_CONNECTIONS,
137 connect_timeout=BOTO3_CONNECT_TIMEOUT,
138 read_timeout=BOTO3_READ_TIMEOUT,
139 retries=dict(mode="standard"),
140 ),
141 }
142 if self._endpoint_url:
143 options["endpoint_url"] = self._endpoint_url
144
145 if self._credentials_provider:
146 creds = self._fetch_credentials()
147 if "expiry_time" in creds and creds["expiry_time"]:
148 # Use RefreshableCredentials if expiry_time provided.
149 refreshable_credentials = RefreshableCredentials.create_from_metadata(
150 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh"
151 )
152
153 botocore_session = get_session()
154 botocore_session._credentials = refreshable_credentials
155
156 boto3_session = boto3.Session(botocore_session=botocore_session)
157
158 return boto3_session.client("s3", **options)
159 else:
160 # Add static credentials to the options dictionary
161 options["aws_access_key_id"] = creds["access_key"]
162 options["aws_secret_access_key"] = creds["secret_key"]
163 if creds["token"]:
164 options["aws_session_token"] = creds["token"]
165
166 if self._signature_version:
167 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue]
168 signature_version=botocore.UNSIGNED
169 if self._signature_version == "UNSIGNED"
170 else self._signature_version
171 )
172 options["config"] = options["config"].merge(signature_config)
173
174 # Fallback to standard credential chain.
175 return boto3.client("s3", **options)
176
177 def _fetch_credentials(self) -> dict:
178 """
179 Refreshes the S3 client if the current credentials are expired.
180 """
181 if not self._credentials_provider:
182 raise RuntimeError("Cannot fetch credentials if no credential provider configured.")
183 self._credentials_provider.refresh_credentials()
184 credentials = self._credentials_provider.get_credentials()
185 return {
186 "access_key": credentials.access_key,
187 "secret_key": credentials.secret_key,
188 "token": credentials.token,
189 "expiry_time": credentials.expiration,
190 }
191
192 def _collect_metrics(
193 self,
194 func: Callable,
195 operation: str,
196 bucket: str,
197 key: str,
198 put_object_size: Optional[int] = None,
199 get_object_size: Optional[int] = None,
200 ) -> Any:
201 """
202 Collects and records performance metrics around S3 operations such as PUT, GET, DELETE, etc.
203
204 This method wraps an S3 operation and measures the time it takes to complete, along with recording
205 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
206 proper logging of duration and object size.
207
208 :param func: The function that performs the actual S3 operation.
209 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
210 :param bucket: The name of the S3 bucket involved in the operation.
211 :param key: The key of the object within the S3 bucket.
212 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
213 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
214
215 :return: The result of the S3 operation, typically the return value of the `func` callable.
216 """
217 start_time = time.time()
218 status_code = 200
219
220 object_size = None
221 if operation == "PUT":
222 object_size = put_object_size
223 elif operation == "GET" and get_object_size:
224 object_size = get_object_size
225
226 try:
227 result = func()
228 if operation == "GET" and object_size is None:
229 object_size = len(result)
230 return result
231 except ClientError as error:
232 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"]
233 request_id = error.response["ResponseMetadata"].get("RequestId")
234 host_id = error.response["ResponseMetadata"].get("HostId")
235
236 request_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}"
237
238 if status_code == 404:
239 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {request_info}") # pylint: disable=raise-missing-from
240 elif status_code == 429:
241 raise RetryableError(
242 f"Too many request to {operation} object(s) at {bucket}/{key}. {request_info}"
243 ) from error
244 else:
245 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {request_info}") from error
246 except FileNotFoundError as error:
247 status_code = -1
248 raise error
249 except (ReadTimeoutError, IncompleteReadError) as error:
250 status_code = -1
251 raise RetryableError(
252 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read."
253 ) from error
254 except Exception as error:
255 status_code = -1
256 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error
257 finally:
258 elapsed_time = time.time() - start_time
259 self._metric_helper.record_duration(
260 elapsed_time, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code
261 )
262 if object_size:
263 self._metric_helper.record_object_size(
264 object_size, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code
265 )
266
267 def _put_object(self, path: str, body: bytes) -> None:
268 bucket, key = split_path(path)
269
270 def _invoke_api() -> None:
271 self._s3_client.put_object(Bucket=bucket, Key=key, Body=body)
272
273 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
274
275 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
276 bucket, key = split_path(path)
277
278 def _invoke_api() -> bytes:
279 if byte_range:
280 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
281 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range)
282 else:
283 response = self._s3_client.get_object(Bucket=bucket, Key=key)
284 return response["Body"].read()
285
286 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
287
288 def _copy_object(self, src_path: str, dest_path: str) -> None:
289 src_bucket, src_key = split_path(src_path)
290 dest_bucket, dest_key = split_path(dest_path)
291
292 def _invoke_api() -> None:
293 self._s3_client.copy_object(
294 CopySource={"Bucket": src_bucket, "Key": src_key}, Bucket=dest_bucket, Key=dest_key
295 )
296
297 src_object = self._get_object_metadata(src_path)
298
299 return self._collect_metrics(
300 _invoke_api,
301 operation="COPY",
302 bucket=dest_bucket,
303 key=dest_key,
304 put_object_size=src_object.content_length,
305 )
306
307 def _delete_object(self, path: str) -> None:
308 bucket, key = split_path(path)
309
310 def _invoke_api() -> None:
311 self._s3_client.delete_object(Bucket=bucket, Key=key)
312
313 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
314
315 def _is_dir(self, path: str) -> bool:
316 # Ensure the path ends with '/' to mimic a directory
317 path = self._append_delimiter(path)
318
319 bucket, key = split_path(path)
320
321 def _invoke_api() -> bool:
322 # List objects with the given prefix
323 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/")
324 # Check if there are any contents or common prefixes
325 return bool(response.get("Contents", []) or response.get("CommonPrefixes", []))
326
327 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key)
328
329 def _get_object_metadata(self, path: str) -> ObjectMetadata:
330 if path.endswith("/"):
331 # If path is a "directory", then metadata is not guaranteed to exist if
332 # it is a "virtual prefix" that was never explicitly created.
333 if self._is_dir(path):
334 return ObjectMetadata(
335 key=path,
336 type="directory",
337 content_length=0,
338 last_modified=datetime.min,
339 )
340 else:
341 raise FileNotFoundError(f"Directory {path} does not exist.")
342 else:
343 bucket, key = split_path(path)
344
345 def _invoke_api() -> ObjectMetadata:
346 response = self._s3_client.head_object(Bucket=bucket, Key=key)
347 return ObjectMetadata(
348 key=path,
349 type="file",
350 content_length=response["ContentLength"],
351 content_type=response["ContentType"],
352 last_modified=response["LastModified"],
353 etag=response["ETag"].strip('"'),
354 )
355
356 try:
357 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
358 except FileNotFoundError as error:
359 # If the object does not exist on the given path, we will append a trailing slash and
360 # check if the path is a directory.
361 path = self._append_delimiter(path)
362 if self._is_dir(path):
363 return ObjectMetadata(
364 key=path,
365 type="directory",
366 content_length=0,
367 last_modified=datetime.min,
368 )
369 else:
370 raise error
371
372 def _list_objects(
373 self,
374 prefix: str,
375 start_after: Optional[str] = None,
376 end_at: Optional[str] = None,
377 include_directories: bool = False,
378 ) -> Iterator[ObjectMetadata]:
379 bucket, prefix = split_path(prefix)
380
381 def _invoke_api() -> Iterator[ObjectMetadata]:
382 paginator = self._s3_client.get_paginator("list_objects_v2")
383 if include_directories:
384 page_iterator = paginator.paginate(
385 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "")
386 )
387 else:
388 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or ""))
389
390 for page in page_iterator:
391 for item in page.get("CommonPrefixes", []):
392 yield ObjectMetadata(
393 key=item["Prefix"].rstrip("/"),
394 type="directory",
395 content_length=0,
396 last_modified=datetime.min,
397 )
398
399 # S3 guarantees lexicographical order for general purpose buckets (for
400 # normal S3) but not directory buckets (for S3 Express One Zone).
401 for response_object in page.get("Contents", []):
402 key = response_object["Key"]
403 if end_at is None or key <= end_at:
404 yield ObjectMetadata(
405 key=key,
406 type="file",
407 content_length=response_object["Size"],
408 last_modified=response_object["LastModified"],
409 etag=response_object["ETag"].strip('"'),
410 )
411 else:
412 return
413
414 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
415
416 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None:
417 if isinstance(f, str):
418 filesize = os.path.getsize(f)
419
420 # Upload small files
421 if filesize <= self._transfer_config.multipart_threshold:
422 with open(f, "rb") as fp:
423 self._put_object(remote_path, fp.read())
424 return
425
426 # Upload large files using TransferConfig
427 bucket, key = split_path(remote_path)
428
429 def _invoke_api() -> None:
430 self._s3_client.upload_file(
431 Filename=f,
432 Bucket=bucket,
433 Key=key,
434 Config=self._transfer_config,
435 )
436
437 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize)
438 else:
439 # Upload small files
440 f.seek(0, io.SEEK_END)
441 filesize = f.tell()
442 f.seek(0)
443
444 if filesize <= self._transfer_config.multipart_threshold:
445 if isinstance(f, io.StringIO):
446 self._put_object(remote_path, f.read().encode("utf-8"))
447 else:
448 self._put_object(remote_path, f.read())
449 return
450
451 # Upload large files using TransferConfig
452 bucket, key = split_path(remote_path)
453
454 def _invoke_api() -> None:
455 self._s3_client.upload_fileobj(
456 Fileobj=f,
457 Bucket=bucket,
458 Key=key,
459 Config=self._transfer_config,
460 )
461
462 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize)
463
464 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
465 if not metadata:
466 metadata = self._get_object_metadata(remote_path)
467
468 if isinstance(f, str):
469 os.makedirs(os.path.dirname(f), exist_ok=True)
470 # Download small files
471 if metadata.content_length <= self._transfer_config.multipart_threshold:
472 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
473 temp_file_path = fp.name
474 fp.write(self._get_object(remote_path))
475 os.rename(src=temp_file_path, dst=f)
476 return
477
478 # Download large files using TransferConfig
479 bucket, key = split_path(remote_path)
480
481 def _invoke_api() -> None:
482 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
483 temp_file_path = fp.name
484 self._s3_client.download_fileobj(
485 Bucket=bucket,
486 Key=key,
487 Fileobj=fp,
488 Config=self._transfer_config,
489 )
490 os.rename(src=temp_file_path, dst=f)
491
492 return self._collect_metrics(
493 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
494 )
495 else:
496 # Download small files
497 if metadata.content_length <= self._transfer_config.multipart_threshold:
498 if isinstance(f, io.StringIO):
499 f.write(self._get_object(remote_path).decode("utf-8"))
500 else:
501 f.write(self._get_object(remote_path))
502 return
503
504 # Download large files using TransferConfig
505 bucket, key = split_path(remote_path)
506
507 def _invoke_api() -> None:
508 self._s3_client.download_fileobj(
509 Bucket=bucket,
510 Key=key,
511 Fileobj=f,
512 Config=self._transfer_config,
513 )
514
515 return self._collect_metrics(
516 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
517 )