Source code for multistorageclient.providers.s3

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19import time
 20from datetime import datetime
 21from typing import IO, Any, Callable, Iterator, Optional, Union
 22
 23import boto3
 24from boto3.s3.transfer import TransferConfig
 25import botocore
 26from botocore.credentials import RefreshableCredentials
 27from botocore.exceptions import (
 28    ClientError,
 29    ReadTimeoutError,
 30    IncompleteReadError,
 31)
 32from botocore.session import get_session
 33
 34from ..types import (
 35    Credentials,
 36    CredentialsProvider,
 37    ObjectMetadata,
 38    Range,
 39    RetryableError,
 40)
 41from ..utils import split_path
 42from .base import BaseStorageProvider
 43
 44BOTO3_MAX_POOL_CONNECTIONS = 32
 45BOTO3_CONNECT_TIMEOUT = 10
 46BOTO3_READ_TIMEOUT = 10
 47
 48MB = 1024 * 1024
 49
 50MULTIPART_THRESHOLD = 512 * MB
 51MULTIPART_CHUNK_SIZE = 256 * MB
 52IO_CHUNK_SIZE = 128 * MB
 53MAX_CONCURRENCY = 16
 54PROVIDER = "s3"
 55
 56
[docs] 57class StaticS3CredentialsProvider(CredentialsProvider): 58 """ 59 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials. 60 """ 61 62 _access_key: str 63 _secret_key: str 64 _session_token: Optional[str] 65 66 def __init__(self, access_key: str, secret_key: str, session_token: Optional[str] = None): 67 """ 68 Initializes the :py:class:`StaticS3CredentialsProvider` with the provided access key, secret key, and optional 69 session token. 70 71 :param access_key: The access key for S3 authentication. 72 :param secret_key: The secret key for S3 authentication. 73 :param session_token: An optional session token for temporary credentials. 74 """ 75 self._access_key = access_key 76 self._secret_key = secret_key 77 self._session_token = session_token 78
[docs] 79 def get_credentials(self) -> Credentials: 80 return Credentials( 81 access_key=self._access_key, 82 secret_key=self._secret_key, 83 token=self._session_token, 84 expiration=None, 85 )
86
[docs] 87 def refresh_credentials(self) -> None: 88 pass
89 90
[docs] 91class S3StorageProvider(BaseStorageProvider): 92 """ 93 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Amazon S3 or SwiftStack. 94 """ 95 96 def __init__( 97 self, 98 region_name: str = "", 99 endpoint_url: str = "", 100 base_path: str = "", 101 credentials_provider: Optional[CredentialsProvider] = None, 102 **kwargs: Any, 103 ) -> None: 104 """ 105 Initializes the :py:class:`S3StorageProvider` with the region, endpoint URL, and optional credentials provider. 106 107 :param region_name: The AWS region where the S3 bucket is located. 108 :param endpoint_url: The custom endpoint URL for the S3 service. 109 :param base_path: The root prefix path within the S3 bucket where all operations will be scoped. 110 :param credentials_provider: The provider to retrieve S3 credentials. 111 """ 112 super().__init__(base_path=base_path, provider_name=PROVIDER) 113 114 self._region_name = region_name 115 self._endpoint_url = endpoint_url 116 self._credentials_provider = credentials_provider 117 self._signature_version = kwargs.get("signature_version", "") 118 self._s3_client = self._create_s3_client() 119 self._transfer_config = TransferConfig( 120 multipart_threshold=int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)), 121 max_concurrency=int(kwargs.get("max_concurrency", MAX_CONCURRENCY)), 122 multipart_chunksize=int(kwargs.get("multipart_chunksize", MULTIPART_CHUNK_SIZE)), 123 io_chunksize=int(kwargs.get("io_chunk_size", IO_CHUNK_SIZE)), 124 use_threads=True, 125 ) 126 127 def _create_s3_client(self): 128 """ 129 Creates and configures the boto3 S3 client, using refreshable credentials if possible. 130 131 :return The configured S3 client. 132 """ 133 options = { 134 "region_name": self._region_name, 135 "config": boto3.session.Config( # pyright: ignore [reportAttributeAccessIssue] 136 max_pool_connections=BOTO3_MAX_POOL_CONNECTIONS, 137 connect_timeout=BOTO3_CONNECT_TIMEOUT, 138 read_timeout=BOTO3_READ_TIMEOUT, 139 retries=dict(mode="standard"), 140 ), 141 } 142 if self._endpoint_url: 143 options["endpoint_url"] = self._endpoint_url 144 145 if self._credentials_provider: 146 creds = self._fetch_credentials() 147 if "expiry_time" in creds and creds["expiry_time"]: 148 # Use RefreshableCredentials if expiry_time provided. 149 refreshable_credentials = RefreshableCredentials.create_from_metadata( 150 metadata=creds, refresh_using=self._fetch_credentials, method="custom-refresh" 151 ) 152 153 botocore_session = get_session() 154 botocore_session._credentials = refreshable_credentials 155 156 boto3_session = boto3.Session(botocore_session=botocore_session) 157 158 return boto3_session.client("s3", **options) 159 else: 160 # Add static credentials to the options dictionary 161 options["aws_access_key_id"] = creds["access_key"] 162 options["aws_secret_access_key"] = creds["secret_key"] 163 if creds["token"]: 164 options["aws_session_token"] = creds["token"] 165 166 if self._signature_version: 167 signature_config = botocore.config.Config( # pyright: ignore[reportAttributeAccessIssue] 168 signature_version=botocore.UNSIGNED 169 if self._signature_version == "UNSIGNED" 170 else self._signature_version 171 ) 172 options["config"] = options["config"].merge(signature_config) 173 174 # Fallback to standard credential chain. 175 return boto3.client("s3", **options) 176 177 def _fetch_credentials(self) -> dict: 178 """ 179 Refreshes the S3 client if the current credentials are expired. 180 """ 181 if not self._credentials_provider: 182 raise RuntimeError("Cannot fetch credentials if no credential provider configured.") 183 self._credentials_provider.refresh_credentials() 184 credentials = self._credentials_provider.get_credentials() 185 return { 186 "access_key": credentials.access_key, 187 "secret_key": credentials.secret_key, 188 "token": credentials.token, 189 "expiry_time": credentials.expiration, 190 } 191 192 def _collect_metrics( 193 self, 194 func: Callable, 195 operation: str, 196 bucket: str, 197 key: str, 198 put_object_size: Optional[int] = None, 199 get_object_size: Optional[int] = None, 200 ) -> Any: 201 """ 202 Collects and records performance metrics around S3 operations such as PUT, GET, DELETE, etc. 203 204 This method wraps an S3 operation and measures the time it takes to complete, along with recording 205 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 206 proper logging of duration and object size. 207 208 :param func: The function that performs the actual S3 operation. 209 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 210 :param bucket: The name of the S3 bucket involved in the operation. 211 :param key: The key of the object within the S3 bucket. 212 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 213 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 214 215 :return: The result of the S3 operation, typically the return value of the `func` callable. 216 """ 217 start_time = time.time() 218 status_code = 200 219 220 object_size = None 221 if operation == "PUT": 222 object_size = put_object_size 223 elif operation == "GET" and get_object_size: 224 object_size = get_object_size 225 226 try: 227 result = func() 228 if operation == "GET" and object_size is None: 229 object_size = len(result) 230 return result 231 except ClientError as error: 232 status_code = error.response["ResponseMetadata"]["HTTPStatusCode"] 233 request_id = error.response["ResponseMetadata"].get("RequestId") 234 host_id = error.response["ResponseMetadata"].get("HostId") 235 236 request_info = f"request_id: {request_id}, host_id: {host_id}, status_code: {status_code}" 237 238 if status_code == 404: 239 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {request_info}") # pylint: disable=raise-missing-from 240 elif status_code == 429: 241 raise RetryableError( 242 f"Too many request to {operation} object(s) at {bucket}/{key}. {request_info}" 243 ) from error 244 else: 245 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {request_info}") from error 246 except FileNotFoundError as error: 247 status_code = -1 248 raise error 249 except (ReadTimeoutError, IncompleteReadError) as error: 250 status_code = -1 251 raise RetryableError( 252 f"Failed to {operation} object(s) at {bucket}/{key} due to network timeout or incomplete read." 253 ) from error 254 except Exception as error: 255 status_code = -1 256 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error 257 finally: 258 elapsed_time = time.time() - start_time 259 self._metric_helper.record_duration( 260 elapsed_time, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code 261 ) 262 if object_size: 263 self._metric_helper.record_object_size( 264 object_size, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code 265 ) 266 267 def _put_object(self, path: str, body: bytes) -> None: 268 bucket, key = split_path(path) 269 270 def _invoke_api() -> None: 271 self._s3_client.put_object(Bucket=bucket, Key=key, Body=body) 272 273 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 274 275 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 276 bucket, key = split_path(path) 277 278 def _invoke_api() -> bytes: 279 if byte_range: 280 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 281 response = self._s3_client.get_object(Bucket=bucket, Key=key, Range=bytes_range) 282 else: 283 response = self._s3_client.get_object(Bucket=bucket, Key=key) 284 return response["Body"].read() 285 286 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 287 288 def _copy_object(self, src_path: str, dest_path: str) -> None: 289 src_bucket, src_key = split_path(src_path) 290 dest_bucket, dest_key = split_path(dest_path) 291 292 def _invoke_api() -> None: 293 self._s3_client.copy_object( 294 CopySource={"Bucket": src_bucket, "Key": src_key}, Bucket=dest_bucket, Key=dest_key 295 ) 296 297 src_object = self._get_object_metadata(src_path) 298 299 return self._collect_metrics( 300 _invoke_api, 301 operation="COPY", 302 bucket=dest_bucket, 303 key=dest_key, 304 put_object_size=src_object.content_length, 305 ) 306 307 def _delete_object(self, path: str) -> None: 308 bucket, key = split_path(path) 309 310 def _invoke_api() -> None: 311 self._s3_client.delete_object(Bucket=bucket, Key=key) 312 313 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 314 315 def _is_dir(self, path: str) -> bool: 316 # Ensure the path ends with '/' to mimic a directory 317 path = self._append_delimiter(path) 318 319 bucket, key = split_path(path) 320 321 def _invoke_api() -> bool: 322 # List objects with the given prefix 323 response = self._s3_client.list_objects_v2(Bucket=bucket, Prefix=key, MaxKeys=1, Delimiter="/") 324 # Check if there are any contents or common prefixes 325 return bool(response.get("Contents", []) or response.get("CommonPrefixes", [])) 326 327 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) 328 329 def _get_object_metadata(self, path: str) -> ObjectMetadata: 330 if path.endswith("/"): 331 # If path is a "directory", then metadata is not guaranteed to exist if 332 # it is a "virtual prefix" that was never explicitly created. 333 if self._is_dir(path): 334 return ObjectMetadata( 335 key=path, 336 type="directory", 337 content_length=0, 338 last_modified=datetime.min, 339 ) 340 else: 341 raise FileNotFoundError(f"Directory {path} does not exist.") 342 else: 343 bucket, key = split_path(path) 344 345 def _invoke_api() -> ObjectMetadata: 346 response = self._s3_client.head_object(Bucket=bucket, Key=key) 347 return ObjectMetadata( 348 key=path, 349 type="file", 350 content_length=response["ContentLength"], 351 content_type=response["ContentType"], 352 last_modified=response["LastModified"], 353 etag=response["ETag"].strip('"'), 354 ) 355 356 try: 357 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 358 except FileNotFoundError as error: 359 # If the object does not exist on the given path, we will append a trailing slash and 360 # check if the path is a directory. 361 path = self._append_delimiter(path) 362 if self._is_dir(path): 363 return ObjectMetadata( 364 key=path, 365 type="directory", 366 content_length=0, 367 last_modified=datetime.min, 368 ) 369 else: 370 raise error 371 372 def _list_objects( 373 self, 374 prefix: str, 375 start_after: Optional[str] = None, 376 end_at: Optional[str] = None, 377 include_directories: bool = False, 378 ) -> Iterator[ObjectMetadata]: 379 bucket, prefix = split_path(prefix) 380 381 def _invoke_api() -> Iterator[ObjectMetadata]: 382 paginator = self._s3_client.get_paginator("list_objects_v2") 383 if include_directories: 384 page_iterator = paginator.paginate( 385 Bucket=bucket, Prefix=prefix, Delimiter="/", StartAfter=(start_after or "") 386 ) 387 else: 388 page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, StartAfter=(start_after or "")) 389 390 for page in page_iterator: 391 for item in page.get("CommonPrefixes", []): 392 yield ObjectMetadata( 393 key=item["Prefix"].rstrip("/"), 394 type="directory", 395 content_length=0, 396 last_modified=datetime.min, 397 ) 398 399 # S3 guarantees lexicographical order for general purpose buckets (for 400 # normal S3) but not directory buckets (for S3 Express One Zone). 401 for response_object in page.get("Contents", []): 402 key = response_object["Key"] 403 if end_at is None or key <= end_at: 404 yield ObjectMetadata( 405 key=key, 406 type="file", 407 content_length=response_object["Size"], 408 last_modified=response_object["LastModified"], 409 etag=response_object["ETag"].strip('"'), 410 ) 411 else: 412 return 413 414 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 415 416 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None: 417 if isinstance(f, str): 418 filesize = os.path.getsize(f) 419 420 # Upload small files 421 if filesize <= self._transfer_config.multipart_threshold: 422 with open(f, "rb") as fp: 423 self._put_object(remote_path, fp.read()) 424 return 425 426 # Upload large files using TransferConfig 427 bucket, key = split_path(remote_path) 428 429 def _invoke_api() -> None: 430 self._s3_client.upload_file( 431 Filename=f, 432 Bucket=bucket, 433 Key=key, 434 Config=self._transfer_config, 435 ) 436 437 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize) 438 else: 439 # Upload small files 440 f.seek(0, io.SEEK_END) 441 filesize = f.tell() 442 f.seek(0) 443 444 if filesize <= self._transfer_config.multipart_threshold: 445 if isinstance(f, io.StringIO): 446 self._put_object(remote_path, f.read().encode("utf-8")) 447 else: 448 self._put_object(remote_path, f.read()) 449 return 450 451 # Upload large files using TransferConfig 452 bucket, key = split_path(remote_path) 453 454 def _invoke_api() -> None: 455 self._s3_client.upload_fileobj( 456 Fileobj=f, 457 Bucket=bucket, 458 Key=key, 459 Config=self._transfer_config, 460 ) 461 462 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize) 463 464 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None: 465 if not metadata: 466 metadata = self._get_object_metadata(remote_path) 467 468 if isinstance(f, str): 469 os.makedirs(os.path.dirname(f), exist_ok=True) 470 # Download small files 471 if metadata.content_length <= self._transfer_config.multipart_threshold: 472 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 473 temp_file_path = fp.name 474 fp.write(self._get_object(remote_path)) 475 os.rename(src=temp_file_path, dst=f) 476 return 477 478 # Download large files using TransferConfig 479 bucket, key = split_path(remote_path) 480 481 def _invoke_api() -> None: 482 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 483 temp_file_path = fp.name 484 self._s3_client.download_fileobj( 485 Bucket=bucket, 486 Key=key, 487 Fileobj=fp, 488 Config=self._transfer_config, 489 ) 490 os.rename(src=temp_file_path, dst=f) 491 492 return self._collect_metrics( 493 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 494 ) 495 else: 496 # Download small files 497 if metadata.content_length <= self._transfer_config.multipart_threshold: 498 if isinstance(f, io.StringIO): 499 f.write(self._get_object(remote_path).decode("utf-8")) 500 else: 501 f.write(self._get_object(remote_path)) 502 return 503 504 # Download large files using TransferConfig 505 bucket, key = split_path(remote_path) 506 507 def _invoke_api() -> None: 508 self._s3_client.download_fileobj( 509 Bucket=bucket, 510 Key=key, 511 Fileobj=f, 512 Config=self._transfer_config, 513 ) 514 515 return self._collect_metrics( 516 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 517 )