Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19import time
 20from collections.abc import Callable, Iterator, Sequence, Sized
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23import opentelemetry.metrics as api_metrics
 24from google.api_core.exceptions import GoogleAPICallError, NotFound
 25from google.auth import identity_pool
 26from google.cloud import storage
 27from google.cloud.storage import transfer_manager
 28from google.cloud.storage.exceptions import InvalidResponse
 29from google.oauth2.credentials import Credentials as OAuth2Credentials
 30
 31from ..telemetry import Telemetry
 32from ..telemetry.attributes.base import AttributesProvider
 33from ..types import (
 34    AWARE_DATETIME_MIN,
 35    Credentials,
 36    CredentialsProvider,
 37    NotModifiedError,
 38    ObjectMetadata,
 39    PreconditionFailedError,
 40    Range,
 41    RetryableError,
 42)
 43from ..utils import split_path, validate_attributes
 44from .base import BaseStorageProvider
 45
 46_T = TypeVar("_T")
 47
 48PROVIDER = "gcs"
 49
 50MB = 1024 * 1024
 51
 52DEFAULT_MULTIPART_THRESHOLD = 512 * MB
 53DEFAULT_MULTIPART_CHUNK_SIZE = 256 * MB
 54DEFAULT_IO_CHUNK_SIZE = 256 * MB
 55DEFAULT_MAX_CONCURRENCY = 8
 56
 57
[docs] 58class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 59 """ 60 Supply a string token to the Google Identity Pool. 61 """ 62 63 def __init__(self, token: str): 64 self._token = token 65
[docs] 66 def get_subject_token(self, context, request): 67 return self._token
68 69
[docs] 70class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 71 """ 72 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 73 """ 74 75 def __init__(self, audience: str, token_supplier: str): 76 """ 77 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 78 79 :param audience: The audience for the Google Identity Pool. 80 :param token_supplier: The token supplier for the Google Identity Pool. 81 """ 82 self._audience = audience 83 self._token_supplier = token_supplier 84
[docs] 85 def get_credentials(self) -> Credentials: 86 return Credentials( 87 access_key="", 88 secret_key="", 89 token="", 90 expiration=None, 91 custom_fields={"audience": self._audience, "token": self._token_supplier}, 92 )
93
[docs] 94 def refresh_credentials(self) -> None: 95 pass
96 97
[docs] 98class GoogleStorageProvider(BaseStorageProvider): 99 """ 100 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 101 """ 102 103 def __init__( 104 self, 105 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 106 endpoint_url: str = "", 107 base_path: str = "", 108 credentials_provider: Optional[CredentialsProvider] = None, 109 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {}, 110 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {}, 111 metric_attributes_providers: Sequence[AttributesProvider] = (), 112 **kwargs: Any, 113 ): 114 """ 115 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 116 117 :param project_id: The Google Cloud project ID. 118 :param endpoint_url: The custom endpoint URL for the GCS service. 119 :param base_path: The root prefix path within the bucket where all operations will be scoped. 120 :param credentials_provider: The provider to retrieve GCS credentials. 121 :param metric_counters: Metric counters. 122 :param metric_gauges: Metric gauges. 123 :param metric_attributes_providers: Metric attributes providers. 124 """ 125 super().__init__( 126 base_path=base_path, 127 provider_name=PROVIDER, 128 metric_counters=metric_counters, 129 metric_gauges=metric_gauges, 130 metric_attributes_providers=metric_attributes_providers, 131 ) 132 133 self._project_id = project_id 134 self._endpoint_url = endpoint_url 135 self._credentials_provider = credentials_provider 136 self._gcs_client = self._create_gcs_client() 137 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 138 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNK_SIZE) 139 self._io_chunk_size = kwargs.get("io_chunk_size", DEFAULT_IO_CHUNK_SIZE) 140 self._max_concurrency = kwargs.get("max_concurrency", DEFAULT_MAX_CONCURRENCY) 141 142 def _create_gcs_client(self) -> storage.Client: 143 client_options = {} 144 if self._endpoint_url: 145 client_options["api_endpoint"] = self._endpoint_url 146 147 if self._credentials_provider: 148 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 149 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 150 token = self._credentials_provider.get_credentials().get_custom_field("token") 151 152 # Use Workload Identity Federation (WIF) 153 identity_pool_credentials = identity_pool.Credentials( 154 audience=audience, 155 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 156 subject_token_supplier=StringTokenSupplier(token), 157 ) 158 return storage.Client( 159 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 160 ) 161 else: 162 # Use OAuth 2.0 token 163 token = self._credentials_provider.get_credentials().token 164 creds = OAuth2Credentials(token=token) 165 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 166 else: 167 return storage.Client(project=self._project_id, client_options=client_options) 168 169 def _refresh_gcs_client_if_needed(self) -> None: 170 """ 171 Refreshes the GCS client if the current credentials are expired. 172 """ 173 if self._credentials_provider: 174 credentials = self._credentials_provider.get_credentials() 175 if credentials.is_expired(): 176 self._credentials_provider.refresh_credentials() 177 self._gcs_client = self._create_gcs_client() 178 179 def _collect_metrics( 180 self, 181 func: Callable[[], _T], 182 operation: str, 183 bucket: str, 184 key: str, 185 put_object_size: Optional[int] = None, 186 get_object_size: Optional[int] = None, 187 ) -> _T: 188 """ 189 Collects and records performance metrics around GCS operations such as PUT, GET, DELETE, etc. 190 191 This method wraps an GCS operation and measures the time it takes to complete, along with recording 192 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 193 proper logging of duration and object size. 194 195 :param func: The function that performs the actual GCS operation. 196 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 197 :param bucket: The name of the GCS bucket involved in the operation. 198 :param key: The key of the object within the GCS bucket. 199 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 200 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 201 202 :return: The result of the GCS operation, typically the return value of the `func` callable. 203 """ 204 start_time = time.time() 205 status_code = 200 206 207 object_size = None 208 if operation == "PUT": 209 object_size = put_object_size 210 elif operation == "GET" and get_object_size: 211 object_size = get_object_size 212 213 try: 214 result = func() 215 if operation == "GET" and object_size is None and isinstance(result, Sized): 216 object_size = len(result) 217 return result 218 except GoogleAPICallError as error: 219 status_code = error.code if error.code else -1 220 error_info = f"status_code: {status_code}, message: {error.message}" 221 if status_code == 404: 222 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 223 elif status_code == 412: 224 raise PreconditionFailedError( 225 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 226 ) from error 227 elif status_code == 304: 228 # for if_none_match with a specific etag condition. 229 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 230 else: 231 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 232 except InvalidResponse as error: 233 status_code = error.response.status_code 234 response_text = error.response.text 235 error_details = f"error: {error}, error_response_text: {response_text}" 236 # Check for NoSuchUpload within the response text 237 if "NoSuchUpload" in response_text: 238 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 239 else: 240 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 241 except Exception as error: 242 status_code = -1 243 error_details = str(error) 244 raise RuntimeError( 245 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 246 ) from error 247 finally: 248 elapsed_time = time.time() - start_time 249 self._metric_helper.record_duration( 250 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code 251 ) 252 if object_size: 253 self._metric_helper.record_object_size( 254 object_size, 255 provider=self._provider_name, 256 operation=operation, 257 bucket=bucket, 258 status_code=status_code, 259 ) 260 261 def _put_object( 262 self, 263 path: str, 264 body: bytes, 265 if_match: Optional[str] = None, 266 if_none_match: Optional[str] = None, 267 attributes: Optional[dict[str, str]] = None, 268 ) -> int: 269 """ 270 Uploads an object to Google Cloud Storage. 271 272 :param path: The path to the object to upload. 273 :param body: The content of the object to upload. 274 :param if_match: Optional ETag to match against the object. 275 :param if_none_match: Optional ETag to match against the object. 276 :param attributes: Optional attributes to attach to the object. 277 """ 278 bucket, key = split_path(path) 279 self._refresh_gcs_client_if_needed() 280 281 def _invoke_api() -> int: 282 bucket_obj = self._gcs_client.bucket(bucket) 283 blob = bucket_obj.blob(key) 284 285 kwargs = {} 286 287 if if_match: 288 kwargs["if_generation_match"] = int(if_match) # 412 error code 289 if if_none_match: 290 if if_none_match == "*": 291 raise NotImplementedError("if_none_match='*' is not supported for GCS") 292 else: 293 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 294 295 validated_attributes = validate_attributes(attributes) 296 if validated_attributes: 297 blob.metadata = validated_attributes 298 299 blob.upload_from_string(body, **kwargs) 300 301 return len(body) 302 303 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 304 305 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 306 bucket, key = split_path(path) 307 self._refresh_gcs_client_if_needed() 308 309 def _invoke_api() -> bytes: 310 bucket_obj = self._gcs_client.bucket(bucket) 311 blob = bucket_obj.blob(key) 312 if byte_range: 313 return blob.download_as_bytes(start=byte_range.offset, end=byte_range.offset + byte_range.size - 1) 314 return blob.download_as_bytes() 315 316 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 317 318 def _copy_object(self, src_path: str, dest_path: str) -> int: 319 src_bucket, src_key = split_path(src_path) 320 dest_bucket, dest_key = split_path(dest_path) 321 self._refresh_gcs_client_if_needed() 322 323 src_object = self._get_object_metadata(src_path) 324 325 def _invoke_api() -> int: 326 source_bucket_obj = self._gcs_client.bucket(src_bucket) 327 source_blob = source_bucket_obj.blob(src_key) 328 329 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 330 destination_blob = destination_bucket_obj.blob(dest_key) 331 332 rewrite_tokens = [None] 333 while len(rewrite_tokens) > 0: 334 rewrite_token = rewrite_tokens.pop() 335 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 336 if next_rewrite_token is not None: 337 rewrite_tokens.append(next_rewrite_token) 338 339 return src_object.content_length 340 341 return self._collect_metrics( 342 _invoke_api, 343 operation="COPY", 344 bucket=src_bucket, 345 key=src_key, 346 put_object_size=src_object.content_length, 347 ) 348 349 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 350 bucket, key = split_path(path) 351 self._refresh_gcs_client_if_needed() 352 353 def _invoke_api() -> None: 354 bucket_obj = self._gcs_client.bucket(bucket) 355 blob = bucket_obj.blob(key) 356 357 # If if_match is provided, use it as a precondition 358 if if_match: 359 generation = int(if_match) 360 blob.delete(if_generation_match=generation) 361 else: 362 # No if_match check needed, just delete 363 blob.delete() 364 365 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 366 367 def _is_dir(self, path: str) -> bool: 368 # Ensure the path ends with '/' to mimic a directory 369 path = self._append_delimiter(path) 370 371 bucket, key = split_path(path) 372 self._refresh_gcs_client_if_needed() 373 374 def _invoke_api() -> bool: 375 bucket_obj = self._gcs_client.bucket(bucket) 376 # List objects with the given prefix 377 blobs = bucket_obj.list_blobs( 378 prefix=key, 379 delimiter="/", 380 ) 381 # Check if there are any contents or common prefixes 382 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 383 384 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) 385 386 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 387 if path.endswith("/"): 388 # If path is a "directory", then metadata is not guaranteed to exist if 389 # it is a "virtual prefix" that was never explicitly created. 390 if self._is_dir(path): 391 return ObjectMetadata( 392 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 393 ) 394 else: 395 raise FileNotFoundError(f"Directory {path} does not exist.") 396 else: 397 bucket, key = split_path(path) 398 self._refresh_gcs_client_if_needed() 399 400 def _invoke_api() -> ObjectMetadata: 401 bucket_obj = self._gcs_client.bucket(bucket) 402 blob = bucket_obj.get_blob(key) 403 if not blob: 404 raise NotFound(f"Blob {key} not found in bucket {bucket}") 405 return ObjectMetadata( 406 key=path, 407 content_length=blob.size or 0, 408 content_type=blob.content_type, 409 last_modified=blob.updated or AWARE_DATETIME_MIN, 410 etag=str(blob.generation), 411 metadata=dict(blob.metadata) if blob.metadata else None, 412 ) 413 414 try: 415 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 416 except FileNotFoundError as error: 417 if strict: 418 # If the object does not exist on the given path, we will append a trailing slash and 419 # check if the path is a directory. 420 path = self._append_delimiter(path) 421 if self._is_dir(path): 422 return ObjectMetadata( 423 key=path, 424 type="directory", 425 content_length=0, 426 last_modified=AWARE_DATETIME_MIN, 427 ) 428 raise error 429 430 def _list_objects( 431 self, 432 prefix: str, 433 start_after: Optional[str] = None, 434 end_at: Optional[str] = None, 435 include_directories: bool = False, 436 ) -> Iterator[ObjectMetadata]: 437 bucket, prefix = split_path(prefix) 438 self._refresh_gcs_client_if_needed() 439 440 def _invoke_api() -> Iterator[ObjectMetadata]: 441 bucket_obj = self._gcs_client.bucket(bucket) 442 if include_directories: 443 blobs = bucket_obj.list_blobs( 444 prefix=prefix, 445 # This is ≥ instead of >. 446 start_offset=start_after, 447 delimiter="/", 448 ) 449 else: 450 blobs = bucket_obj.list_blobs( 451 prefix=prefix, 452 # This is ≥ instead of >. 453 start_offset=start_after, 454 ) 455 456 # GCS guarantees lexicographical order. 457 for blob in blobs: 458 key = blob.name 459 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 460 if key.endswith("/"): 461 if include_directories: 462 yield ObjectMetadata( 463 key=os.path.join(bucket, key.rstrip("/")), 464 type="directory", 465 content_length=0, 466 last_modified=blob.updated, 467 ) 468 else: 469 yield ObjectMetadata( 470 key=os.path.join(bucket, key), 471 content_length=blob.size, 472 content_type=blob.content_type, 473 last_modified=blob.updated, 474 etag=blob.etag, 475 ) 476 elif start_after != key: 477 return 478 479 # The directories must be accessed last. 480 if include_directories: 481 for directory in blobs.prefixes: 482 yield ObjectMetadata( 483 key=os.path.join(bucket, directory.rstrip("/")), 484 type="directory", 485 content_length=0, 486 last_modified=AWARE_DATETIME_MIN, 487 ) 488 489 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 490 491 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 492 bucket, key = split_path(remote_path) 493 file_size: int = 0 494 self._refresh_gcs_client_if_needed() 495 496 if isinstance(f, str): 497 file_size = os.path.getsize(f) 498 499 # Upload small files 500 if file_size <= self._multipart_threshold: 501 with open(f, "rb") as fp: 502 self._put_object(remote_path, fp.read(), attributes=attributes) 503 return file_size 504 505 # Upload large files using transfer manager 506 def _invoke_api() -> int: 507 bucket_obj = self._gcs_client.bucket(bucket) 508 blob = bucket_obj.blob(key) 509 blob.metadata = validate_attributes(attributes) 510 transfer_manager.upload_chunks_concurrently( 511 f, 512 blob, 513 chunk_size=self._multipart_chunksize, 514 max_workers=self._max_concurrency, 515 worker_type=transfer_manager.THREAD, 516 ) 517 518 return file_size 519 520 return self._collect_metrics( 521 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 522 ) 523 else: 524 f.seek(0, io.SEEK_END) 525 file_size = f.tell() 526 f.seek(0) 527 528 # Upload small files 529 if file_size <= self._multipart_threshold: 530 if isinstance(f, io.StringIO): 531 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 532 else: 533 self._put_object(remote_path, f.read(), attributes=attributes) 534 return file_size 535 536 # Upload large files using transfer manager 537 def _invoke_api() -> int: 538 bucket_obj = self._gcs_client.bucket(bucket) 539 blob = bucket_obj.blob(key) 540 validated_attributes = validate_attributes(attributes) 541 if validated_attributes: 542 blob.metadata = validated_attributes 543 if isinstance(f, io.StringIO): 544 mode = "w" 545 else: 546 mode = "wb" 547 548 # transfer manager does not support uploading a file object 549 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 550 temp_file_path = fp.name 551 fp.write(f.read()) 552 553 transfer_manager.upload_chunks_concurrently( 554 temp_file_path, 555 blob, 556 chunk_size=self._multipart_chunksize, 557 max_workers=self._max_concurrency, 558 worker_type=transfer_manager.THREAD, 559 ) 560 561 os.unlink(temp_file_path) 562 563 return file_size 564 565 return self._collect_metrics( 566 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 567 ) 568 569 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 570 self._refresh_gcs_client_if_needed() 571 572 if metadata is None: 573 metadata = self._get_object_metadata(remote_path) 574 575 bucket, key = split_path(remote_path) 576 577 if isinstance(f, str): 578 os.makedirs(os.path.dirname(f), exist_ok=True) 579 # Download small files 580 if metadata.content_length <= self._multipart_threshold: 581 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 582 temp_file_path = fp.name 583 fp.write(self._get_object(remote_path)) 584 os.rename(src=temp_file_path, dst=f) 585 return metadata.content_length 586 587 # Download large files using transfer manager 588 def _invoke_api() -> int: 589 bucket_obj = self._gcs_client.bucket(bucket) 590 blob = bucket_obj.blob(key) 591 592 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 593 temp_file_path = fp.name 594 transfer_manager.download_chunks_concurrently( 595 blob, 596 temp_file_path, 597 chunk_size=self._io_chunk_size, 598 max_workers=self._max_concurrency, 599 worker_type=transfer_manager.THREAD, 600 ) 601 os.rename(src=temp_file_path, dst=f) 602 603 return metadata.content_length 604 605 return self._collect_metrics( 606 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 607 ) 608 else: 609 # Download small files 610 if metadata.content_length <= self._multipart_threshold: 611 if isinstance(f, io.StringIO): 612 f.write(self._get_object(remote_path).decode("utf-8")) 613 else: 614 f.write(self._get_object(remote_path)) 615 return metadata.content_length 616 617 # Download large files using transfer manager 618 def _invoke_api() -> int: 619 bucket_obj = self._gcs_client.bucket(bucket) 620 blob = bucket_obj.blob(key) 621 622 # transfer manager does not support downloading to a file object 623 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 624 temp_file_path = fp.name 625 transfer_manager.download_chunks_concurrently( 626 blob, 627 temp_file_path, 628 chunk_size=self._io_chunk_size, 629 max_workers=self._max_concurrency, 630 worker_type=transfer_manager.THREAD, 631 ) 632 633 if isinstance(f, io.StringIO): 634 with open(temp_file_path, "r") as fp: 635 f.write(fp.read()) 636 else: 637 with open(temp_file_path, "rb") as fp: 638 f.write(fp.read()) 639 640 os.unlink(temp_file_path) 641 642 return metadata.content_length 643 644 return self._collect_metrics( 645 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 646 )