Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import copy
 18import io
 19import json
 20import logging
 21import os
 22import tempfile
 23from collections.abc import Callable, Iterator
 24from typing import IO, Any, Optional, TypeVar, Union
 25
 26from google.api_core.exceptions import GoogleAPICallError, NotFound
 27from google.auth import credentials as auth_credentials
 28from google.auth import identity_pool
 29from google.cloud import storage
 30from google.cloud.storage import transfer_manager
 31from google.cloud.storage.exceptions import InvalidResponse
 32from google.oauth2 import service_account
 33from google.oauth2.credentials import Credentials as OAuth2Credentials
 34
 35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
 36
 37from ..rust_utils import parse_retry_config, run_async_rust_client_method
 38from ..telemetry import Telemetry
 39from ..types import (
 40    AWARE_DATETIME_MIN,
 41    Credentials,
 42    CredentialsProvider,
 43    NotModifiedError,
 44    ObjectMetadata,
 45    PreconditionFailedError,
 46    Range,
 47    RetryableError,
 48    SymlinkHandling,
 49)
 50from ..utils import (
 51    safe_makedirs,
 52    split_path,
 53    validate_attributes,
 54)
 55from .base import BaseStorageProvider
 56
 57_T = TypeVar("_T")
 58
 59PROVIDER = "gcs"
 60
 61MiB = 1024 * 1024
 62
 63DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
 64DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
 65DEFAULT_IO_CHUNKSIZE = 32 * MiB
 66PYTHON_MAX_CONCURRENCY = 8
 67
 68logger = logging.getLogger(__name__)
 69
 70
[docs] 71class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 72 """ 73 Supply a string token to the Google Identity Pool. 74 """ 75 76 def __init__(self, token: str): 77 self._token = token 78
[docs] 79 def get_subject_token(self, context, request): 80 return self._token
81 82
[docs] 83class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 84 """ 85 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 86 """ 87 88 def __init__(self, audience: str, token_supplier: str): 89 """ 90 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 91 92 :param audience: The audience for the Google Identity Pool. 93 :param token_supplier: The token supplier for the Google Identity Pool. 94 """ 95 self._audience = audience 96 self._token_supplier = token_supplier 97
[docs] 98 def get_credentials(self) -> Credentials: 99 return Credentials( 100 access_key="", 101 secret_key="", 102 token="", 103 expiration=None, 104 custom_fields={"audience": self._audience, "token": self._token_supplier}, 105 )
106
[docs] 107 def refresh_credentials(self) -> None: 108 pass
109 110
[docs] 111class GoogleServiceAccountCredentialsProvider(CredentialsProvider): 112 """ 113 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials. 114 """ 115 116 #: Google service account private key file contents. 117 _info: dict[str, Any] 118 119 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None): 120 """ 121 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a 122 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file 123 or the file contents. 124 125 :param file: Path to a Google service account private key file. 126 :param info: Google service account private key file contents. 127 """ 128 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)): 129 raise ValueError("Must specify exactly one of file or info") 130 131 if file is not None: 132 with open(file, "r") as f: 133 self._info = json.load(f) 134 elif info is not None: 135 self._info = copy.deepcopy(info) 136
[docs] 137 def get_credentials(self) -> Credentials: 138 return Credentials( 139 access_key="", 140 secret_key="", 141 token=None, 142 expiration=None, 143 custom_fields={"info": copy.deepcopy(self._info)}, 144 )
145
[docs] 146 def refresh_credentials(self) -> None: 147 pass
148 149
[docs] 150class GoogleStorageProvider(BaseStorageProvider): 151 """ 152 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 153 """ 154 155 def __init__( 156 self, 157 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 158 endpoint_url: str = "", 159 base_path: str = "", 160 credentials_provider: Optional[CredentialsProvider] = None, 161 config_dict: Optional[dict[str, Any]] = None, 162 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 163 **kwargs: Any, 164 ): 165 """ 166 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 167 168 :param project_id: The Google Cloud project ID. 169 :param endpoint_url: The custom endpoint URL for the GCS service. 170 :param base_path: The root prefix path within the bucket where all operations will be scoped. 171 :param credentials_provider: The provider to retrieve GCS credentials. 172 :param config_dict: Resolved MSC config. 173 :param telemetry_provider: A function that provides a telemetry instance. 174 """ 175 super().__init__( 176 base_path=base_path, 177 provider_name=PROVIDER, 178 config_dict=config_dict, 179 telemetry_provider=telemetry_provider, 180 ) 181 182 self._project_id = project_id 183 self._endpoint_url = endpoint_url 184 self._credentials_provider = credentials_provider 185 self._skip_signature = kwargs.get("skip_signature", False) 186 self._gcs_client = self._create_gcs_client() 187 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 188 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 189 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE) 190 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY) 191 self._rust_client = None 192 if "rust_client" in kwargs: 193 # Inherit the rust client options from the kwargs 194 rust_client_options = copy.deepcopy(kwargs["rust_client"]) 195 if "max_concurrency" in kwargs: 196 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 197 if "multipart_chunksize" in kwargs: 198 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 199 if "read_timeout" in kwargs: 200 rust_client_options["read_timeout"] = kwargs["read_timeout"] 201 if "connect_timeout" in kwargs: 202 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 203 self._rust_client = self._create_rust_client(rust_client_options) 204 205 def _create_gcs_client(self) -> storage.Client: 206 client_options = {} 207 if self._endpoint_url: 208 client_options["api_endpoint"] = self._endpoint_url 209 210 # Use anonymous credentials for public buckets when skip_signature is enabled 211 if self._skip_signature: 212 return storage.Client( 213 project=self._project_id, 214 credentials=auth_credentials.AnonymousCredentials(), 215 client_options=client_options, 216 ) 217 218 if self._credentials_provider: 219 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 220 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 221 token = self._credentials_provider.get_credentials().get_custom_field("token") 222 223 # Use Workload Identity Federation (WIF) 224 identity_pool_credentials = identity_pool.Credentials( 225 audience=audience, 226 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 227 subject_token_supplier=StringTokenSupplier(token), 228 ) 229 return storage.Client( 230 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 231 ) 232 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 233 # Use service account key. 234 service_account_credentials = service_account.Credentials.from_service_account_info( 235 info=self._credentials_provider.get_credentials().get_custom_field("info") 236 ) 237 return storage.Client( 238 project=self._project_id, credentials=service_account_credentials, client_options=client_options 239 ) 240 else: 241 # Use OAuth 2.0 token 242 token = self._credentials_provider.get_credentials().token 243 creds = OAuth2Credentials(token=token) 244 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 245 else: 246 return storage.Client(project=self._project_id, client_options=client_options) 247 248 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 249 if self._endpoint_url: 250 logger.warning("Rust client for GCS does not support customized endpoint URL, skipping rust client") 251 return None 252 253 configs = dict(rust_client_options) if rust_client_options else {} 254 255 # Extract and parse retry configuration 256 retry_config = parse_retry_config(configs) 257 258 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): 259 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") 260 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"): 261 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY") 262 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"): 263 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT") 264 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"): 265 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH") 266 267 if self._skip_signature and "skip_signature" not in configs: 268 configs["skip_signature"] = True 269 270 if "bucket" not in configs: 271 bucket, _ = split_path(self._base_path) 272 configs["bucket"] = bucket 273 274 if self._credentials_provider: 275 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 276 # Workload Identity Federation (WIF) is not supported by the rust client: 277 # https://github.com/apache/arrow-rs-object-store/issues/258 278 logger.warning("Rust client for GCS doesn't support Workload Identity Federation, skipping rust client") 279 return None 280 if isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 281 # Use service account key. 282 configs["service_account_key"] = json.dumps( 283 self._credentials_provider.get_credentials().get_custom_field("info") 284 ) 285 return RustClient( 286 provider=PROVIDER, 287 configs=configs, 288 retry=retry_config, 289 ) 290 try: 291 return RustClient( 292 provider=PROVIDER, 293 configs=configs, 294 credentials_provider=self._credentials_provider, 295 retry=retry_config, 296 ) 297 except Exception as e: 298 logger.warning(f"Failed to create rust client for GCS: {e}, falling back to Python client") 299 return None 300 301 def _refresh_gcs_client_if_needed(self) -> None: 302 """ 303 Refreshes the GCS client if the current credentials are expired. 304 """ 305 if self._credentials_provider: 306 credentials = self._credentials_provider.get_credentials() 307 if credentials.is_expired(): 308 self._credentials_provider.refresh_credentials() 309 self._gcs_client = self._create_gcs_client() 310 311 def _translate_errors( 312 self, 313 func: Callable[[], _T], 314 operation: str, 315 bucket: str, 316 key: str, 317 ) -> _T: 318 """ 319 Translates errors like timeouts and client errors. 320 321 :param func: The function that performs the actual GCS operation. 322 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 323 :param bucket: The name of the GCS bucket involved in the operation. 324 :param key: The key of the object within the GCS bucket. 325 326 :return: The result of the GCS operation, typically the return value of the `func` callable. 327 """ 328 try: 329 return func() 330 except GoogleAPICallError as error: 331 status_code = error.code if error.code else -1 332 error_info = f"status_code: {status_code}, message: {error.message}" 333 if status_code == 404: 334 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 335 elif status_code == 412: 336 raise PreconditionFailedError( 337 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 338 ) from error 339 elif status_code == 304: 340 # for if_none_match with a specific etag condition. 341 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 342 else: 343 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 344 except InvalidResponse as error: 345 response_text = error.response.text 346 error_details = f"error: {error}, error_response_text: {response_text}" 347 # Check for NoSuchUpload within the response text 348 if "NoSuchUpload" in response_text: 349 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 350 else: 351 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 352 except RustRetryableError as error: 353 raise RetryableError( 354 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 355 f"error_type: {type(error).__name__}" 356 ) from error 357 except RustClientError as error: 358 message = error.args[0] 359 status_code = error.args[1] 360 if status_code == 404: 361 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 362 elif status_code == 403: 363 raise PermissionError( 364 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 365 ) from error 366 else: 367 raise RetryableError( 368 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 369 ) from error 370 except Exception as error: 371 error_details = str(error) 372 raise RuntimeError( 373 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 374 ) from error 375 376 def _put_object( 377 self, 378 path: str, 379 body: bytes, 380 if_match: Optional[str] = None, 381 if_none_match: Optional[str] = None, 382 attributes: Optional[dict[str, str]] = None, 383 ) -> int: 384 """ 385 Uploads an object to Google Cloud Storage. 386 387 :param path: The path to the object to upload. 388 :param body: The content of the object to upload. 389 :param if_match: Optional ETag to match against the object. 390 :param if_none_match: Optional ETag to match against the object. 391 :param attributes: Optional attributes to attach to the object. 392 """ 393 bucket, key = split_path(path) 394 self._refresh_gcs_client_if_needed() 395 396 def _invoke_api() -> int: 397 bucket_obj = self._gcs_client.bucket(bucket) 398 blob = bucket_obj.blob(key) 399 400 kwargs = {} 401 402 if if_match: 403 kwargs["if_generation_match"] = int(if_match) # 412 error code 404 if if_none_match: 405 if if_none_match == "*": 406 raise NotImplementedError("if_none_match='*' is not supported for GCS") 407 else: 408 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 409 410 validated_attributes = validate_attributes(attributes) 411 if validated_attributes: 412 blob.metadata = validated_attributes 413 414 if ( 415 self._rust_client 416 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 417 and not path.endswith("/") 418 and not kwargs 419 and not validated_attributes 420 ): 421 run_async_rust_client_method(self._rust_client, "put", key, body) 422 else: 423 blob.upload_from_string(body, **kwargs) 424 425 return len(body) 426 427 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 428 429 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 430 bucket, key = split_path(path) 431 self._refresh_gcs_client_if_needed() 432 433 def _invoke_api() -> bytes: 434 bucket_obj = self._gcs_client.bucket(bucket) 435 blob = bucket_obj.blob(key) 436 if byte_range: 437 if self._rust_client: 438 return run_async_rust_client_method( 439 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 440 ) 441 else: 442 return blob.download_as_bytes( 443 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True 444 ) 445 else: 446 if self._rust_client: 447 return run_async_rust_client_method(self._rust_client, "get", key) 448 else: 449 return blob.download_as_bytes(single_shot_download=True) 450 451 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 452 453 def _copy_object(self, src_path: str, dest_path: str) -> int: 454 src_bucket, src_key = split_path(src_path) 455 dest_bucket, dest_key = split_path(dest_path) 456 self._refresh_gcs_client_if_needed() 457 458 src_object = self._get_object_metadata(src_path) 459 460 def _invoke_api() -> int: 461 source_bucket_obj = self._gcs_client.bucket(src_bucket) 462 source_blob = source_bucket_obj.blob(src_key) 463 464 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 465 destination_blob = destination_bucket_obj.blob(dest_key) 466 467 rewrite_tokens = [None] 468 while len(rewrite_tokens) > 0: 469 rewrite_token = rewrite_tokens.pop() 470 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 471 if next_rewrite_token is not None: 472 rewrite_tokens.append(next_rewrite_token) 473 474 return src_object.content_length 475 476 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 477 478 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 479 bucket, key = split_path(path) 480 self._refresh_gcs_client_if_needed() 481 482 def _invoke_api() -> None: 483 bucket_obj = self._gcs_client.bucket(bucket) 484 blob = bucket_obj.blob(key) 485 486 # If if_match is provided, use it as a precondition 487 if if_match: 488 generation = int(if_match) 489 blob.delete(if_generation_match=generation) 490 else: 491 # No if_match check needed, just delete 492 blob.delete() 493 494 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 495 496 def _delete_objects(self, paths: list[str]) -> None: 497 if not paths: 498 return 499 500 by_bucket: dict[str, list[str]] = {} 501 for p in paths: 502 bucket, key = split_path(p) 503 by_bucket.setdefault(bucket, []).append(key) 504 self._refresh_gcs_client_if_needed() 505 506 GCS_BATCH_LIMIT = 100 507 508 def _invoke_api() -> None: 509 for bucket, keys in by_bucket.items(): 510 bucket_obj = self._gcs_client.bucket(bucket) 511 for i in range(0, len(keys), GCS_BATCH_LIMIT): 512 chunk = keys[i : i + GCS_BATCH_LIMIT] 513 with self._gcs_client.batch(): 514 for k in chunk: 515 bucket_obj.blob(k).delete() 516 517 bucket_desc = "(" + "|".join(by_bucket) + ")" 518 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)" 519 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc) 520 521 def _is_dir(self, path: str) -> bool: 522 # Ensure the path ends with '/' to mimic a directory 523 path = self._append_delimiter(path) 524 525 bucket, key = split_path(path) 526 self._refresh_gcs_client_if_needed() 527 528 def _invoke_api() -> bool: 529 bucket_obj = self._gcs_client.bucket(bucket) 530 # List objects with the given prefix 531 blobs = bucket_obj.list_blobs( 532 prefix=key, 533 delimiter="/", 534 ) 535 # Check if there are any contents or common prefixes 536 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 537 538 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 539 540 def _make_symlink(self, path: str, target: str) -> None: 541 bucket_name, key = split_path(path) 542 target_bucket, target_key = split_path(target) 543 if bucket_name != target_bucket: 544 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket_name}' -> '{target_bucket}'.") 545 relative_target = ObjectMetadata.encode_symlink_target(key, target_key) 546 self._refresh_gcs_client_if_needed() 547 548 def _invoke_api() -> None: 549 bucket = self._gcs_client.bucket(bucket_name) 550 blob = bucket.blob(key) 551 blob.metadata = {"msc-symlink-target": relative_target} 552 blob.upload_from_string(b"") 553 554 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket_name, key=key) 555 556 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 557 bucket, key = split_path(path) 558 if path.endswith("/") or (bucket and not key): 559 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 560 # which metadata is not guaranteed to exist for cases such as 561 # "virtual prefix" that was never explicitly created. 562 if self._is_dir(path): 563 return ObjectMetadata( 564 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 565 ) 566 else: 567 raise FileNotFoundError(f"Directory {path} does not exist.") 568 else: 569 self._refresh_gcs_client_if_needed() 570 571 def _invoke_api() -> ObjectMetadata: 572 bucket_obj = self._gcs_client.bucket(bucket) 573 blob = bucket_obj.get_blob(key) 574 if not blob: 575 raise NotFound(f"Blob {key} not found in bucket {bucket}") 576 user_metadata = dict(blob.metadata) if blob.metadata else None 577 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None 578 return ObjectMetadata( 579 key=path, 580 content_length=blob.size or 0, 581 content_type=blob.content_type, 582 last_modified=blob.updated or AWARE_DATETIME_MIN, 583 etag=str(blob.generation), 584 metadata=user_metadata, 585 symlink_target=symlink_target, 586 ) 587 588 try: 589 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 590 except FileNotFoundError as error: 591 if strict: 592 # If the object does not exist on the given path, we will append a trailing slash and 593 # check if the path is a directory. 594 path = self._append_delimiter(path) 595 if self._is_dir(path): 596 return ObjectMetadata( 597 key=path, 598 type="directory", 599 content_length=0, 600 last_modified=AWARE_DATETIME_MIN, 601 ) 602 raise error 603 604 def _list_objects( 605 self, 606 path: str, 607 start_after: Optional[str] = None, 608 end_at: Optional[str] = None, 609 include_directories: bool = False, 610 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW, 611 ) -> Iterator[ObjectMetadata]: 612 bucket, prefix = split_path(path) 613 614 # Get the prefix of the start_after and end_at paths relative to the bucket. 615 if start_after: 616 _, start_after = split_path(start_after) 617 if end_at: 618 _, end_at = split_path(end_at) 619 620 self._refresh_gcs_client_if_needed() 621 622 def _invoke_api() -> Iterator[ObjectMetadata]: 623 bucket_obj = self._gcs_client.bucket(bucket) 624 if include_directories: 625 blobs = bucket_obj.list_blobs( 626 prefix=prefix, 627 # This is ≥ instead of >. 628 start_offset=start_after, 629 delimiter="/", 630 ) 631 else: 632 blobs = bucket_obj.list_blobs( 633 prefix=prefix, 634 # This is ≥ instead of >. 635 start_offset=start_after, 636 ) 637 638 # GCS guarantees lexicographical order. 639 for blob in blobs: 640 key = blob.name 641 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 642 if key.endswith("/"): 643 if include_directories: 644 yield ObjectMetadata( 645 key=os.path.join(bucket, key.rstrip("/")), 646 type="directory", 647 content_length=0, 648 last_modified=blob.updated, 649 ) 650 else: 651 user_metadata = blob.metadata 652 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None 653 yield ObjectMetadata( 654 key=os.path.join(bucket, key), 655 content_length=blob.size, 656 content_type=blob.content_type, 657 last_modified=blob.updated, 658 etag=blob.etag, 659 symlink_target=symlink_target, 660 ) 661 elif start_after != key: 662 return 663 664 # The directories must be accessed last. 665 if include_directories: 666 for directory in blobs.prefixes: 667 prefix_key = directory.rstrip("/") 668 # Filter by start_after and end_at if specified 669 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 670 yield ObjectMetadata( 671 key=os.path.join(bucket, prefix_key), 672 type="directory", 673 content_length=0, 674 last_modified=AWARE_DATETIME_MIN, 675 ) 676 677 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 678 679 @property 680 def supports_parallel_listing(self) -> bool: 681 return True 682 683 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 684 bucket, key = split_path(remote_path) 685 file_size: int = 0 686 self._refresh_gcs_client_if_needed() 687 688 if isinstance(f, str): 689 file_size = os.path.getsize(f) 690 691 # Upload small files 692 if file_size <= self._multipart_threshold: 693 if self._rust_client and not attributes: 694 run_async_rust_client_method(self._rust_client, "upload", f, key) 695 else: 696 with open(f, "rb") as fp: 697 self._put_object(remote_path, fp.read(), attributes=attributes) 698 return file_size 699 700 # Upload large files using transfer manager 701 def _invoke_api() -> int: 702 if self._rust_client and not attributes: 703 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 704 else: 705 bucket_obj = self._gcs_client.bucket(bucket) 706 blob = bucket_obj.blob(key) 707 # GCS will raise an error if blob.metadata is None 708 validated_attributes = validate_attributes(attributes) 709 if validated_attributes is not None: 710 blob.metadata = validated_attributes 711 transfer_manager.upload_chunks_concurrently( 712 f, 713 blob, 714 chunk_size=self._multipart_chunksize, 715 max_workers=self._max_concurrency, 716 worker_type=transfer_manager.THREAD, 717 ) 718 719 return file_size 720 721 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 722 else: 723 f.seek(0, io.SEEK_END) 724 file_size = f.tell() 725 f.seek(0) 726 727 # Upload small files 728 if file_size <= self._multipart_threshold: 729 if isinstance(f, io.StringIO): 730 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 731 else: 732 self._put_object(remote_path, f.read(), attributes=attributes) 733 return file_size 734 735 # Upload large files using transfer manager 736 def _invoke_api() -> int: 737 bucket_obj = self._gcs_client.bucket(bucket) 738 blob = bucket_obj.blob(key) 739 validated_attributes = validate_attributes(attributes) 740 if validated_attributes: 741 blob.metadata = validated_attributes 742 if isinstance(f, io.StringIO): 743 mode = "w" 744 else: 745 mode = "wb" 746 747 # transfer manager does not support uploading a file object 748 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 749 temp_file_path = fp.name 750 fp.write(f.read()) 751 752 transfer_manager.upload_chunks_concurrently( 753 temp_file_path, 754 blob, 755 chunk_size=self._multipart_chunksize, 756 max_workers=self._max_concurrency, 757 worker_type=transfer_manager.THREAD, 758 ) 759 760 os.unlink(temp_file_path) 761 762 return file_size 763 764 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 765 766 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 767 self._refresh_gcs_client_if_needed() 768 769 if metadata is None: 770 metadata = self._get_object_metadata(remote_path) 771 772 bucket, key = split_path(remote_path) 773 774 if isinstance(f, str): 775 if os.path.dirname(f): 776 safe_makedirs(os.path.dirname(f)) 777 # Download small files 778 if metadata.content_length <= self._multipart_threshold: 779 if self._rust_client: 780 run_async_rust_client_method(self._rust_client, "download", key, f) 781 else: 782 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 783 temp_file_path = fp.name 784 fp.write(self._get_object(remote_path)) 785 os.rename(src=temp_file_path, dst=f) 786 return metadata.content_length 787 788 # Download large files using transfer manager 789 def _invoke_api() -> int: 790 bucket_obj = self._gcs_client.bucket(bucket) 791 blob = bucket_obj.blob(key) 792 if self._rust_client: 793 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f) 794 else: 795 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 796 temp_file_path = fp.name 797 transfer_manager.download_chunks_concurrently( 798 blob, 799 temp_file_path, 800 chunk_size=self._io_chunksize, 801 max_workers=self._max_concurrency, 802 worker_type=transfer_manager.THREAD, 803 ) 804 os.rename(src=temp_file_path, dst=f) 805 806 return metadata.content_length 807 808 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 809 else: 810 # Download small files 811 if metadata.content_length <= self._multipart_threshold: 812 response = self._get_object(remote_path) 813 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol, 814 # so we need to check whether `.decode()` is available. 815 if isinstance(f, io.StringIO): 816 if hasattr(response, "decode"): 817 f.write(response.decode("utf-8")) 818 else: 819 f.write(codecs.decode(memoryview(response), "utf-8")) 820 else: 821 f.write(response) 822 return metadata.content_length 823 824 # Download large files using transfer manager 825 def _invoke_api() -> int: 826 bucket_obj = self._gcs_client.bucket(bucket) 827 blob = bucket_obj.blob(key) 828 829 # transfer manager does not support downloading to a file object 830 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 831 temp_file_path = fp.name 832 transfer_manager.download_chunks_concurrently( 833 blob, 834 temp_file_path, 835 chunk_size=self._io_chunksize, 836 max_workers=self._max_concurrency, 837 worker_type=transfer_manager.THREAD, 838 ) 839 840 if isinstance(f, io.StringIO): 841 with open(temp_file_path, "r") as fp: 842 f.write(fp.read()) 843 else: 844 with open(temp_file_path, "rb") as fp: 845 f.write(fp.read()) 846 847 os.unlink(temp_file_path) 848 849 return metadata.content_length 850 851 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)