Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import copy
 18import io
 19import json
 20import logging
 21import os
 22import tempfile
 23from collections.abc import Callable, Iterator
 24from typing import IO, Any, Optional, TypeVar, Union
 25
 26from google.api_core.exceptions import GoogleAPICallError, NotFound
 27from google.auth import credentials as auth_credentials
 28from google.auth import identity_pool
 29from google.cloud import storage
 30from google.cloud.storage import transfer_manager
 31from google.cloud.storage.exceptions import InvalidResponse
 32from google.oauth2 import service_account
 33from google.oauth2.credentials import Credentials as OAuth2Credentials
 34
 35from multistorageclient_rust import RustClient, RustRetryableError
 36
 37from ..rust_utils import run_async_rust_client_method
 38from ..telemetry import Telemetry
 39from ..types import (
 40    AWARE_DATETIME_MIN,
 41    Credentials,
 42    CredentialsProvider,
 43    NotModifiedError,
 44    ObjectMetadata,
 45    PreconditionFailedError,
 46    Range,
 47    RetryableError,
 48)
 49from ..utils import (
 50    split_path,
 51    validate_attributes,
 52)
 53from .base import BaseStorageProvider
 54
 55_T = TypeVar("_T")
 56
 57PROVIDER = "gcs"
 58
 59MiB = 1024 * 1024
 60
 61DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
 62DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
 63DEFAULT_IO_CHUNKSIZE = 32 * MiB
 64PYTHON_MAX_CONCURRENCY = 8
 65
 66logger = logging.getLogger(__name__)
 67
 68
[docs] 69class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 70 """ 71 Supply a string token to the Google Identity Pool. 72 """ 73 74 def __init__(self, token: str): 75 self._token = token 76
[docs] 77 def get_subject_token(self, context, request): 78 return self._token
79 80
[docs] 81class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 82 """ 83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 84 """ 85 86 def __init__(self, audience: str, token_supplier: str): 87 """ 88 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 89 90 :param audience: The audience for the Google Identity Pool. 91 :param token_supplier: The token supplier for the Google Identity Pool. 92 """ 93 self._audience = audience 94 self._token_supplier = token_supplier 95
[docs] 96 def get_credentials(self) -> Credentials: 97 return Credentials( 98 access_key="", 99 secret_key="", 100 token="", 101 expiration=None, 102 custom_fields={"audience": self._audience, "token": self._token_supplier}, 103 )
104
[docs] 105 def refresh_credentials(self) -> None: 106 pass
107 108
[docs] 109class GoogleServiceAccountCredentialsProvider(CredentialsProvider): 110 """ 111 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials. 112 """ 113 114 #: Google service account private key file contents. 115 _info: dict[str, Any] 116 117 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None): 118 """ 119 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a 120 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file 121 or the file contents. 122 123 :param file: Path to a Google service account private key file. 124 :param info: Google service account private key file contents. 125 """ 126 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)): 127 raise ValueError("Must specify exactly one of file or info") 128 129 if file is not None: 130 with open(file, "r") as f: 131 self._info = json.load(f) 132 elif info is not None: 133 self._info = copy.deepcopy(info) 134
[docs] 135 def get_credentials(self) -> Credentials: 136 return Credentials( 137 access_key="", 138 secret_key="", 139 token=None, 140 expiration=None, 141 custom_fields={"info": copy.deepcopy(self._info)}, 142 )
143
[docs] 144 def refresh_credentials(self) -> None: 145 pass
146 147
[docs] 148class GoogleStorageProvider(BaseStorageProvider): 149 """ 150 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 151 """ 152 153 def __init__( 154 self, 155 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 156 endpoint_url: str = "", 157 base_path: str = "", 158 credentials_provider: Optional[CredentialsProvider] = None, 159 config_dict: Optional[dict[str, Any]] = None, 160 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 161 **kwargs: Any, 162 ): 163 """ 164 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 165 166 :param project_id: The Google Cloud project ID. 167 :param endpoint_url: The custom endpoint URL for the GCS service. 168 :param base_path: The root prefix path within the bucket where all operations will be scoped. 169 :param credentials_provider: The provider to retrieve GCS credentials. 170 :param config_dict: Resolved MSC config. 171 :param telemetry_provider: A function that provides a telemetry instance. 172 """ 173 super().__init__( 174 base_path=base_path, 175 provider_name=PROVIDER, 176 config_dict=config_dict, 177 telemetry_provider=telemetry_provider, 178 ) 179 180 self._project_id = project_id 181 self._endpoint_url = endpoint_url 182 self._credentials_provider = credentials_provider 183 self._skip_signature = kwargs.get("skip_signature", False) 184 self._gcs_client = self._create_gcs_client() 185 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 186 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 187 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE) 188 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY) 189 self._rust_client = None 190 if "rust_client" in kwargs: 191 # Inherit the rust client options from the kwargs 192 rust_client_options = copy.deepcopy(kwargs["rust_client"]) 193 if "max_concurrency" in kwargs: 194 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 195 if "multipart_chunksize" in kwargs: 196 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 197 if "read_timeout" in kwargs: 198 rust_client_options["read_timeout"] = kwargs["read_timeout"] 199 if "connect_timeout" in kwargs: 200 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 201 if "service_account_key" not in rust_client_options and isinstance( 202 self._credentials_provider, GoogleServiceAccountCredentialsProvider 203 ): 204 rust_client_options["service_account_key"] = json.dumps( 205 self._credentials_provider.get_credentials().get_custom_field("info") 206 ) 207 self._rust_client = self._create_rust_client(rust_client_options) 208 209 def _create_gcs_client(self) -> storage.Client: 210 client_options = {} 211 if self._endpoint_url: 212 client_options["api_endpoint"] = self._endpoint_url 213 214 # Use anonymous credentials for public buckets when skip_signature is enabled 215 if self._skip_signature: 216 return storage.Client( 217 project=self._project_id, 218 credentials=auth_credentials.AnonymousCredentials(), 219 client_options=client_options, 220 ) 221 222 if self._credentials_provider: 223 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 224 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 225 token = self._credentials_provider.get_credentials().get_custom_field("token") 226 227 # Use Workload Identity Federation (WIF) 228 identity_pool_credentials = identity_pool.Credentials( 229 audience=audience, 230 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 231 subject_token_supplier=StringTokenSupplier(token), 232 ) 233 return storage.Client( 234 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 235 ) 236 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 237 # Use service account key. 238 service_account_credentials = service_account.Credentials.from_service_account_info( 239 info=self._credentials_provider.get_credentials().get_custom_field("info") 240 ) 241 return storage.Client( 242 project=self._project_id, credentials=service_account_credentials, client_options=client_options 243 ) 244 else: 245 # Use OAuth 2.0 token 246 token = self._credentials_provider.get_credentials().token 247 creds = OAuth2Credentials(token=token) 248 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 249 else: 250 return storage.Client(project=self._project_id, client_options=client_options) 251 252 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 253 if self._credentials_provider or self._endpoint_url: 254 # Workload Identity Federation (WIF) is not supported by the rust client: 255 # https://github.com/apache/arrow-rs-object-store/issues/258 256 logger.warning( 257 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client" 258 ) 259 return None 260 configs = dict(rust_client_options) if rust_client_options else {} 261 262 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): 263 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") 264 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"): 265 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY") 266 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"): 267 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT") 268 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"): 269 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH") 270 271 if self._skip_signature and "skip_signature" not in configs: 272 configs["skip_signature"] = True 273 274 if "bucket" not in configs: 275 bucket, _ = split_path(self._base_path) 276 configs["bucket"] = bucket 277 278 return RustClient( 279 provider=PROVIDER, 280 configs=configs, 281 ) 282 283 def _refresh_gcs_client_if_needed(self) -> None: 284 """ 285 Refreshes the GCS client if the current credentials are expired. 286 """ 287 if self._credentials_provider: 288 credentials = self._credentials_provider.get_credentials() 289 if credentials.is_expired(): 290 self._credentials_provider.refresh_credentials() 291 self._gcs_client = self._create_gcs_client() 292 293 def _translate_errors( 294 self, 295 func: Callable[[], _T], 296 operation: str, 297 bucket: str, 298 key: str, 299 ) -> _T: 300 """ 301 Translates errors like timeouts and client errors. 302 303 :param func: The function that performs the actual GCS operation. 304 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 305 :param bucket: The name of the GCS bucket involved in the operation. 306 :param key: The key of the object within the GCS bucket. 307 308 :return: The result of the GCS operation, typically the return value of the `func` callable. 309 """ 310 try: 311 return func() 312 except GoogleAPICallError as error: 313 status_code = error.code if error.code else -1 314 error_info = f"status_code: {status_code}, message: {error.message}" 315 if status_code == 404: 316 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 317 elif status_code == 412: 318 raise PreconditionFailedError( 319 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 320 ) from error 321 elif status_code == 304: 322 # for if_none_match with a specific etag condition. 323 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 324 else: 325 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 326 except InvalidResponse as error: 327 response_text = error.response.text 328 error_details = f"error: {error}, error_response_text: {response_text}" 329 # Check for NoSuchUpload within the response text 330 if "NoSuchUpload" in response_text: 331 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 332 else: 333 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 334 except RustRetryableError as error: 335 raise RetryableError( 336 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 337 f"error_type: {type(error).__name__}" 338 ) from error 339 except FileNotFoundError: 340 raise 341 except Exception as error: 342 error_details = str(error) 343 raise RuntimeError( 344 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 345 ) from error 346 347 def _put_object( 348 self, 349 path: str, 350 body: bytes, 351 if_match: Optional[str] = None, 352 if_none_match: Optional[str] = None, 353 attributes: Optional[dict[str, str]] = None, 354 ) -> int: 355 """ 356 Uploads an object to Google Cloud Storage. 357 358 :param path: The path to the object to upload. 359 :param body: The content of the object to upload. 360 :param if_match: Optional ETag to match against the object. 361 :param if_none_match: Optional ETag to match against the object. 362 :param attributes: Optional attributes to attach to the object. 363 """ 364 bucket, key = split_path(path) 365 self._refresh_gcs_client_if_needed() 366 367 def _invoke_api() -> int: 368 bucket_obj = self._gcs_client.bucket(bucket) 369 blob = bucket_obj.blob(key) 370 371 kwargs = {} 372 373 if if_match: 374 kwargs["if_generation_match"] = int(if_match) # 412 error code 375 if if_none_match: 376 if if_none_match == "*": 377 raise NotImplementedError("if_none_match='*' is not supported for GCS") 378 else: 379 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 380 381 validated_attributes = validate_attributes(attributes) 382 if validated_attributes: 383 blob.metadata = validated_attributes 384 385 if ( 386 self._rust_client 387 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 388 and not path.endswith("/") 389 and not kwargs 390 and not validated_attributes 391 ): 392 run_async_rust_client_method(self._rust_client, "put", key, body) 393 else: 394 blob.upload_from_string(body, **kwargs) 395 396 return len(body) 397 398 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 399 400 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 401 bucket, key = split_path(path) 402 self._refresh_gcs_client_if_needed() 403 404 def _invoke_api() -> bytes: 405 bucket_obj = self._gcs_client.bucket(bucket) 406 blob = bucket_obj.blob(key) 407 if byte_range: 408 if self._rust_client: 409 return run_async_rust_client_method( 410 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 411 ) 412 else: 413 return blob.download_as_bytes( 414 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True 415 ) 416 else: 417 if self._rust_client: 418 return run_async_rust_client_method(self._rust_client, "get", key) 419 else: 420 return blob.download_as_bytes(single_shot_download=True) 421 422 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 423 424 def _copy_object(self, src_path: str, dest_path: str) -> int: 425 src_bucket, src_key = split_path(src_path) 426 dest_bucket, dest_key = split_path(dest_path) 427 self._refresh_gcs_client_if_needed() 428 429 src_object = self._get_object_metadata(src_path) 430 431 def _invoke_api() -> int: 432 source_bucket_obj = self._gcs_client.bucket(src_bucket) 433 source_blob = source_bucket_obj.blob(src_key) 434 435 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 436 destination_blob = destination_bucket_obj.blob(dest_key) 437 438 rewrite_tokens = [None] 439 while len(rewrite_tokens) > 0: 440 rewrite_token = rewrite_tokens.pop() 441 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 442 if next_rewrite_token is not None: 443 rewrite_tokens.append(next_rewrite_token) 444 445 return src_object.content_length 446 447 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 448 449 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 450 bucket, key = split_path(path) 451 self._refresh_gcs_client_if_needed() 452 453 def _invoke_api() -> None: 454 bucket_obj = self._gcs_client.bucket(bucket) 455 blob = bucket_obj.blob(key) 456 457 # If if_match is provided, use it as a precondition 458 if if_match: 459 generation = int(if_match) 460 blob.delete(if_generation_match=generation) 461 else: 462 # No if_match check needed, just delete 463 blob.delete() 464 465 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 466 467 def _is_dir(self, path: str) -> bool: 468 # Ensure the path ends with '/' to mimic a directory 469 path = self._append_delimiter(path) 470 471 bucket, key = split_path(path) 472 self._refresh_gcs_client_if_needed() 473 474 def _invoke_api() -> bool: 475 bucket_obj = self._gcs_client.bucket(bucket) 476 # List objects with the given prefix 477 blobs = bucket_obj.list_blobs( 478 prefix=key, 479 delimiter="/", 480 ) 481 # Check if there are any contents or common prefixes 482 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 483 484 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 485 486 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 487 bucket, key = split_path(path) 488 if path.endswith("/") or (bucket and not key): 489 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 490 # which metadata is not guaranteed to exist for cases such as 491 # "virtual prefix" that was never explicitly created. 492 if self._is_dir(path): 493 return ObjectMetadata( 494 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 495 ) 496 else: 497 raise FileNotFoundError(f"Directory {path} does not exist.") 498 else: 499 self._refresh_gcs_client_if_needed() 500 501 def _invoke_api() -> ObjectMetadata: 502 bucket_obj = self._gcs_client.bucket(bucket) 503 blob = bucket_obj.get_blob(key) 504 if not blob: 505 raise NotFound(f"Blob {key} not found in bucket {bucket}") 506 return ObjectMetadata( 507 key=path, 508 content_length=blob.size or 0, 509 content_type=blob.content_type, 510 last_modified=blob.updated or AWARE_DATETIME_MIN, 511 etag=str(blob.generation), 512 metadata=dict(blob.metadata) if blob.metadata else None, 513 ) 514 515 try: 516 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 517 except FileNotFoundError as error: 518 if strict: 519 # If the object does not exist on the given path, we will append a trailing slash and 520 # check if the path is a directory. 521 path = self._append_delimiter(path) 522 if self._is_dir(path): 523 return ObjectMetadata( 524 key=path, 525 type="directory", 526 content_length=0, 527 last_modified=AWARE_DATETIME_MIN, 528 ) 529 raise error 530 531 def _list_objects( 532 self, 533 path: str, 534 start_after: Optional[str] = None, 535 end_at: Optional[str] = None, 536 include_directories: bool = False, 537 follow_symlinks: bool = True, 538 ) -> Iterator[ObjectMetadata]: 539 bucket, prefix = split_path(path) 540 541 # Get the prefix of the start_after and end_at paths relative to the bucket. 542 if start_after: 543 _, start_after = split_path(start_after) 544 if end_at: 545 _, end_at = split_path(end_at) 546 547 self._refresh_gcs_client_if_needed() 548 549 def _invoke_api() -> Iterator[ObjectMetadata]: 550 bucket_obj = self._gcs_client.bucket(bucket) 551 if include_directories: 552 blobs = bucket_obj.list_blobs( 553 prefix=prefix, 554 # This is ≥ instead of >. 555 start_offset=start_after, 556 delimiter="/", 557 ) 558 else: 559 blobs = bucket_obj.list_blobs( 560 prefix=prefix, 561 # This is ≥ instead of >. 562 start_offset=start_after, 563 ) 564 565 # GCS guarantees lexicographical order. 566 for blob in blobs: 567 key = blob.name 568 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 569 if key.endswith("/"): 570 if include_directories: 571 yield ObjectMetadata( 572 key=os.path.join(bucket, key.rstrip("/")), 573 type="directory", 574 content_length=0, 575 last_modified=blob.updated, 576 ) 577 else: 578 yield ObjectMetadata( 579 key=os.path.join(bucket, key), 580 content_length=blob.size, 581 content_type=blob.content_type, 582 last_modified=blob.updated, 583 etag=blob.etag, 584 ) 585 elif start_after != key: 586 return 587 588 # The directories must be accessed last. 589 if include_directories: 590 for directory in blobs.prefixes: 591 yield ObjectMetadata( 592 key=os.path.join(bucket, directory.rstrip("/")), 593 type="directory", 594 content_length=0, 595 last_modified=AWARE_DATETIME_MIN, 596 ) 597 598 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 599 600 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 601 bucket, key = split_path(remote_path) 602 file_size: int = 0 603 self._refresh_gcs_client_if_needed() 604 605 if isinstance(f, str): 606 file_size = os.path.getsize(f) 607 608 # Upload small files 609 if file_size <= self._multipart_threshold: 610 if self._rust_client and not attributes: 611 run_async_rust_client_method(self._rust_client, "upload", f, key) 612 else: 613 with open(f, "rb") as fp: 614 self._put_object(remote_path, fp.read(), attributes=attributes) 615 return file_size 616 617 # Upload large files using transfer manager 618 def _invoke_api() -> int: 619 if self._rust_client and not attributes: 620 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 621 else: 622 bucket_obj = self._gcs_client.bucket(bucket) 623 blob = bucket_obj.blob(key) 624 # GCS will raise an error if blob.metadata is None 625 validated_attributes = validate_attributes(attributes) 626 if validated_attributes is not None: 627 blob.metadata = validated_attributes 628 transfer_manager.upload_chunks_concurrently( 629 f, 630 blob, 631 chunk_size=self._multipart_chunksize, 632 max_workers=self._max_concurrency, 633 worker_type=transfer_manager.THREAD, 634 ) 635 636 return file_size 637 638 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 639 else: 640 f.seek(0, io.SEEK_END) 641 file_size = f.tell() 642 f.seek(0) 643 644 # Upload small files 645 if file_size <= self._multipart_threshold: 646 if isinstance(f, io.StringIO): 647 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 648 else: 649 self._put_object(remote_path, f.read(), attributes=attributes) 650 return file_size 651 652 # Upload large files using transfer manager 653 def _invoke_api() -> int: 654 bucket_obj = self._gcs_client.bucket(bucket) 655 blob = bucket_obj.blob(key) 656 validated_attributes = validate_attributes(attributes) 657 if validated_attributes: 658 blob.metadata = validated_attributes 659 if isinstance(f, io.StringIO): 660 mode = "w" 661 else: 662 mode = "wb" 663 664 # transfer manager does not support uploading a file object 665 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 666 temp_file_path = fp.name 667 fp.write(f.read()) 668 669 transfer_manager.upload_chunks_concurrently( 670 temp_file_path, 671 blob, 672 chunk_size=self._multipart_chunksize, 673 max_workers=self._max_concurrency, 674 worker_type=transfer_manager.THREAD, 675 ) 676 677 os.unlink(temp_file_path) 678 679 return file_size 680 681 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 682 683 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 684 self._refresh_gcs_client_if_needed() 685 686 if metadata is None: 687 metadata = self._get_object_metadata(remote_path) 688 689 bucket, key = split_path(remote_path) 690 691 if isinstance(f, str): 692 if os.path.dirname(f): 693 os.makedirs(os.path.dirname(f), exist_ok=True) 694 # Download small files 695 if metadata.content_length <= self._multipart_threshold: 696 if self._rust_client: 697 run_async_rust_client_method(self._rust_client, "download", key, f) 698 else: 699 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 700 temp_file_path = fp.name 701 fp.write(self._get_object(remote_path)) 702 os.rename(src=temp_file_path, dst=f) 703 return metadata.content_length 704 705 # Download large files using transfer manager 706 def _invoke_api() -> int: 707 bucket_obj = self._gcs_client.bucket(bucket) 708 blob = bucket_obj.blob(key) 709 if self._rust_client: 710 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f) 711 else: 712 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 713 temp_file_path = fp.name 714 transfer_manager.download_chunks_concurrently( 715 blob, 716 temp_file_path, 717 chunk_size=self._io_chunksize, 718 max_workers=self._max_concurrency, 719 worker_type=transfer_manager.THREAD, 720 ) 721 os.rename(src=temp_file_path, dst=f) 722 723 return metadata.content_length 724 725 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 726 else: 727 # Download small files 728 if metadata.content_length <= self._multipart_threshold: 729 response = self._get_object(remote_path) 730 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 731 # so we need to check whether `.decode()` is available. 732 if isinstance(f, io.StringIO): 733 if hasattr(response, "decode"): 734 f.write(response.decode("utf-8")) 735 else: 736 f.write(codecs.decode(memoryview(response), "utf-8")) 737 else: 738 f.write(response) 739 return metadata.content_length 740 741 # Download large files using transfer manager 742 def _invoke_api() -> int: 743 bucket_obj = self._gcs_client.bucket(bucket) 744 blob = bucket_obj.blob(key) 745 746 # transfer manager does not support downloading to a file object 747 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 748 temp_file_path = fp.name 749 transfer_manager.download_chunks_concurrently( 750 blob, 751 temp_file_path, 752 chunk_size=self._io_chunksize, 753 max_workers=self._max_concurrency, 754 worker_type=transfer_manager.THREAD, 755 ) 756 757 if isinstance(f, io.StringIO): 758 with open(temp_file_path, "r") as fp: 759 f.write(fp.read()) 760 else: 761 with open(temp_file_path, "rb") as fp: 762 f.write(fp.read()) 763 764 os.unlink(temp_file_path) 765 766 return metadata.content_length 767 768 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)