Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import copy
 18import io
 19import json
 20import logging
 21import os
 22import tempfile
 23from collections.abc import Callable, Iterator
 24from typing import IO, Any, Optional, TypeVar, Union
 25
 26from google.api_core.exceptions import GoogleAPICallError, NotFound
 27from google.auth import credentials as auth_credentials
 28from google.auth import identity_pool
 29from google.cloud import storage
 30from google.cloud.storage import transfer_manager
 31from google.cloud.storage.exceptions import InvalidResponse
 32from google.oauth2 import service_account
 33from google.oauth2.credentials import Credentials as OAuth2Credentials
 34
 35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
 36
 37from ..rust_utils import parse_retry_config, run_async_rust_client_method
 38from ..telemetry import Telemetry
 39from ..types import (
 40    AWARE_DATETIME_MIN,
 41    Credentials,
 42    CredentialsProvider,
 43    NotModifiedError,
 44    ObjectMetadata,
 45    PreconditionFailedError,
 46    Range,
 47    RetryableError,
 48)
 49from ..utils import (
 50    safe_makedirs,
 51    split_path,
 52    validate_attributes,
 53)
 54from .base import BaseStorageProvider
 55
 56_T = TypeVar("_T")
 57
 58PROVIDER = "gcs"
 59
 60MiB = 1024 * 1024
 61
 62DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
 63DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
 64DEFAULT_IO_CHUNKSIZE = 32 * MiB
 65PYTHON_MAX_CONCURRENCY = 8
 66
 67logger = logging.getLogger(__name__)
 68
 69
[docs] 70class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 71 """ 72 Supply a string token to the Google Identity Pool. 73 """ 74 75 def __init__(self, token: str): 76 self._token = token 77
[docs] 78 def get_subject_token(self, context, request): 79 return self._token
80 81
[docs] 82class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 83 """ 84 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 85 """ 86 87 def __init__(self, audience: str, token_supplier: str): 88 """ 89 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 90 91 :param audience: The audience for the Google Identity Pool. 92 :param token_supplier: The token supplier for the Google Identity Pool. 93 """ 94 self._audience = audience 95 self._token_supplier = token_supplier 96
[docs] 97 def get_credentials(self) -> Credentials: 98 return Credentials( 99 access_key="", 100 secret_key="", 101 token="", 102 expiration=None, 103 custom_fields={"audience": self._audience, "token": self._token_supplier}, 104 )
105
[docs] 106 def refresh_credentials(self) -> None: 107 pass
108 109
[docs] 110class GoogleServiceAccountCredentialsProvider(CredentialsProvider): 111 """ 112 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials. 113 """ 114 115 #: Google service account private key file contents. 116 _info: dict[str, Any] 117 118 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None): 119 """ 120 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a 121 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file 122 or the file contents. 123 124 :param file: Path to a Google service account private key file. 125 :param info: Google service account private key file contents. 126 """ 127 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)): 128 raise ValueError("Must specify exactly one of file or info") 129 130 if file is not None: 131 with open(file, "r") as f: 132 self._info = json.load(f) 133 elif info is not None: 134 self._info = copy.deepcopy(info) 135
[docs] 136 def get_credentials(self) -> Credentials: 137 return Credentials( 138 access_key="", 139 secret_key="", 140 token=None, 141 expiration=None, 142 custom_fields={"info": copy.deepcopy(self._info)}, 143 )
144
[docs] 145 def refresh_credentials(self) -> None: 146 pass
147 148
[docs] 149class GoogleStorageProvider(BaseStorageProvider): 150 """ 151 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 152 """ 153 154 def __init__( 155 self, 156 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 157 endpoint_url: str = "", 158 base_path: str = "", 159 credentials_provider: Optional[CredentialsProvider] = None, 160 config_dict: Optional[dict[str, Any]] = None, 161 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 162 **kwargs: Any, 163 ): 164 """ 165 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 166 167 :param project_id: The Google Cloud project ID. 168 :param endpoint_url: The custom endpoint URL for the GCS service. 169 :param base_path: The root prefix path within the bucket where all operations will be scoped. 170 :param credentials_provider: The provider to retrieve GCS credentials. 171 :param config_dict: Resolved MSC config. 172 :param telemetry_provider: A function that provides a telemetry instance. 173 """ 174 super().__init__( 175 base_path=base_path, 176 provider_name=PROVIDER, 177 config_dict=config_dict, 178 telemetry_provider=telemetry_provider, 179 ) 180 181 self._project_id = project_id 182 self._endpoint_url = endpoint_url 183 self._credentials_provider = credentials_provider 184 self._skip_signature = kwargs.get("skip_signature", False) 185 self._gcs_client = self._create_gcs_client() 186 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 187 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 188 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE) 189 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY) 190 self._rust_client = None 191 if "rust_client" in kwargs: 192 # Inherit the rust client options from the kwargs 193 rust_client_options = copy.deepcopy(kwargs["rust_client"]) 194 if "max_concurrency" in kwargs: 195 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 196 if "multipart_chunksize" in kwargs: 197 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 198 if "read_timeout" in kwargs: 199 rust_client_options["read_timeout"] = kwargs["read_timeout"] 200 if "connect_timeout" in kwargs: 201 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 202 self._rust_client = self._create_rust_client(rust_client_options) 203 204 def _create_gcs_client(self) -> storage.Client: 205 client_options = {} 206 if self._endpoint_url: 207 client_options["api_endpoint"] = self._endpoint_url 208 209 # Use anonymous credentials for public buckets when skip_signature is enabled 210 if self._skip_signature: 211 return storage.Client( 212 project=self._project_id, 213 credentials=auth_credentials.AnonymousCredentials(), 214 client_options=client_options, 215 ) 216 217 if self._credentials_provider: 218 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 219 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 220 token = self._credentials_provider.get_credentials().get_custom_field("token") 221 222 # Use Workload Identity Federation (WIF) 223 identity_pool_credentials = identity_pool.Credentials( 224 audience=audience, 225 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 226 subject_token_supplier=StringTokenSupplier(token), 227 ) 228 return storage.Client( 229 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 230 ) 231 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 232 # Use service account key. 233 service_account_credentials = service_account.Credentials.from_service_account_info( 234 info=self._credentials_provider.get_credentials().get_custom_field("info") 235 ) 236 return storage.Client( 237 project=self._project_id, credentials=service_account_credentials, client_options=client_options 238 ) 239 else: 240 # Use OAuth 2.0 token 241 token = self._credentials_provider.get_credentials().token 242 creds = OAuth2Credentials(token=token) 243 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 244 else: 245 return storage.Client(project=self._project_id, client_options=client_options) 246 247 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 248 if self._endpoint_url: 249 logger.warning("Rust client for GCS does not support customized endpoint URL, skipping rust client") 250 return None 251 252 configs = dict(rust_client_options) if rust_client_options else {} 253 254 # Extract and parse retry configuration 255 retry_config = parse_retry_config(configs) 256 257 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): 258 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") 259 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"): 260 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY") 261 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"): 262 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT") 263 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"): 264 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH") 265 266 if self._skip_signature and "skip_signature" not in configs: 267 configs["skip_signature"] = True 268 269 if "bucket" not in configs: 270 bucket, _ = split_path(self._base_path) 271 configs["bucket"] = bucket 272 273 if self._credentials_provider: 274 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 275 # Workload Identity Federation (WIF) is not supported by the rust client: 276 # https://github.com/apache/arrow-rs-object-store/issues/258 277 logger.warning("Rust client for GCS doesn't support Workload Identity Federation, skipping rust client") 278 return None 279 if isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 280 # Use service account key. 281 configs["service_account_key"] = json.dumps( 282 self._credentials_provider.get_credentials().get_custom_field("info") 283 ) 284 return RustClient( 285 provider=PROVIDER, 286 configs=configs, 287 retry=retry_config, 288 ) 289 try: 290 return RustClient( 291 provider=PROVIDER, 292 configs=configs, 293 credentials_provider=self._credentials_provider, 294 retry=retry_config, 295 ) 296 except Exception as e: 297 logger.warning(f"Failed to create rust client for GCS: {e}, falling back to Python client") 298 return None 299 300 def _refresh_gcs_client_if_needed(self) -> None: 301 """ 302 Refreshes the GCS client if the current credentials are expired. 303 """ 304 if self._credentials_provider: 305 credentials = self._credentials_provider.get_credentials() 306 if credentials.is_expired(): 307 self._credentials_provider.refresh_credentials() 308 self._gcs_client = self._create_gcs_client() 309 310 def _translate_errors( 311 self, 312 func: Callable[[], _T], 313 operation: str, 314 bucket: str, 315 key: str, 316 ) -> _T: 317 """ 318 Translates errors like timeouts and client errors. 319 320 :param func: The function that performs the actual GCS operation. 321 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 322 :param bucket: The name of the GCS bucket involved in the operation. 323 :param key: The key of the object within the GCS bucket. 324 325 :return: The result of the GCS operation, typically the return value of the `func` callable. 326 """ 327 try: 328 return func() 329 except GoogleAPICallError as error: 330 status_code = error.code if error.code else -1 331 error_info = f"status_code: {status_code}, message: {error.message}" 332 if status_code == 404: 333 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 334 elif status_code == 412: 335 raise PreconditionFailedError( 336 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 337 ) from error 338 elif status_code == 304: 339 # for if_none_match with a specific etag condition. 340 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 341 else: 342 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 343 except InvalidResponse as error: 344 response_text = error.response.text 345 error_details = f"error: {error}, error_response_text: {response_text}" 346 # Check for NoSuchUpload within the response text 347 if "NoSuchUpload" in response_text: 348 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 349 else: 350 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 351 except RustRetryableError as error: 352 raise RetryableError( 353 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 354 f"error_type: {type(error).__name__}" 355 ) from error 356 except RustClientError as error: 357 message = error.args[0] 358 status_code = error.args[1] 359 if status_code == 404: 360 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 361 elif status_code == 403: 362 raise PermissionError( 363 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 364 ) from error 365 else: 366 raise RetryableError( 367 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 368 ) from error 369 except Exception as error: 370 error_details = str(error) 371 raise RuntimeError( 372 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 373 ) from error 374 375 def _put_object( 376 self, 377 path: str, 378 body: bytes, 379 if_match: Optional[str] = None, 380 if_none_match: Optional[str] = None, 381 attributes: Optional[dict[str, str]] = None, 382 ) -> int: 383 """ 384 Uploads an object to Google Cloud Storage. 385 386 :param path: The path to the object to upload. 387 :param body: The content of the object to upload. 388 :param if_match: Optional ETag to match against the object. 389 :param if_none_match: Optional ETag to match against the object. 390 :param attributes: Optional attributes to attach to the object. 391 """ 392 bucket, key = split_path(path) 393 self._refresh_gcs_client_if_needed() 394 395 def _invoke_api() -> int: 396 bucket_obj = self._gcs_client.bucket(bucket) 397 blob = bucket_obj.blob(key) 398 399 kwargs = {} 400 401 if if_match: 402 kwargs["if_generation_match"] = int(if_match) # 412 error code 403 if if_none_match: 404 if if_none_match == "*": 405 raise NotImplementedError("if_none_match='*' is not supported for GCS") 406 else: 407 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 408 409 validated_attributes = validate_attributes(attributes) 410 if validated_attributes: 411 blob.metadata = validated_attributes 412 413 if ( 414 self._rust_client 415 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 416 and not path.endswith("/") 417 and not kwargs 418 and not validated_attributes 419 ): 420 run_async_rust_client_method(self._rust_client, "put", key, body) 421 else: 422 blob.upload_from_string(body, **kwargs) 423 424 return len(body) 425 426 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 427 428 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 429 bucket, key = split_path(path) 430 self._refresh_gcs_client_if_needed() 431 432 def _invoke_api() -> bytes: 433 bucket_obj = self._gcs_client.bucket(bucket) 434 blob = bucket_obj.blob(key) 435 if byte_range: 436 if self._rust_client: 437 return run_async_rust_client_method( 438 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 439 ) 440 else: 441 return blob.download_as_bytes( 442 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True 443 ) 444 else: 445 if self._rust_client: 446 return run_async_rust_client_method(self._rust_client, "get", key) 447 else: 448 return blob.download_as_bytes(single_shot_download=True) 449 450 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 451 452 def _copy_object(self, src_path: str, dest_path: str) -> int: 453 src_bucket, src_key = split_path(src_path) 454 dest_bucket, dest_key = split_path(dest_path) 455 self._refresh_gcs_client_if_needed() 456 457 src_object = self._get_object_metadata(src_path) 458 459 def _invoke_api() -> int: 460 source_bucket_obj = self._gcs_client.bucket(src_bucket) 461 source_blob = source_bucket_obj.blob(src_key) 462 463 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 464 destination_blob = destination_bucket_obj.blob(dest_key) 465 466 rewrite_tokens = [None] 467 while len(rewrite_tokens) > 0: 468 rewrite_token = rewrite_tokens.pop() 469 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 470 if next_rewrite_token is not None: 471 rewrite_tokens.append(next_rewrite_token) 472 473 return src_object.content_length 474 475 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 476 477 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 478 bucket, key = split_path(path) 479 self._refresh_gcs_client_if_needed() 480 481 def _invoke_api() -> None: 482 bucket_obj = self._gcs_client.bucket(bucket) 483 blob = bucket_obj.blob(key) 484 485 # If if_match is provided, use it as a precondition 486 if if_match: 487 generation = int(if_match) 488 blob.delete(if_generation_match=generation) 489 else: 490 # No if_match check needed, just delete 491 blob.delete() 492 493 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 494 495 def _delete_objects(self, paths: list[str]) -> None: 496 if not paths: 497 return 498 499 by_bucket: dict[str, list[str]] = {} 500 for p in paths: 501 bucket, key = split_path(p) 502 by_bucket.setdefault(bucket, []).append(key) 503 self._refresh_gcs_client_if_needed() 504 505 GCS_BATCH_LIMIT = 100 506 507 def _invoke_api() -> None: 508 for bucket, keys in by_bucket.items(): 509 bucket_obj = self._gcs_client.bucket(bucket) 510 for i in range(0, len(keys), GCS_BATCH_LIMIT): 511 chunk = keys[i : i + GCS_BATCH_LIMIT] 512 with self._gcs_client.batch(): 513 for k in chunk: 514 bucket_obj.blob(k).delete() 515 516 bucket_desc = "(" + "|".join(by_bucket) + ")" 517 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)" 518 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc) 519 520 def _is_dir(self, path: str) -> bool: 521 # Ensure the path ends with '/' to mimic a directory 522 path = self._append_delimiter(path) 523 524 bucket, key = split_path(path) 525 self._refresh_gcs_client_if_needed() 526 527 def _invoke_api() -> bool: 528 bucket_obj = self._gcs_client.bucket(bucket) 529 # List objects with the given prefix 530 blobs = bucket_obj.list_blobs( 531 prefix=key, 532 delimiter="/", 533 ) 534 # Check if there are any contents or common prefixes 535 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 536 537 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 538 539 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 540 bucket, key = split_path(path) 541 if path.endswith("/") or (bucket and not key): 542 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 543 # which metadata is not guaranteed to exist for cases such as 544 # "virtual prefix" that was never explicitly created. 545 if self._is_dir(path): 546 return ObjectMetadata( 547 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 548 ) 549 else: 550 raise FileNotFoundError(f"Directory {path} does not exist.") 551 else: 552 self._refresh_gcs_client_if_needed() 553 554 def _invoke_api() -> ObjectMetadata: 555 bucket_obj = self._gcs_client.bucket(bucket) 556 blob = bucket_obj.get_blob(key) 557 if not blob: 558 raise NotFound(f"Blob {key} not found in bucket {bucket}") 559 return ObjectMetadata( 560 key=path, 561 content_length=blob.size or 0, 562 content_type=blob.content_type, 563 last_modified=blob.updated or AWARE_DATETIME_MIN, 564 etag=str(blob.generation), 565 metadata=dict(blob.metadata) if blob.metadata else None, 566 ) 567 568 try: 569 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 570 except FileNotFoundError as error: 571 if strict: 572 # If the object does not exist on the given path, we will append a trailing slash and 573 # check if the path is a directory. 574 path = self._append_delimiter(path) 575 if self._is_dir(path): 576 return ObjectMetadata( 577 key=path, 578 type="directory", 579 content_length=0, 580 last_modified=AWARE_DATETIME_MIN, 581 ) 582 raise error 583 584 def _list_objects( 585 self, 586 path: str, 587 start_after: Optional[str] = None, 588 end_at: Optional[str] = None, 589 include_directories: bool = False, 590 follow_symlinks: bool = True, 591 ) -> Iterator[ObjectMetadata]: 592 bucket, prefix = split_path(path) 593 594 # Get the prefix of the start_after and end_at paths relative to the bucket. 595 if start_after: 596 _, start_after = split_path(start_after) 597 if end_at: 598 _, end_at = split_path(end_at) 599 600 self._refresh_gcs_client_if_needed() 601 602 def _invoke_api() -> Iterator[ObjectMetadata]: 603 bucket_obj = self._gcs_client.bucket(bucket) 604 if include_directories: 605 blobs = bucket_obj.list_blobs( 606 prefix=prefix, 607 # This is ≥ instead of >. 608 start_offset=start_after, 609 delimiter="/", 610 ) 611 else: 612 blobs = bucket_obj.list_blobs( 613 prefix=prefix, 614 # This is ≥ instead of >. 615 start_offset=start_after, 616 ) 617 618 # GCS guarantees lexicographical order. 619 for blob in blobs: 620 key = blob.name 621 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 622 if key.endswith("/"): 623 if include_directories: 624 yield ObjectMetadata( 625 key=os.path.join(bucket, key.rstrip("/")), 626 type="directory", 627 content_length=0, 628 last_modified=blob.updated, 629 ) 630 else: 631 yield ObjectMetadata( 632 key=os.path.join(bucket, key), 633 content_length=blob.size, 634 content_type=blob.content_type, 635 last_modified=blob.updated, 636 etag=blob.etag, 637 ) 638 elif start_after != key: 639 return 640 641 # The directories must be accessed last. 642 if include_directories: 643 for directory in blobs.prefixes: 644 prefix_key = directory.rstrip("/") 645 # Filter by start_after and end_at if specified 646 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 647 yield ObjectMetadata( 648 key=os.path.join(bucket, prefix_key), 649 type="directory", 650 content_length=0, 651 last_modified=AWARE_DATETIME_MIN, 652 ) 653 654 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 655 656 @property 657 def supports_parallel_listing(self) -> bool: 658 return True 659 660 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 661 bucket, key = split_path(remote_path) 662 file_size: int = 0 663 self._refresh_gcs_client_if_needed() 664 665 if isinstance(f, str): 666 file_size = os.path.getsize(f) 667 668 # Upload small files 669 if file_size <= self._multipart_threshold: 670 if self._rust_client and not attributes: 671 run_async_rust_client_method(self._rust_client, "upload", f, key) 672 else: 673 with open(f, "rb") as fp: 674 self._put_object(remote_path, fp.read(), attributes=attributes) 675 return file_size 676 677 # Upload large files using transfer manager 678 def _invoke_api() -> int: 679 if self._rust_client and not attributes: 680 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 681 else: 682 bucket_obj = self._gcs_client.bucket(bucket) 683 blob = bucket_obj.blob(key) 684 # GCS will raise an error if blob.metadata is None 685 validated_attributes = validate_attributes(attributes) 686 if validated_attributes is not None: 687 blob.metadata = validated_attributes 688 transfer_manager.upload_chunks_concurrently( 689 f, 690 blob, 691 chunk_size=self._multipart_chunksize, 692 max_workers=self._max_concurrency, 693 worker_type=transfer_manager.THREAD, 694 ) 695 696 return file_size 697 698 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 699 else: 700 f.seek(0, io.SEEK_END) 701 file_size = f.tell() 702 f.seek(0) 703 704 # Upload small files 705 if file_size <= self._multipart_threshold: 706 if isinstance(f, io.StringIO): 707 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 708 else: 709 self._put_object(remote_path, f.read(), attributes=attributes) 710 return file_size 711 712 # Upload large files using transfer manager 713 def _invoke_api() -> int: 714 bucket_obj = self._gcs_client.bucket(bucket) 715 blob = bucket_obj.blob(key) 716 validated_attributes = validate_attributes(attributes) 717 if validated_attributes: 718 blob.metadata = validated_attributes 719 if isinstance(f, io.StringIO): 720 mode = "w" 721 else: 722 mode = "wb" 723 724 # transfer manager does not support uploading a file object 725 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 726 temp_file_path = fp.name 727 fp.write(f.read()) 728 729 transfer_manager.upload_chunks_concurrently( 730 temp_file_path, 731 blob, 732 chunk_size=self._multipart_chunksize, 733 max_workers=self._max_concurrency, 734 worker_type=transfer_manager.THREAD, 735 ) 736 737 os.unlink(temp_file_path) 738 739 return file_size 740 741 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 742 743 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 744 self._refresh_gcs_client_if_needed() 745 746 if metadata is None: 747 metadata = self._get_object_metadata(remote_path) 748 749 bucket, key = split_path(remote_path) 750 751 if isinstance(f, str): 752 if os.path.dirname(f): 753 safe_makedirs(os.path.dirname(f)) 754 # Download small files 755 if metadata.content_length <= self._multipart_threshold: 756 if self._rust_client: 757 run_async_rust_client_method(self._rust_client, "download", key, f) 758 else: 759 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 760 temp_file_path = fp.name 761 fp.write(self._get_object(remote_path)) 762 os.rename(src=temp_file_path, dst=f) 763 return metadata.content_length 764 765 # Download large files using transfer manager 766 def _invoke_api() -> int: 767 bucket_obj = self._gcs_client.bucket(bucket) 768 blob = bucket_obj.blob(key) 769 if self._rust_client: 770 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f) 771 else: 772 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 773 temp_file_path = fp.name 774 transfer_manager.download_chunks_concurrently( 775 blob, 776 temp_file_path, 777 chunk_size=self._io_chunksize, 778 max_workers=self._max_concurrency, 779 worker_type=transfer_manager.THREAD, 780 ) 781 os.rename(src=temp_file_path, dst=f) 782 783 return metadata.content_length 784 785 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 786 else: 787 # Download small files 788 if metadata.content_length <= self._multipart_threshold: 789 response = self._get_object(remote_path) 790 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol, 791 # so we need to check whether `.decode()` is available. 792 if isinstance(f, io.StringIO): 793 if hasattr(response, "decode"): 794 f.write(response.decode("utf-8")) 795 else: 796 f.write(codecs.decode(memoryview(response), "utf-8")) 797 else: 798 f.write(response) 799 return metadata.content_length 800 801 # Download large files using transfer manager 802 def _invoke_api() -> int: 803 bucket_obj = self._gcs_client.bucket(bucket) 804 blob = bucket_obj.blob(key) 805 806 # transfer manager does not support downloading to a file object 807 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 808 temp_file_path = fp.name 809 transfer_manager.download_chunks_concurrently( 810 blob, 811 temp_file_path, 812 chunk_size=self._io_chunksize, 813 max_workers=self._max_concurrency, 814 worker_type=transfer_manager.THREAD, 815 ) 816 817 if isinstance(f, io.StringIO): 818 with open(temp_file_path, "r") as fp: 819 f.write(fp.read()) 820 else: 821 with open(temp_file_path, "rb") as fp: 822 f.write(fp.read()) 823 824 os.unlink(temp_file_path) 825 826 return metadata.content_length 827 828 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)