Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import io
 18import logging
 19import os
 20import tempfile
 21from collections.abc import Callable, Iterator
 22from typing import IO, Any, Optional, TypeVar, Union
 23
 24from google.api_core.exceptions import GoogleAPICallError, NotFound
 25from google.auth import identity_pool
 26from google.cloud import storage
 27from google.cloud.storage import transfer_manager
 28from google.cloud.storage.exceptions import InvalidResponse
 29from google.oauth2.credentials import Credentials as OAuth2Credentials
 30
 31from multistorageclient_rust import RustClient, RustRetryableError
 32
 33from ..rust_utils import run_async_rust_client_method
 34from ..telemetry import Telemetry
 35from ..types import (
 36    AWARE_DATETIME_MIN,
 37    Credentials,
 38    CredentialsProvider,
 39    NotModifiedError,
 40    ObjectMetadata,
 41    PreconditionFailedError,
 42    Range,
 43    RetryableError,
 44)
 45from ..utils import (
 46    split_path,
 47    validate_attributes,
 48)
 49from .base import BaseStorageProvider
 50
 51_T = TypeVar("_T")
 52
 53PROVIDER = "gcs"
 54
 55MiB = 1024 * 1024
 56
 57DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
 58DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
 59DEFAULT_IO_CHUNKSIZE = 32 * MiB
 60PYTHON_MAX_CONCURRENCY = 8
 61
 62logger = logging.getLogger(__name__)
 63
 64
[docs] 65class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 66 """ 67 Supply a string token to the Google Identity Pool. 68 """ 69 70 def __init__(self, token: str): 71 self._token = token 72
[docs] 73 def get_subject_token(self, context, request): 74 return self._token
75 76
[docs] 77class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 78 """ 79 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 80 """ 81 82 def __init__(self, audience: str, token_supplier: str): 83 """ 84 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 85 86 :param audience: The audience for the Google Identity Pool. 87 :param token_supplier: The token supplier for the Google Identity Pool. 88 """ 89 self._audience = audience 90 self._token_supplier = token_supplier 91
[docs] 92 def get_credentials(self) -> Credentials: 93 return Credentials( 94 access_key="", 95 secret_key="", 96 token="", 97 expiration=None, 98 custom_fields={"audience": self._audience, "token": self._token_supplier}, 99 )
100
[docs] 101 def refresh_credentials(self) -> None: 102 pass
103 104
[docs] 105class GoogleStorageProvider(BaseStorageProvider): 106 """ 107 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 108 """ 109 110 def __init__( 111 self, 112 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 113 endpoint_url: str = "", 114 base_path: str = "", 115 credentials_provider: Optional[CredentialsProvider] = None, 116 config_dict: Optional[dict[str, Any]] = None, 117 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 118 **kwargs: Any, 119 ): 120 """ 121 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 122 123 :param project_id: The Google Cloud project ID. 124 :param endpoint_url: The custom endpoint URL for the GCS service. 125 :param base_path: The root prefix path within the bucket where all operations will be scoped. 126 :param credentials_provider: The provider to retrieve GCS credentials. 127 :param config_dict: Resolved MSC config. 128 :param telemetry_provider: A function that provides a telemetry instance. 129 """ 130 super().__init__( 131 base_path=base_path, 132 provider_name=PROVIDER, 133 config_dict=config_dict, 134 telemetry_provider=telemetry_provider, 135 ) 136 137 self._project_id = project_id 138 self._endpoint_url = endpoint_url 139 self._credentials_provider = credentials_provider 140 self._gcs_client = self._create_gcs_client() 141 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 142 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 143 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE) 144 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY) 145 self._rust_client = None 146 if "rust_client" in kwargs: 147 # Inherit the rust client options from the kwargs 148 rust_client_options = kwargs["rust_client"] 149 if "max_concurrency" in kwargs: 150 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 151 if "multipart_chunksize" in kwargs: 152 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 153 if "read_timeout" in kwargs: 154 rust_client_options["read_timeout"] = kwargs["read_timeout"] 155 if "connect_timeout" in kwargs: 156 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 157 self._rust_client = self._create_rust_client(rust_client_options) 158 159 def _create_gcs_client(self) -> storage.Client: 160 client_options = {} 161 if self._endpoint_url: 162 client_options["api_endpoint"] = self._endpoint_url 163 164 if self._credentials_provider: 165 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 166 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 167 token = self._credentials_provider.get_credentials().get_custom_field("token") 168 169 # Use Workload Identity Federation (WIF) 170 identity_pool_credentials = identity_pool.Credentials( 171 audience=audience, 172 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 173 subject_token_supplier=StringTokenSupplier(token), 174 ) 175 return storage.Client( 176 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 177 ) 178 else: 179 # Use OAuth 2.0 token 180 token = self._credentials_provider.get_credentials().token 181 creds = OAuth2Credentials(token=token) 182 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 183 else: 184 return storage.Client(project=self._project_id, client_options=client_options) 185 186 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 187 configs = {} 188 189 if self._credentials_provider or self._endpoint_url: 190 # Workload Identity Federation (WIF) is not supported by the rust client: 191 # https://github.com/apache/arrow-rs-object-store/issues/258 192 logger.warning( 193 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client" 194 ) 195 return None 196 197 # Use environment variables if available, could be overridden by rust_client_options 198 if os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): 199 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") 200 if os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"): 201 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY") 202 if os.getenv("GOOGLE_SERVICE_ACCOUNT"): 203 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT") 204 if os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"): 205 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH") 206 207 if rust_client_options: 208 if "application_credentials" in rust_client_options: 209 configs["application_credentials"] = rust_client_options["application_credentials"] 210 if "service_account_key" in rust_client_options: 211 configs["service_account_key"] = rust_client_options["service_account_key"] 212 if "service_account_path" in rust_client_options: 213 configs["service_account_path"] = rust_client_options["service_account_path"] 214 if "skip_signature" in rust_client_options: 215 configs["skip_signature"] = rust_client_options["skip_signature"] 216 if "max_concurrency" in rust_client_options: 217 configs["max_concurrency"] = rust_client_options["max_concurrency"] 218 if "multipart_chunksize" in rust_client_options: 219 configs["multipart_chunksize"] = rust_client_options["multipart_chunksize"] 220 if "read_timeout" in rust_client_options: 221 configs["read_timeout"] = rust_client_options["read_timeout"] 222 if "connect_timeout" in rust_client_options: 223 configs["connect_timeout"] = rust_client_options["connect_timeout"] 224 225 if rust_client_options and "bucket" in rust_client_options: 226 configs["bucket"] = rust_client_options["bucket"] 227 else: 228 bucket, _ = split_path(self._base_path) 229 configs["bucket"] = bucket 230 231 return RustClient( 232 provider=PROVIDER, 233 configs=configs, 234 ) 235 236 def _refresh_gcs_client_if_needed(self) -> None: 237 """ 238 Refreshes the GCS client if the current credentials are expired. 239 """ 240 if self._credentials_provider: 241 credentials = self._credentials_provider.get_credentials() 242 if credentials.is_expired(): 243 self._credentials_provider.refresh_credentials() 244 self._gcs_client = self._create_gcs_client() 245 246 def _translate_errors( 247 self, 248 func: Callable[[], _T], 249 operation: str, 250 bucket: str, 251 key: str, 252 ) -> _T: 253 """ 254 Translates errors like timeouts and client errors. 255 256 :param func: The function that performs the actual GCS operation. 257 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 258 :param bucket: The name of the GCS bucket involved in the operation. 259 :param key: The key of the object within the GCS bucket. 260 261 :return: The result of the GCS operation, typically the return value of the `func` callable. 262 """ 263 try: 264 return func() 265 except GoogleAPICallError as error: 266 status_code = error.code if error.code else -1 267 error_info = f"status_code: {status_code}, message: {error.message}" 268 if status_code == 404: 269 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 270 elif status_code == 412: 271 raise PreconditionFailedError( 272 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 273 ) from error 274 elif status_code == 304: 275 # for if_none_match with a specific etag condition. 276 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 277 else: 278 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 279 except InvalidResponse as error: 280 response_text = error.response.text 281 error_details = f"error: {error}, error_response_text: {response_text}" 282 # Check for NoSuchUpload within the response text 283 if "NoSuchUpload" in response_text: 284 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 285 else: 286 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 287 except RustRetryableError as error: 288 raise RetryableError( 289 f"Failed to {operation} object(s) at {bucket}/{key} due to exhausted retries from Rust. " 290 f"error_type: {type(error).__name__}" 291 ) from error 292 except Exception as error: 293 error_details = str(error) 294 raise RuntimeError( 295 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 296 ) from error 297 298 def _put_object( 299 self, 300 path: str, 301 body: bytes, 302 if_match: Optional[str] = None, 303 if_none_match: Optional[str] = None, 304 attributes: Optional[dict[str, str]] = None, 305 ) -> int: 306 """ 307 Uploads an object to Google Cloud Storage. 308 309 :param path: The path to the object to upload. 310 :param body: The content of the object to upload. 311 :param if_match: Optional ETag to match against the object. 312 :param if_none_match: Optional ETag to match against the object. 313 :param attributes: Optional attributes to attach to the object. 314 """ 315 bucket, key = split_path(path) 316 self._refresh_gcs_client_if_needed() 317 318 def _invoke_api() -> int: 319 bucket_obj = self._gcs_client.bucket(bucket) 320 blob = bucket_obj.blob(key) 321 322 kwargs = {} 323 324 if if_match: 325 kwargs["if_generation_match"] = int(if_match) # 412 error code 326 if if_none_match: 327 if if_none_match == "*": 328 raise NotImplementedError("if_none_match='*' is not supported for GCS") 329 else: 330 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 331 332 validated_attributes = validate_attributes(attributes) 333 if validated_attributes: 334 blob.metadata = validated_attributes 335 336 if ( 337 self._rust_client 338 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 339 and not path.endswith("/") 340 and not kwargs 341 and not validated_attributes 342 ): 343 run_async_rust_client_method(self._rust_client, "put", key, body) 344 else: 345 blob.upload_from_string(body, **kwargs) 346 347 return len(body) 348 349 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 350 351 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 352 bucket, key = split_path(path) 353 self._refresh_gcs_client_if_needed() 354 355 def _invoke_api() -> bytes: 356 bucket_obj = self._gcs_client.bucket(bucket) 357 blob = bucket_obj.blob(key) 358 if byte_range: 359 if self._rust_client: 360 return run_async_rust_client_method( 361 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 362 ) 363 else: 364 return blob.download_as_bytes( 365 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True 366 ) 367 else: 368 if self._rust_client: 369 return run_async_rust_client_method(self._rust_client, "get", key) 370 else: 371 return blob.download_as_bytes(single_shot_download=True) 372 373 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 374 375 def _copy_object(self, src_path: str, dest_path: str) -> int: 376 src_bucket, src_key = split_path(src_path) 377 dest_bucket, dest_key = split_path(dest_path) 378 self._refresh_gcs_client_if_needed() 379 380 src_object = self._get_object_metadata(src_path) 381 382 def _invoke_api() -> int: 383 source_bucket_obj = self._gcs_client.bucket(src_bucket) 384 source_blob = source_bucket_obj.blob(src_key) 385 386 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 387 destination_blob = destination_bucket_obj.blob(dest_key) 388 389 rewrite_tokens = [None] 390 while len(rewrite_tokens) > 0: 391 rewrite_token = rewrite_tokens.pop() 392 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 393 if next_rewrite_token is not None: 394 rewrite_tokens.append(next_rewrite_token) 395 396 return src_object.content_length 397 398 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 399 400 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 401 bucket, key = split_path(path) 402 self._refresh_gcs_client_if_needed() 403 404 def _invoke_api() -> None: 405 bucket_obj = self._gcs_client.bucket(bucket) 406 blob = bucket_obj.blob(key) 407 408 # If if_match is provided, use it as a precondition 409 if if_match: 410 generation = int(if_match) 411 blob.delete(if_generation_match=generation) 412 else: 413 # No if_match check needed, just delete 414 blob.delete() 415 416 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 417 418 def _is_dir(self, path: str) -> bool: 419 # Ensure the path ends with '/' to mimic a directory 420 path = self._append_delimiter(path) 421 422 bucket, key = split_path(path) 423 self._refresh_gcs_client_if_needed() 424 425 def _invoke_api() -> bool: 426 bucket_obj = self._gcs_client.bucket(bucket) 427 # List objects with the given prefix 428 blobs = bucket_obj.list_blobs( 429 prefix=key, 430 delimiter="/", 431 ) 432 # Check if there are any contents or common prefixes 433 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 434 435 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 436 437 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 438 bucket, key = split_path(path) 439 if path.endswith("/") or (bucket and not key): 440 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 441 # which metadata is not guaranteed to exist for cases such as 442 # "virtual prefix" that was never explicitly created. 443 if self._is_dir(path): 444 return ObjectMetadata( 445 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 446 ) 447 else: 448 raise FileNotFoundError(f"Directory {path} does not exist.") 449 else: 450 self._refresh_gcs_client_if_needed() 451 452 def _invoke_api() -> ObjectMetadata: 453 bucket_obj = self._gcs_client.bucket(bucket) 454 blob = bucket_obj.get_blob(key) 455 if not blob: 456 raise NotFound(f"Blob {key} not found in bucket {bucket}") 457 return ObjectMetadata( 458 key=path, 459 content_length=blob.size or 0, 460 content_type=blob.content_type, 461 last_modified=blob.updated or AWARE_DATETIME_MIN, 462 etag=str(blob.generation), 463 metadata=dict(blob.metadata) if blob.metadata else None, 464 ) 465 466 try: 467 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 468 except FileNotFoundError as error: 469 if strict: 470 # If the object does not exist on the given path, we will append a trailing slash and 471 # check if the path is a directory. 472 path = self._append_delimiter(path) 473 if self._is_dir(path): 474 return ObjectMetadata( 475 key=path, 476 type="directory", 477 content_length=0, 478 last_modified=AWARE_DATETIME_MIN, 479 ) 480 raise error 481 482 def _list_objects( 483 self, 484 path: str, 485 start_after: Optional[str] = None, 486 end_at: Optional[str] = None, 487 include_directories: bool = False, 488 ) -> Iterator[ObjectMetadata]: 489 bucket, prefix = split_path(path) 490 self._refresh_gcs_client_if_needed() 491 492 def _invoke_api() -> Iterator[ObjectMetadata]: 493 bucket_obj = self._gcs_client.bucket(bucket) 494 if include_directories: 495 blobs = bucket_obj.list_blobs( 496 prefix=prefix, 497 # This is ≥ instead of >. 498 start_offset=start_after, 499 delimiter="/", 500 ) 501 else: 502 blobs = bucket_obj.list_blobs( 503 prefix=prefix, 504 # This is ≥ instead of >. 505 start_offset=start_after, 506 ) 507 508 # GCS guarantees lexicographical order. 509 for blob in blobs: 510 key = blob.name 511 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 512 if key.endswith("/"): 513 if include_directories: 514 yield ObjectMetadata( 515 key=os.path.join(bucket, key.rstrip("/")), 516 type="directory", 517 content_length=0, 518 last_modified=blob.updated, 519 ) 520 else: 521 yield ObjectMetadata( 522 key=os.path.join(bucket, key), 523 content_length=blob.size, 524 content_type=blob.content_type, 525 last_modified=blob.updated, 526 etag=blob.etag, 527 ) 528 elif start_after != key: 529 return 530 531 # The directories must be accessed last. 532 if include_directories: 533 for directory in blobs.prefixes: 534 yield ObjectMetadata( 535 key=os.path.join(bucket, directory.rstrip("/")), 536 type="directory", 537 content_length=0, 538 last_modified=AWARE_DATETIME_MIN, 539 ) 540 541 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 542 543 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 544 bucket, key = split_path(remote_path) 545 file_size: int = 0 546 self._refresh_gcs_client_if_needed() 547 548 if isinstance(f, str): 549 file_size = os.path.getsize(f) 550 551 # Upload small files 552 if file_size <= self._multipart_threshold: 553 if self._rust_client and not attributes: 554 run_async_rust_client_method(self._rust_client, "upload", f, key) 555 else: 556 with open(f, "rb") as fp: 557 self._put_object(remote_path, fp.read(), attributes=attributes) 558 return file_size 559 560 # Upload large files using transfer manager 561 def _invoke_api() -> int: 562 if self._rust_client and not attributes: 563 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 564 else: 565 bucket_obj = self._gcs_client.bucket(bucket) 566 blob = bucket_obj.blob(key) 567 # GCS will raise an error if blob.metadata is None 568 validated_attributes = validate_attributes(attributes) 569 if validated_attributes is not None: 570 blob.metadata = validated_attributes 571 transfer_manager.upload_chunks_concurrently( 572 f, 573 blob, 574 chunk_size=self._multipart_chunksize, 575 max_workers=self._max_concurrency, 576 worker_type=transfer_manager.THREAD, 577 ) 578 579 return file_size 580 581 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 582 else: 583 f.seek(0, io.SEEK_END) 584 file_size = f.tell() 585 f.seek(0) 586 587 # Upload small files 588 if file_size <= self._multipart_threshold: 589 if isinstance(f, io.StringIO): 590 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 591 else: 592 self._put_object(remote_path, f.read(), attributes=attributes) 593 return file_size 594 595 # Upload large files using transfer manager 596 def _invoke_api() -> int: 597 bucket_obj = self._gcs_client.bucket(bucket) 598 blob = bucket_obj.blob(key) 599 validated_attributes = validate_attributes(attributes) 600 if validated_attributes: 601 blob.metadata = validated_attributes 602 if isinstance(f, io.StringIO): 603 mode = "w" 604 else: 605 mode = "wb" 606 607 # transfer manager does not support uploading a file object 608 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 609 temp_file_path = fp.name 610 fp.write(f.read()) 611 612 transfer_manager.upload_chunks_concurrently( 613 temp_file_path, 614 blob, 615 chunk_size=self._multipart_chunksize, 616 max_workers=self._max_concurrency, 617 worker_type=transfer_manager.THREAD, 618 ) 619 620 os.unlink(temp_file_path) 621 622 return file_size 623 624 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 625 626 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 627 self._refresh_gcs_client_if_needed() 628 629 if metadata is None: 630 metadata = self._get_object_metadata(remote_path) 631 632 bucket, key = split_path(remote_path) 633 634 if isinstance(f, str): 635 if os.path.dirname(f): 636 os.makedirs(os.path.dirname(f), exist_ok=True) 637 # Download small files 638 if metadata.content_length <= self._multipart_threshold: 639 if self._rust_client: 640 run_async_rust_client_method(self._rust_client, "download", key, f) 641 else: 642 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 643 temp_file_path = fp.name 644 fp.write(self._get_object(remote_path)) 645 os.rename(src=temp_file_path, dst=f) 646 return metadata.content_length 647 648 # Download large files using transfer manager 649 def _invoke_api() -> int: 650 bucket_obj = self._gcs_client.bucket(bucket) 651 blob = bucket_obj.blob(key) 652 if self._rust_client: 653 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f) 654 else: 655 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 656 temp_file_path = fp.name 657 transfer_manager.download_chunks_concurrently( 658 blob, 659 temp_file_path, 660 chunk_size=self._io_chunksize, 661 max_workers=self._max_concurrency, 662 worker_type=transfer_manager.THREAD, 663 ) 664 os.rename(src=temp_file_path, dst=f) 665 666 return metadata.content_length 667 668 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 669 else: 670 # Download small files 671 if metadata.content_length <= self._multipart_threshold: 672 response = self._get_object(remote_path) 673 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 674 # so we need to check whether `.decode()` is available. 675 if isinstance(f, io.StringIO): 676 if hasattr(response, "decode"): 677 f.write(response.decode("utf-8")) 678 else: 679 f.write(codecs.decode(memoryview(response), "utf-8")) 680 else: 681 f.write(response) 682 return metadata.content_length 683 684 # Download large files using transfer manager 685 def _invoke_api() -> int: 686 bucket_obj = self._gcs_client.bucket(bucket) 687 blob = bucket_obj.blob(key) 688 689 # transfer manager does not support downloading to a file object 690 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 691 temp_file_path = fp.name 692 transfer_manager.download_chunks_concurrently( 693 blob, 694 temp_file_path, 695 chunk_size=self._io_chunksize, 696 max_workers=self._max_concurrency, 697 worker_type=transfer_manager.THREAD, 698 ) 699 700 if isinstance(f, io.StringIO): 701 with open(temp_file_path, "r") as fp: 702 f.write(fp.read()) 703 else: 704 with open(temp_file_path, "rb") as fp: 705 f.write(fp.read()) 706 707 os.unlink(temp_file_path) 708 709 return metadata.content_length 710 711 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)