Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import copy
 18import io
 19import json
 20import logging
 21import os
 22import tempfile
 23from collections.abc import Callable, Iterator
 24from typing import IO, Any, Optional, TypeVar, Union
 25
 26from google.api_core.exceptions import GoogleAPICallError, NotFound
 27from google.auth import credentials as auth_credentials
 28from google.auth import identity_pool
 29from google.cloud import storage
 30from google.cloud.storage import transfer_manager
 31from google.cloud.storage.exceptions import InvalidResponse
 32from google.oauth2 import service_account
 33from google.oauth2.credentials import Credentials as OAuth2Credentials
 34
 35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
 36
 37from ..constants import DEFAULT_MAX_POOL_CONNECTIONS
 38from ..rust_utils import parse_retry_config, run_async_rust_client_method
 39from ..telemetry import Telemetry
 40from ..types import (
 41    AWARE_DATETIME_MIN,
 42    Credentials,
 43    CredentialsProvider,
 44    NotModifiedError,
 45    ObjectMetadata,
 46    PreconditionFailedError,
 47    Range,
 48    RetryableError,
 49    SymlinkHandling,
 50)
 51from ..utils import (
 52    safe_makedirs,
 53    split_path,
 54    validate_attributes,
 55)
 56from .base import BaseStorageProvider
 57
 58_T = TypeVar("_T")
 59
 60PROVIDER = "gcs"
 61
 62MiB = 1024 * 1024
 63
 64DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
 65DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
 66DEFAULT_IO_CHUNKSIZE = 32 * MiB
 67PYTHON_MAX_CONCURRENCY = 8
 68
 69logger = logging.getLogger(__name__)
 70
 71
[docs] 72class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 73 """ 74 Supply a string token to the Google Identity Pool. 75 """ 76 77 def __init__(self, token: str): 78 self._token = token 79
[docs] 80 def get_subject_token(self, context, request): 81 return self._token
82 83
[docs] 84class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 85 """ 86 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 87 """ 88 89 def __init__(self, audience: str, token_supplier: str): 90 """ 91 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 92 93 :param audience: The audience for the Google Identity Pool. 94 :param token_supplier: The token supplier for the Google Identity Pool. 95 """ 96 self._audience = audience 97 self._token_supplier = token_supplier 98
[docs] 99 def get_credentials(self) -> Credentials: 100 return Credentials( 101 access_key="", 102 secret_key="", 103 token="", 104 expiration=None, 105 custom_fields={"audience": self._audience, "token": self._token_supplier}, 106 )
107
[docs] 108 def refresh_credentials(self) -> None: 109 pass
110 111
[docs] 112class GoogleServiceAccountCredentialsProvider(CredentialsProvider): 113 """ 114 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials. 115 """ 116 117 #: Google service account private key file contents. 118 _info: dict[str, Any] 119 120 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None): 121 """ 122 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a 123 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file 124 or the file contents. 125 126 :param file: Path to a Google service account private key file. 127 :param info: Google service account private key file contents. 128 """ 129 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)): 130 raise ValueError("Must specify exactly one of file or info") 131 132 if file is not None: 133 with open(file, "r") as f: 134 self._info = json.load(f) 135 elif info is not None: 136 self._info = copy.deepcopy(info) 137
[docs] 138 def get_credentials(self) -> Credentials: 139 return Credentials( 140 access_key="", 141 secret_key="", 142 token=None, 143 expiration=None, 144 custom_fields={"info": copy.deepcopy(self._info)}, 145 )
146
[docs] 147 def refresh_credentials(self) -> None: 148 pass
149 150
[docs] 151class GoogleStorageProvider(BaseStorageProvider): 152 """ 153 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 154 """ 155 156 def __init__( 157 self, 158 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 159 endpoint_url: str = "", 160 base_path: str = "", 161 credentials_provider: Optional[CredentialsProvider] = None, 162 config_dict: Optional[dict[str, Any]] = None, 163 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 164 **kwargs: Any, 165 ): 166 """ 167 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 168 169 :param project_id: The Google Cloud project ID. 170 :param endpoint_url: The custom endpoint URL for the GCS service. 171 :param base_path: The root prefix path within the bucket where all operations will be scoped. 172 :param credentials_provider: The provider to retrieve GCS credentials. 173 :param config_dict: Resolved MSC config. 174 :param telemetry_provider: A function that provides a telemetry instance. 175 :param max_pool_connections: Maximum connection pool size for the Rust client. 176 """ 177 super().__init__( 178 base_path=base_path, 179 provider_name=PROVIDER, 180 config_dict=config_dict, 181 telemetry_provider=telemetry_provider, 182 ) 183 184 self._project_id = project_id 185 self._endpoint_url = endpoint_url 186 self._credentials_provider = credentials_provider 187 self._skip_signature = kwargs.get("skip_signature", False) 188 self._gcs_client = self._create_gcs_client() 189 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 190 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 191 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE) 192 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY) 193 self._rust_client = None 194 if "rust_client" in kwargs: 195 # Inherit the rust client options from the kwargs 196 rust_client_options = copy.deepcopy(kwargs["rust_client"]) 197 if "max_pool_connections" in kwargs: 198 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"] 199 if "max_concurrency" in kwargs: 200 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 201 if "multipart_chunksize" in kwargs: 202 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 203 if "read_timeout" in kwargs: 204 rust_client_options["read_timeout"] = kwargs["read_timeout"] 205 if "connect_timeout" in kwargs: 206 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 207 self._rust_client = self._create_rust_client(rust_client_options) 208 209 def _create_gcs_client(self) -> storage.Client: 210 client_options = {} 211 if self._endpoint_url: 212 client_options["api_endpoint"] = self._endpoint_url 213 214 # Use anonymous credentials for public buckets when skip_signature is enabled 215 if self._skip_signature: 216 return storage.Client( 217 project=self._project_id, 218 credentials=auth_credentials.AnonymousCredentials(), 219 client_options=client_options, 220 ) 221 222 if self._credentials_provider: 223 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 224 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 225 token = self._credentials_provider.get_credentials().get_custom_field("token") 226 227 # Use Workload Identity Federation (WIF) 228 identity_pool_credentials = identity_pool.Credentials( 229 audience=audience, 230 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 231 subject_token_supplier=StringTokenSupplier(token), 232 ) 233 return storage.Client( 234 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 235 ) 236 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 237 # Use service account key. 238 service_account_credentials = service_account.Credentials.from_service_account_info( 239 info=self._credentials_provider.get_credentials().get_custom_field("info") 240 ) 241 return storage.Client( 242 project=self._project_id, credentials=service_account_credentials, client_options=client_options 243 ) 244 else: 245 # Use OAuth 2.0 token 246 token = self._credentials_provider.get_credentials().token 247 creds = OAuth2Credentials(token=token) 248 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 249 else: 250 return storage.Client(project=self._project_id, client_options=client_options) 251 252 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 253 if self._endpoint_url: 254 logger.warning("Rust client for GCS does not support customized endpoint URL, skipping rust client") 255 return None 256 257 configs = dict(rust_client_options) if rust_client_options else {} 258 259 # Extract and parse retry configuration 260 retry_config = parse_retry_config(configs) 261 262 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): 263 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") 264 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"): 265 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY") 266 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"): 267 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT") 268 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"): 269 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH") 270 271 if self._skip_signature and "skip_signature" not in configs: 272 configs["skip_signature"] = True 273 274 if "bucket" not in configs: 275 bucket, _ = split_path(self._base_path) 276 configs["bucket"] = bucket 277 278 if "max_pool_connections" not in configs: 279 configs["max_pool_connections"] = DEFAULT_MAX_POOL_CONNECTIONS 280 281 if self._credentials_provider: 282 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 283 # Workload Identity Federation (WIF) is not supported by the rust client: 284 # https://github.com/apache/arrow-rs-object-store/issues/258 285 logger.warning("Rust client for GCS doesn't support Workload Identity Federation, skipping rust client") 286 return None 287 if isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 288 # Use service account key. 289 configs["service_account_key"] = json.dumps( 290 self._credentials_provider.get_credentials().get_custom_field("info") 291 ) 292 return RustClient( 293 provider=PROVIDER, 294 configs=configs, 295 retry=retry_config, 296 ) 297 try: 298 return RustClient( 299 provider=PROVIDER, 300 configs=configs, 301 credentials_provider=self._credentials_provider, 302 retry=retry_config, 303 ) 304 except Exception as e: 305 logger.warning(f"Failed to create rust client for GCS: {e}, falling back to Python client") 306 return None 307 308 def _refresh_gcs_client_if_needed(self) -> None: 309 """ 310 Refreshes the GCS client if the current credentials are expired. 311 """ 312 if self._credentials_provider: 313 credentials = self._credentials_provider.get_credentials() 314 if credentials.is_expired(): 315 self._credentials_provider.refresh_credentials() 316 self._gcs_client = self._create_gcs_client() 317 318 def _translate_errors( 319 self, 320 func: Callable[[], _T], 321 operation: str, 322 bucket: str, 323 key: str, 324 ) -> _T: 325 """ 326 Translates errors like timeouts and client errors. 327 328 :param func: The function that performs the actual GCS operation. 329 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 330 :param bucket: The name of the GCS bucket involved in the operation. 331 :param key: The key of the object within the GCS bucket. 332 333 :return: The result of the GCS operation, typically the return value of the `func` callable. 334 """ 335 try: 336 return func() 337 except GoogleAPICallError as error: 338 status_code = error.code if error.code else -1 339 error_info = f"status_code: {status_code}, message: {error.message}" 340 if status_code == 404: 341 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 342 elif status_code == 412: 343 raise PreconditionFailedError( 344 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 345 ) from error 346 elif status_code == 304: 347 # for if_none_match with a specific etag condition. 348 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 349 else: 350 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 351 except InvalidResponse as error: 352 response_text = error.response.text 353 error_details = f"error: {error}, error_response_text: {response_text}" 354 # Check for NoSuchUpload within the response text 355 if "NoSuchUpload" in response_text: 356 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 357 else: 358 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 359 except RustRetryableError as error: 360 raise RetryableError( 361 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 362 f"error_type: {type(error).__name__}" 363 ) from error 364 except RustClientError as error: 365 message = error.args[0] 366 status_code = error.args[1] 367 if status_code == 404: 368 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 369 elif status_code == 403: 370 raise PermissionError( 371 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 372 ) from error 373 else: 374 raise RetryableError( 375 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 376 ) from error 377 except Exception as error: 378 error_details = str(error) 379 raise RuntimeError( 380 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 381 ) from error 382 383 def _put_object( 384 self, 385 path: str, 386 body: bytes, 387 if_match: Optional[str] = None, 388 if_none_match: Optional[str] = None, 389 attributes: Optional[dict[str, str]] = None, 390 ) -> int: 391 """ 392 Uploads an object to Google Cloud Storage. 393 394 :param path: The path to the object to upload. 395 :param body: The content of the object to upload. 396 :param if_match: Optional ETag to match against the object. 397 :param if_none_match: Optional ETag to match against the object. 398 :param attributes: Optional attributes to attach to the object. 399 """ 400 bucket, key = split_path(path) 401 self._refresh_gcs_client_if_needed() 402 403 def _invoke_api() -> int: 404 bucket_obj = self._gcs_client.bucket(bucket) 405 blob = bucket_obj.blob(key) 406 407 kwargs = {} 408 409 if if_match: 410 kwargs["if_generation_match"] = int(if_match) # 412 error code 411 if if_none_match: 412 if if_none_match == "*": 413 raise NotImplementedError("if_none_match='*' is not supported for GCS") 414 else: 415 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 416 417 validated_attributes = validate_attributes(attributes) 418 if validated_attributes: 419 blob.metadata = validated_attributes 420 421 if ( 422 self._rust_client 423 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 424 and not path.endswith("/") 425 and not kwargs 426 and not validated_attributes 427 ): 428 run_async_rust_client_method(self._rust_client, "put", key, body) 429 else: 430 blob.upload_from_string(body, **kwargs) 431 432 return len(body) 433 434 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 435 436 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 437 bucket, key = split_path(path) 438 self._refresh_gcs_client_if_needed() 439 440 def _invoke_api() -> bytes: 441 bucket_obj = self._gcs_client.bucket(bucket) 442 blob = bucket_obj.blob(key) 443 if byte_range: 444 if self._rust_client: 445 return run_async_rust_client_method( 446 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 447 ) 448 else: 449 return blob.download_as_bytes( 450 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True 451 ) 452 else: 453 if self._rust_client: 454 return run_async_rust_client_method(self._rust_client, "get", key) 455 else: 456 return blob.download_as_bytes(single_shot_download=True) 457 458 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 459 460 def _copy_object(self, src_path: str, dest_path: str) -> int: 461 src_bucket, src_key = split_path(src_path) 462 dest_bucket, dest_key = split_path(dest_path) 463 self._refresh_gcs_client_if_needed() 464 465 src_object = self._get_object_metadata(src_path) 466 467 def _invoke_api() -> int: 468 source_bucket_obj = self._gcs_client.bucket(src_bucket) 469 source_blob = source_bucket_obj.blob(src_key) 470 471 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 472 destination_blob = destination_bucket_obj.blob(dest_key) 473 474 rewrite_tokens = [None] 475 while len(rewrite_tokens) > 0: 476 rewrite_token = rewrite_tokens.pop() 477 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 478 if next_rewrite_token is not None: 479 rewrite_tokens.append(next_rewrite_token) 480 481 return src_object.content_length 482 483 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 484 485 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 486 bucket, key = split_path(path) 487 self._refresh_gcs_client_if_needed() 488 489 def _invoke_api() -> None: 490 bucket_obj = self._gcs_client.bucket(bucket) 491 blob = bucket_obj.blob(key) 492 493 # If if_match is provided, use it as a precondition 494 if if_match: 495 generation = int(if_match) 496 blob.delete(if_generation_match=generation) 497 else: 498 # No if_match check needed, just delete 499 blob.delete() 500 501 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 502 503 def _delete_objects(self, paths: list[str]) -> None: 504 if not paths: 505 return 506 507 by_bucket: dict[str, list[str]] = {} 508 for p in paths: 509 bucket, key = split_path(p) 510 by_bucket.setdefault(bucket, []).append(key) 511 self._refresh_gcs_client_if_needed() 512 513 GCS_BATCH_LIMIT = 100 514 515 def _invoke_api() -> None: 516 for bucket, keys in by_bucket.items(): 517 bucket_obj = self._gcs_client.bucket(bucket) 518 for i in range(0, len(keys), GCS_BATCH_LIMIT): 519 chunk = keys[i : i + GCS_BATCH_LIMIT] 520 with self._gcs_client.batch(): 521 for k in chunk: 522 bucket_obj.blob(k).delete() 523 524 bucket_desc = "(" + "|".join(by_bucket) + ")" 525 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)" 526 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc) 527 528 def _is_dir(self, path: str) -> bool: 529 # Ensure the path ends with '/' to mimic a directory 530 path = self._append_delimiter(path) 531 532 bucket, key = split_path(path) 533 self._refresh_gcs_client_if_needed() 534 535 def _invoke_api() -> bool: 536 bucket_obj = self._gcs_client.bucket(bucket) 537 # List objects with the given prefix 538 blobs = bucket_obj.list_blobs( 539 prefix=key, 540 delimiter="/", 541 ) 542 # Check if there are any contents or common prefixes 543 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 544 545 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 546 547 def _make_symlink(self, path: str, target: str) -> None: 548 bucket_name, key = split_path(path) 549 target_bucket, target_key = split_path(target) 550 if bucket_name != target_bucket: 551 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket_name}' -> '{target_bucket}'.") 552 relative_target = ObjectMetadata.encode_symlink_target(key, target_key) 553 self._refresh_gcs_client_if_needed() 554 555 def _invoke_api() -> None: 556 bucket = self._gcs_client.bucket(bucket_name) 557 blob = bucket.blob(key) 558 blob.metadata = {"msc-symlink-target": relative_target} 559 blob.upload_from_string(b"") 560 561 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket_name, key=key) 562 563 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 564 bucket, key = split_path(path) 565 if path.endswith("/") or (bucket and not key): 566 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 567 # which metadata is not guaranteed to exist for cases such as 568 # "virtual prefix" that was never explicitly created. 569 if self._is_dir(path): 570 return ObjectMetadata( 571 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 572 ) 573 else: 574 raise FileNotFoundError(f"Directory {path} does not exist.") 575 else: 576 self._refresh_gcs_client_if_needed() 577 578 def _invoke_api() -> ObjectMetadata: 579 bucket_obj = self._gcs_client.bucket(bucket) 580 blob = bucket_obj.get_blob(key) 581 if not blob: 582 raise NotFound(f"Blob {key} not found in bucket {bucket}") 583 user_metadata = dict(blob.metadata) if blob.metadata else None 584 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None 585 return ObjectMetadata( 586 key=path, 587 content_length=blob.size or 0, 588 content_type=blob.content_type, 589 last_modified=blob.updated or AWARE_DATETIME_MIN, 590 etag=str(blob.generation), 591 metadata=user_metadata, 592 symlink_target=symlink_target, 593 ) 594 595 try: 596 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 597 except FileNotFoundError as error: 598 if strict: 599 # If the object does not exist on the given path, we will append a trailing slash and 600 # check if the path is a directory. 601 path = self._append_delimiter(path) 602 if self._is_dir(path): 603 return ObjectMetadata( 604 key=path, 605 type="directory", 606 content_length=0, 607 last_modified=AWARE_DATETIME_MIN, 608 ) 609 raise error 610 611 def _list_objects( 612 self, 613 path: str, 614 start_after: Optional[str] = None, 615 end_at: Optional[str] = None, 616 include_directories: bool = False, 617 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW, 618 ) -> Iterator[ObjectMetadata]: 619 bucket, prefix = split_path(path) 620 621 # Get the prefix of the start_after and end_at paths relative to the bucket. 622 if start_after: 623 _, start_after = split_path(start_after) 624 if end_at: 625 _, end_at = split_path(end_at) 626 627 self._refresh_gcs_client_if_needed() 628 629 def _invoke_api() -> Iterator[ObjectMetadata]: 630 bucket_obj = self._gcs_client.bucket(bucket) 631 if include_directories: 632 blobs = bucket_obj.list_blobs( 633 prefix=prefix, 634 # This is ≥ instead of >. 635 start_offset=start_after, 636 delimiter="/", 637 ) 638 else: 639 blobs = bucket_obj.list_blobs( 640 prefix=prefix, 641 # This is ≥ instead of >. 642 start_offset=start_after, 643 ) 644 645 # GCS guarantees lexicographical order. 646 for blob in blobs: 647 key = blob.name 648 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 649 if key.endswith("/"): 650 if include_directories: 651 yield ObjectMetadata( 652 key=os.path.join(bucket, key.rstrip("/")), 653 type="directory", 654 content_length=0, 655 last_modified=blob.updated, 656 ) 657 else: 658 user_metadata = blob.metadata 659 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None 660 yield ObjectMetadata( 661 key=os.path.join(bucket, key), 662 content_length=blob.size, 663 content_type=blob.content_type, 664 last_modified=blob.updated, 665 etag=blob.etag, 666 symlink_target=symlink_target, 667 ) 668 elif start_after != key: 669 return 670 671 # The directories must be accessed last. 672 if include_directories: 673 for directory in blobs.prefixes: 674 prefix_key = directory.rstrip("/") 675 # Filter by start_after and end_at if specified 676 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 677 yield ObjectMetadata( 678 key=os.path.join(bucket, prefix_key), 679 type="directory", 680 content_length=0, 681 last_modified=AWARE_DATETIME_MIN, 682 ) 683 684 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 685 686 @property 687 def supports_parallel_listing(self) -> bool: 688 return True 689 690 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 691 bucket, key = split_path(remote_path) 692 file_size: int = 0 693 self._refresh_gcs_client_if_needed() 694 695 if isinstance(f, str): 696 file_size = os.path.getsize(f) 697 698 # Upload small files 699 if file_size <= self._multipart_threshold: 700 if self._rust_client and not attributes: 701 run_async_rust_client_method(self._rust_client, "upload", f, key) 702 else: 703 with open(f, "rb") as fp: 704 self._put_object(remote_path, fp.read(), attributes=attributes) 705 return file_size 706 707 # Upload large files using transfer manager 708 def _invoke_api() -> int: 709 if self._rust_client and not attributes: 710 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 711 else: 712 bucket_obj = self._gcs_client.bucket(bucket) 713 blob = bucket_obj.blob(key) 714 # GCS will raise an error if blob.metadata is None 715 validated_attributes = validate_attributes(attributes) 716 if validated_attributes is not None: 717 blob.metadata = validated_attributes 718 transfer_manager.upload_chunks_concurrently( 719 f, 720 blob, 721 chunk_size=self._multipart_chunksize, 722 max_workers=self._max_concurrency, 723 worker_type=transfer_manager.THREAD, 724 ) 725 726 return file_size 727 728 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 729 else: 730 f.seek(0, io.SEEK_END) 731 file_size = f.tell() 732 f.seek(0) 733 734 # Upload small files 735 if file_size <= self._multipart_threshold: 736 if isinstance(f, io.StringIO): 737 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 738 else: 739 self._put_object(remote_path, f.read(), attributes=attributes) 740 return file_size 741 742 # Upload large files using transfer manager 743 def _invoke_api() -> int: 744 bucket_obj = self._gcs_client.bucket(bucket) 745 blob = bucket_obj.blob(key) 746 validated_attributes = validate_attributes(attributes) 747 if validated_attributes: 748 blob.metadata = validated_attributes 749 if isinstance(f, io.StringIO): 750 mode = "w" 751 else: 752 mode = "wb" 753 754 # transfer manager does not support uploading a file object 755 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 756 temp_file_path = fp.name 757 fp.write(f.read()) 758 759 transfer_manager.upload_chunks_concurrently( 760 temp_file_path, 761 blob, 762 chunk_size=self._multipart_chunksize, 763 max_workers=self._max_concurrency, 764 worker_type=transfer_manager.THREAD, 765 ) 766 767 os.unlink(temp_file_path) 768 769 return file_size 770 771 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 772 773 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 774 self._refresh_gcs_client_if_needed() 775 776 if metadata is None: 777 metadata = self._get_object_metadata(remote_path) 778 779 bucket, key = split_path(remote_path) 780 781 if isinstance(f, str): 782 if os.path.dirname(f): 783 safe_makedirs(os.path.dirname(f)) 784 # Download small files 785 if metadata.content_length <= self._multipart_threshold: 786 if self._rust_client: 787 run_async_rust_client_method(self._rust_client, "download", key, f) 788 else: 789 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 790 temp_file_path = fp.name 791 fp.write(self._get_object(remote_path)) 792 os.rename(src=temp_file_path, dst=f) 793 return metadata.content_length 794 795 # Download large files using transfer manager 796 def _invoke_api() -> int: 797 bucket_obj = self._gcs_client.bucket(bucket) 798 blob = bucket_obj.blob(key) 799 if self._rust_client: 800 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f) 801 else: 802 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 803 temp_file_path = fp.name 804 transfer_manager.download_chunks_concurrently( 805 blob, 806 temp_file_path, 807 chunk_size=self._io_chunksize, 808 max_workers=self._max_concurrency, 809 worker_type=transfer_manager.THREAD, 810 ) 811 os.rename(src=temp_file_path, dst=f) 812 813 return metadata.content_length 814 815 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 816 else: 817 # Download small files 818 if metadata.content_length <= self._multipart_threshold: 819 response = self._get_object(remote_path) 820 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol, 821 # so we need to check whether `.decode()` is available. 822 if isinstance(f, io.StringIO): 823 if hasattr(response, "decode"): 824 f.write(response.decode("utf-8")) 825 else: 826 f.write(codecs.decode(memoryview(response), "utf-8")) 827 else: 828 f.write(response) 829 return metadata.content_length 830 831 # Download large files using transfer manager 832 def _invoke_api() -> int: 833 bucket_obj = self._gcs_client.bucket(bucket) 834 blob = bucket_obj.blob(key) 835 836 # transfer manager does not support downloading to a file object 837 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 838 temp_file_path = fp.name 839 transfer_manager.download_chunks_concurrently( 840 blob, 841 temp_file_path, 842 chunk_size=self._io_chunksize, 843 max_workers=self._max_concurrency, 844 worker_type=transfer_manager.THREAD, 845 ) 846 847 if isinstance(f, io.StringIO): 848 with open(temp_file_path, "r") as fp: 849 f.write(fp.read()) 850 else: 851 with open(temp_file_path, "rb") as fp: 852 f.write(fp.read()) 853 854 os.unlink(temp_file_path) 855 856 return metadata.content_length 857 858 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)