Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import copy
 18import io
 19import json
 20import logging
 21import os
 22import tempfile
 23from collections.abc import Callable, Iterator
 24from typing import IO, Any, Optional, TypeVar, Union
 25
 26from google.api_core.exceptions import GoogleAPICallError, NotFound
 27from google.auth import credentials as auth_credentials
 28from google.auth import identity_pool
 29from google.cloud import storage
 30from google.cloud.storage import transfer_manager
 31from google.cloud.storage.exceptions import InvalidResponse
 32from google.oauth2 import service_account
 33from google.oauth2.credentials import Credentials as OAuth2Credentials
 34
 35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
 36
 37from ..rust_utils import parse_retry_config, run_async_rust_client_method
 38from ..telemetry import Telemetry
 39from ..types import (
 40    AWARE_DATETIME_MIN,
 41    Credentials,
 42    CredentialsProvider,
 43    NotModifiedError,
 44    ObjectMetadata,
 45    PreconditionFailedError,
 46    Range,
 47    RetryableError,
 48)
 49from ..utils import (
 50    safe_makedirs,
 51    split_path,
 52    validate_attributes,
 53)
 54from .base import BaseStorageProvider
 55
 56_T = TypeVar("_T")
 57
 58PROVIDER = "gcs"
 59
 60MiB = 1024 * 1024
 61
 62DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
 63DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
 64DEFAULT_IO_CHUNKSIZE = 32 * MiB
 65PYTHON_MAX_CONCURRENCY = 8
 66
 67logger = logging.getLogger(__name__)
 68
 69
[docs] 70class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 71 """ 72 Supply a string token to the Google Identity Pool. 73 """ 74 75 def __init__(self, token: str): 76 self._token = token 77
[docs] 78 def get_subject_token(self, context, request): 79 return self._token
80 81
[docs] 82class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 83 """ 84 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 85 """ 86 87 def __init__(self, audience: str, token_supplier: str): 88 """ 89 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 90 91 :param audience: The audience for the Google Identity Pool. 92 :param token_supplier: The token supplier for the Google Identity Pool. 93 """ 94 self._audience = audience 95 self._token_supplier = token_supplier 96
[docs] 97 def get_credentials(self) -> Credentials: 98 return Credentials( 99 access_key="", 100 secret_key="", 101 token="", 102 expiration=None, 103 custom_fields={"audience": self._audience, "token": self._token_supplier}, 104 )
105
[docs] 106 def refresh_credentials(self) -> None: 107 pass
108 109
[docs] 110class GoogleServiceAccountCredentialsProvider(CredentialsProvider): 111 """ 112 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials. 113 """ 114 115 #: Google service account private key file contents. 116 _info: dict[str, Any] 117 118 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None): 119 """ 120 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a 121 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file 122 or the file contents. 123 124 :param file: Path to a Google service account private key file. 125 :param info: Google service account private key file contents. 126 """ 127 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)): 128 raise ValueError("Must specify exactly one of file or info") 129 130 if file is not None: 131 with open(file, "r") as f: 132 self._info = json.load(f) 133 elif info is not None: 134 self._info = copy.deepcopy(info) 135
[docs] 136 def get_credentials(self) -> Credentials: 137 return Credentials( 138 access_key="", 139 secret_key="", 140 token=None, 141 expiration=None, 142 custom_fields={"info": copy.deepcopy(self._info)}, 143 )
144
[docs] 145 def refresh_credentials(self) -> None: 146 pass
147 148
[docs] 149class GoogleStorageProvider(BaseStorageProvider): 150 """ 151 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 152 """ 153 154 def __init__( 155 self, 156 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 157 endpoint_url: str = "", 158 base_path: str = "", 159 credentials_provider: Optional[CredentialsProvider] = None, 160 config_dict: Optional[dict[str, Any]] = None, 161 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 162 **kwargs: Any, 163 ): 164 """ 165 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 166 167 :param project_id: The Google Cloud project ID. 168 :param endpoint_url: The custom endpoint URL for the GCS service. 169 :param base_path: The root prefix path within the bucket where all operations will be scoped. 170 :param credentials_provider: The provider to retrieve GCS credentials. 171 :param config_dict: Resolved MSC config. 172 :param telemetry_provider: A function that provides a telemetry instance. 173 """ 174 super().__init__( 175 base_path=base_path, 176 provider_name=PROVIDER, 177 config_dict=config_dict, 178 telemetry_provider=telemetry_provider, 179 ) 180 181 self._project_id = project_id 182 self._endpoint_url = endpoint_url 183 self._credentials_provider = credentials_provider 184 self._skip_signature = kwargs.get("skip_signature", False) 185 self._gcs_client = self._create_gcs_client() 186 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 187 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 188 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE) 189 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY) 190 self._rust_client = None 191 if "rust_client" in kwargs: 192 # Inherit the rust client options from the kwargs 193 rust_client_options = copy.deepcopy(kwargs["rust_client"]) 194 if "max_concurrency" in kwargs: 195 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 196 if "multipart_chunksize" in kwargs: 197 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 198 if "read_timeout" in kwargs: 199 rust_client_options["read_timeout"] = kwargs["read_timeout"] 200 if "connect_timeout" in kwargs: 201 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 202 self._rust_client = self._create_rust_client(rust_client_options) 203 204 def _create_gcs_client(self) -> storage.Client: 205 client_options = {} 206 if self._endpoint_url: 207 client_options["api_endpoint"] = self._endpoint_url 208 209 # Use anonymous credentials for public buckets when skip_signature is enabled 210 if self._skip_signature: 211 return storage.Client( 212 project=self._project_id, 213 credentials=auth_credentials.AnonymousCredentials(), 214 client_options=client_options, 215 ) 216 217 if self._credentials_provider: 218 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 219 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 220 token = self._credentials_provider.get_credentials().get_custom_field("token") 221 222 # Use Workload Identity Federation (WIF) 223 identity_pool_credentials = identity_pool.Credentials( 224 audience=audience, 225 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 226 subject_token_supplier=StringTokenSupplier(token), 227 ) 228 return storage.Client( 229 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 230 ) 231 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 232 # Use service account key. 233 service_account_credentials = service_account.Credentials.from_service_account_info( 234 info=self._credentials_provider.get_credentials().get_custom_field("info") 235 ) 236 return storage.Client( 237 project=self._project_id, credentials=service_account_credentials, client_options=client_options 238 ) 239 else: 240 # Use OAuth 2.0 token 241 token = self._credentials_provider.get_credentials().token 242 creds = OAuth2Credentials(token=token) 243 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 244 else: 245 return storage.Client(project=self._project_id, client_options=client_options) 246 247 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 248 if self._endpoint_url: 249 logger.warning("Rust client for GCS does not support customized endpoint URL, skipping rust client") 250 return None 251 252 configs = dict(rust_client_options) if rust_client_options else {} 253 254 # Extract and parse retry configuration 255 retry_config = parse_retry_config(configs) 256 257 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): 258 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") 259 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"): 260 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY") 261 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"): 262 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT") 263 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"): 264 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH") 265 266 if self._skip_signature and "skip_signature" not in configs: 267 configs["skip_signature"] = True 268 269 if "bucket" not in configs: 270 bucket, _ = split_path(self._base_path) 271 configs["bucket"] = bucket 272 273 if self._credentials_provider: 274 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 275 # Workload Identity Federation (WIF) is not supported by the rust client: 276 # https://github.com/apache/arrow-rs-object-store/issues/258 277 logger.warning("Rust client for GCS doesn't support Workload Identity Federation, skipping rust client") 278 return None 279 if isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 280 # Use service account key. 281 configs["service_account_key"] = json.dumps( 282 self._credentials_provider.get_credentials().get_custom_field("info") 283 ) 284 return RustClient( 285 provider=PROVIDER, 286 configs=configs, 287 retry=retry_config, 288 ) 289 try: 290 return RustClient( 291 provider=PROVIDER, 292 configs=configs, 293 credentials_provider=self._credentials_provider, 294 retry=retry_config, 295 ) 296 except Exception as e: 297 logger.warning(f"Failed to create rust client for GCS: {e}, falling back to Python client") 298 return None 299 300 def _refresh_gcs_client_if_needed(self) -> None: 301 """ 302 Refreshes the GCS client if the current credentials are expired. 303 """ 304 if self._credentials_provider: 305 credentials = self._credentials_provider.get_credentials() 306 if credentials.is_expired(): 307 self._credentials_provider.refresh_credentials() 308 self._gcs_client = self._create_gcs_client() 309 310 def _translate_errors( 311 self, 312 func: Callable[[], _T], 313 operation: str, 314 bucket: str, 315 key: str, 316 ) -> _T: 317 """ 318 Translates errors like timeouts and client errors. 319 320 :param func: The function that performs the actual GCS operation. 321 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 322 :param bucket: The name of the GCS bucket involved in the operation. 323 :param key: The key of the object within the GCS bucket. 324 325 :return: The result of the GCS operation, typically the return value of the `func` callable. 326 """ 327 try: 328 return func() 329 except GoogleAPICallError as error: 330 status_code = error.code if error.code else -1 331 error_info = f"status_code: {status_code}, message: {error.message}" 332 if status_code == 404: 333 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 334 elif status_code == 412: 335 raise PreconditionFailedError( 336 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 337 ) from error 338 elif status_code == 304: 339 # for if_none_match with a specific etag condition. 340 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 341 else: 342 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 343 except InvalidResponse as error: 344 response_text = error.response.text 345 error_details = f"error: {error}, error_response_text: {response_text}" 346 # Check for NoSuchUpload within the response text 347 if "NoSuchUpload" in response_text: 348 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 349 else: 350 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 351 except RustRetryableError as error: 352 raise RetryableError( 353 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 354 f"error_type: {type(error).__name__}" 355 ) from error 356 except RustClientError as error: 357 message = error.args[0] 358 status_code = error.args[1] 359 if status_code == 404: 360 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 361 elif status_code == 403: 362 raise PermissionError( 363 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 364 ) from error 365 else: 366 raise RetryableError( 367 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 368 ) from error 369 except Exception as error: 370 error_details = str(error) 371 raise RuntimeError( 372 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 373 ) from error 374 375 def _put_object( 376 self, 377 path: str, 378 body: bytes, 379 if_match: Optional[str] = None, 380 if_none_match: Optional[str] = None, 381 attributes: Optional[dict[str, str]] = None, 382 ) -> int: 383 """ 384 Uploads an object to Google Cloud Storage. 385 386 :param path: The path to the object to upload. 387 :param body: The content of the object to upload. 388 :param if_match: Optional ETag to match against the object. 389 :param if_none_match: Optional ETag to match against the object. 390 :param attributes: Optional attributes to attach to the object. 391 """ 392 bucket, key = split_path(path) 393 self._refresh_gcs_client_if_needed() 394 395 def _invoke_api() -> int: 396 bucket_obj = self._gcs_client.bucket(bucket) 397 blob = bucket_obj.blob(key) 398 399 kwargs = {} 400 401 if if_match: 402 kwargs["if_generation_match"] = int(if_match) # 412 error code 403 if if_none_match: 404 if if_none_match == "*": 405 raise NotImplementedError("if_none_match='*' is not supported for GCS") 406 else: 407 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 408 409 validated_attributes = validate_attributes(attributes) 410 if validated_attributes: 411 blob.metadata = validated_attributes 412 413 if ( 414 self._rust_client 415 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 416 and not path.endswith("/") 417 and not kwargs 418 and not validated_attributes 419 ): 420 run_async_rust_client_method(self._rust_client, "put", key, body) 421 else: 422 blob.upload_from_string(body, **kwargs) 423 424 return len(body) 425 426 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 427 428 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 429 bucket, key = split_path(path) 430 self._refresh_gcs_client_if_needed() 431 432 def _invoke_api() -> bytes: 433 bucket_obj = self._gcs_client.bucket(bucket) 434 blob = bucket_obj.blob(key) 435 if byte_range: 436 if self._rust_client: 437 return run_async_rust_client_method( 438 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 439 ) 440 else: 441 return blob.download_as_bytes( 442 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True 443 ) 444 else: 445 if self._rust_client: 446 return run_async_rust_client_method(self._rust_client, "get", key) 447 else: 448 return blob.download_as_bytes(single_shot_download=True) 449 450 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 451 452 def _copy_object(self, src_path: str, dest_path: str) -> int: 453 src_bucket, src_key = split_path(src_path) 454 dest_bucket, dest_key = split_path(dest_path) 455 self._refresh_gcs_client_if_needed() 456 457 src_object = self._get_object_metadata(src_path) 458 459 def _invoke_api() -> int: 460 source_bucket_obj = self._gcs_client.bucket(src_bucket) 461 source_blob = source_bucket_obj.blob(src_key) 462 463 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 464 destination_blob = destination_bucket_obj.blob(dest_key) 465 466 rewrite_tokens = [None] 467 while len(rewrite_tokens) > 0: 468 rewrite_token = rewrite_tokens.pop() 469 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 470 if next_rewrite_token is not None: 471 rewrite_tokens.append(next_rewrite_token) 472 473 return src_object.content_length 474 475 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 476 477 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 478 bucket, key = split_path(path) 479 self._refresh_gcs_client_if_needed() 480 481 def _invoke_api() -> None: 482 bucket_obj = self._gcs_client.bucket(bucket) 483 blob = bucket_obj.blob(key) 484 485 # If if_match is provided, use it as a precondition 486 if if_match: 487 generation = int(if_match) 488 blob.delete(if_generation_match=generation) 489 else: 490 # No if_match check needed, just delete 491 blob.delete() 492 493 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 494 495 def _is_dir(self, path: str) -> bool: 496 # Ensure the path ends with '/' to mimic a directory 497 path = self._append_delimiter(path) 498 499 bucket, key = split_path(path) 500 self._refresh_gcs_client_if_needed() 501 502 def _invoke_api() -> bool: 503 bucket_obj = self._gcs_client.bucket(bucket) 504 # List objects with the given prefix 505 blobs = bucket_obj.list_blobs( 506 prefix=key, 507 delimiter="/", 508 ) 509 # Check if there are any contents or common prefixes 510 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 511 512 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 513 514 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 515 bucket, key = split_path(path) 516 if path.endswith("/") or (bucket and not key): 517 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 518 # which metadata is not guaranteed to exist for cases such as 519 # "virtual prefix" that was never explicitly created. 520 if self._is_dir(path): 521 return ObjectMetadata( 522 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 523 ) 524 else: 525 raise FileNotFoundError(f"Directory {path} does not exist.") 526 else: 527 self._refresh_gcs_client_if_needed() 528 529 def _invoke_api() -> ObjectMetadata: 530 bucket_obj = self._gcs_client.bucket(bucket) 531 blob = bucket_obj.get_blob(key) 532 if not blob: 533 raise NotFound(f"Blob {key} not found in bucket {bucket}") 534 return ObjectMetadata( 535 key=path, 536 content_length=blob.size or 0, 537 content_type=blob.content_type, 538 last_modified=blob.updated or AWARE_DATETIME_MIN, 539 etag=str(blob.generation), 540 metadata=dict(blob.metadata) if blob.metadata else None, 541 ) 542 543 try: 544 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 545 except FileNotFoundError as error: 546 if strict: 547 # If the object does not exist on the given path, we will append a trailing slash and 548 # check if the path is a directory. 549 path = self._append_delimiter(path) 550 if self._is_dir(path): 551 return ObjectMetadata( 552 key=path, 553 type="directory", 554 content_length=0, 555 last_modified=AWARE_DATETIME_MIN, 556 ) 557 raise error 558 559 def _list_objects( 560 self, 561 path: str, 562 start_after: Optional[str] = None, 563 end_at: Optional[str] = None, 564 include_directories: bool = False, 565 follow_symlinks: bool = True, 566 ) -> Iterator[ObjectMetadata]: 567 bucket, prefix = split_path(path) 568 569 # Get the prefix of the start_after and end_at paths relative to the bucket. 570 if start_after: 571 _, start_after = split_path(start_after) 572 if end_at: 573 _, end_at = split_path(end_at) 574 575 self._refresh_gcs_client_if_needed() 576 577 def _invoke_api() -> Iterator[ObjectMetadata]: 578 bucket_obj = self._gcs_client.bucket(bucket) 579 if include_directories: 580 blobs = bucket_obj.list_blobs( 581 prefix=prefix, 582 # This is ≥ instead of >. 583 start_offset=start_after, 584 delimiter="/", 585 ) 586 else: 587 blobs = bucket_obj.list_blobs( 588 prefix=prefix, 589 # This is ≥ instead of >. 590 start_offset=start_after, 591 ) 592 593 # GCS guarantees lexicographical order. 594 for blob in blobs: 595 key = blob.name 596 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 597 if key.endswith("/"): 598 if include_directories: 599 yield ObjectMetadata( 600 key=os.path.join(bucket, key.rstrip("/")), 601 type="directory", 602 content_length=0, 603 last_modified=blob.updated, 604 ) 605 else: 606 yield ObjectMetadata( 607 key=os.path.join(bucket, key), 608 content_length=blob.size, 609 content_type=blob.content_type, 610 last_modified=blob.updated, 611 etag=blob.etag, 612 ) 613 elif start_after != key: 614 return 615 616 # The directories must be accessed last. 617 if include_directories: 618 for directory in blobs.prefixes: 619 prefix_key = directory.rstrip("/") 620 # Filter by start_after and end_at if specified 621 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 622 yield ObjectMetadata( 623 key=os.path.join(bucket, prefix_key), 624 type="directory", 625 content_length=0, 626 last_modified=AWARE_DATETIME_MIN, 627 ) 628 629 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 630 631 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 632 bucket, key = split_path(remote_path) 633 file_size: int = 0 634 self._refresh_gcs_client_if_needed() 635 636 if isinstance(f, str): 637 file_size = os.path.getsize(f) 638 639 # Upload small files 640 if file_size <= self._multipart_threshold: 641 if self._rust_client and not attributes: 642 run_async_rust_client_method(self._rust_client, "upload", f, key) 643 else: 644 with open(f, "rb") as fp: 645 self._put_object(remote_path, fp.read(), attributes=attributes) 646 return file_size 647 648 # Upload large files using transfer manager 649 def _invoke_api() -> int: 650 if self._rust_client and not attributes: 651 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 652 else: 653 bucket_obj = self._gcs_client.bucket(bucket) 654 blob = bucket_obj.blob(key) 655 # GCS will raise an error if blob.metadata is None 656 validated_attributes = validate_attributes(attributes) 657 if validated_attributes is not None: 658 blob.metadata = validated_attributes 659 transfer_manager.upload_chunks_concurrently( 660 f, 661 blob, 662 chunk_size=self._multipart_chunksize, 663 max_workers=self._max_concurrency, 664 worker_type=transfer_manager.THREAD, 665 ) 666 667 return file_size 668 669 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 670 else: 671 f.seek(0, io.SEEK_END) 672 file_size = f.tell() 673 f.seek(0) 674 675 # Upload small files 676 if file_size <= self._multipart_threshold: 677 if isinstance(f, io.StringIO): 678 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 679 else: 680 self._put_object(remote_path, f.read(), attributes=attributes) 681 return file_size 682 683 # Upload large files using transfer manager 684 def _invoke_api() -> int: 685 bucket_obj = self._gcs_client.bucket(bucket) 686 blob = bucket_obj.blob(key) 687 validated_attributes = validate_attributes(attributes) 688 if validated_attributes: 689 blob.metadata = validated_attributes 690 if isinstance(f, io.StringIO): 691 mode = "w" 692 else: 693 mode = "wb" 694 695 # transfer manager does not support uploading a file object 696 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 697 temp_file_path = fp.name 698 fp.write(f.read()) 699 700 transfer_manager.upload_chunks_concurrently( 701 temp_file_path, 702 blob, 703 chunk_size=self._multipart_chunksize, 704 max_workers=self._max_concurrency, 705 worker_type=transfer_manager.THREAD, 706 ) 707 708 os.unlink(temp_file_path) 709 710 return file_size 711 712 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 713 714 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 715 self._refresh_gcs_client_if_needed() 716 717 if metadata is None: 718 metadata = self._get_object_metadata(remote_path) 719 720 bucket, key = split_path(remote_path) 721 722 if isinstance(f, str): 723 if os.path.dirname(f): 724 safe_makedirs(os.path.dirname(f)) 725 # Download small files 726 if metadata.content_length <= self._multipart_threshold: 727 if self._rust_client: 728 run_async_rust_client_method(self._rust_client, "download", key, f) 729 else: 730 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 731 temp_file_path = fp.name 732 fp.write(self._get_object(remote_path)) 733 os.rename(src=temp_file_path, dst=f) 734 return metadata.content_length 735 736 # Download large files using transfer manager 737 def _invoke_api() -> int: 738 bucket_obj = self._gcs_client.bucket(bucket) 739 blob = bucket_obj.blob(key) 740 if self._rust_client: 741 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f) 742 else: 743 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 744 temp_file_path = fp.name 745 transfer_manager.download_chunks_concurrently( 746 blob, 747 temp_file_path, 748 chunk_size=self._io_chunksize, 749 max_workers=self._max_concurrency, 750 worker_type=transfer_manager.THREAD, 751 ) 752 os.rename(src=temp_file_path, dst=f) 753 754 return metadata.content_length 755 756 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 757 else: 758 # Download small files 759 if metadata.content_length <= self._multipart_threshold: 760 response = self._get_object(remote_path) 761 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 762 # so we need to check whether `.decode()` is available. 763 if isinstance(f, io.StringIO): 764 if hasattr(response, "decode"): 765 f.write(response.decode("utf-8")) 766 else: 767 f.write(codecs.decode(memoryview(response), "utf-8")) 768 else: 769 f.write(response) 770 return metadata.content_length 771 772 # Download large files using transfer manager 773 def _invoke_api() -> int: 774 bucket_obj = self._gcs_client.bucket(bucket) 775 blob = bucket_obj.blob(key) 776 777 # transfer manager does not support downloading to a file object 778 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 779 temp_file_path = fp.name 780 transfer_manager.download_chunks_concurrently( 781 blob, 782 temp_file_path, 783 chunk_size=self._io_chunksize, 784 max_workers=self._max_concurrency, 785 worker_type=transfer_manager.THREAD, 786 ) 787 788 if isinstance(f, io.StringIO): 789 with open(temp_file_path, "r") as fp: 790 f.write(fp.read()) 791 else: 792 with open(temp_file_path, "rb") as fp: 793 f.write(fp.read()) 794 795 os.unlink(temp_file_path) 796 797 return metadata.content_length 798 799 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)