Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import io
 18import logging
 19import os
 20import tempfile
 21from collections.abc import Callable, Iterator
 22from typing import IO, Any, Optional, TypeVar, Union
 23
 24from google.api_core.exceptions import GoogleAPICallError, NotFound
 25from google.auth import credentials as auth_credentials
 26from google.auth import identity_pool
 27from google.cloud import storage
 28from google.cloud.storage import transfer_manager
 29from google.cloud.storage.exceptions import InvalidResponse
 30from google.oauth2.credentials import Credentials as OAuth2Credentials
 31
 32from multistorageclient_rust import RustClient, RustRetryableError
 33
 34from ..rust_utils import run_async_rust_client_method
 35from ..telemetry import Telemetry
 36from ..types import (
 37    AWARE_DATETIME_MIN,
 38    Credentials,
 39    CredentialsProvider,
 40    NotModifiedError,
 41    ObjectMetadata,
 42    PreconditionFailedError,
 43    Range,
 44    RetryableError,
 45)
 46from ..utils import (
 47    split_path,
 48    validate_attributes,
 49)
 50from .base import BaseStorageProvider
 51
 52_T = TypeVar("_T")
 53
 54PROVIDER = "gcs"
 55
 56MiB = 1024 * 1024
 57
 58DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
 59DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
 60DEFAULT_IO_CHUNKSIZE = 32 * MiB
 61PYTHON_MAX_CONCURRENCY = 8
 62
 63logger = logging.getLogger(__name__)
 64
 65
[docs] 66class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 67 """ 68 Supply a string token to the Google Identity Pool. 69 """ 70 71 def __init__(self, token: str): 72 self._token = token 73
[docs] 74 def get_subject_token(self, context, request): 75 return self._token
76 77
[docs] 78class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 79 """ 80 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 81 """ 82 83 def __init__(self, audience: str, token_supplier: str): 84 """ 85 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 86 87 :param audience: The audience for the Google Identity Pool. 88 :param token_supplier: The token supplier for the Google Identity Pool. 89 """ 90 self._audience = audience 91 self._token_supplier = token_supplier 92
[docs] 93 def get_credentials(self) -> Credentials: 94 return Credentials( 95 access_key="", 96 secret_key="", 97 token="", 98 expiration=None, 99 custom_fields={"audience": self._audience, "token": self._token_supplier}, 100 )
101
[docs] 102 def refresh_credentials(self) -> None: 103 pass
104 105
[docs] 106class GoogleStorageProvider(BaseStorageProvider): 107 """ 108 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 109 """ 110 111 def __init__( 112 self, 113 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 114 endpoint_url: str = "", 115 base_path: str = "", 116 credentials_provider: Optional[CredentialsProvider] = None, 117 config_dict: Optional[dict[str, Any]] = None, 118 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 119 **kwargs: Any, 120 ): 121 """ 122 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 123 124 :param project_id: The Google Cloud project ID. 125 :param endpoint_url: The custom endpoint URL for the GCS service. 126 :param base_path: The root prefix path within the bucket where all operations will be scoped. 127 :param credentials_provider: The provider to retrieve GCS credentials. 128 :param config_dict: Resolved MSC config. 129 :param telemetry_provider: A function that provides a telemetry instance. 130 """ 131 super().__init__( 132 base_path=base_path, 133 provider_name=PROVIDER, 134 config_dict=config_dict, 135 telemetry_provider=telemetry_provider, 136 ) 137 138 self._project_id = project_id 139 self._endpoint_url = endpoint_url 140 self._credentials_provider = credentials_provider 141 self._skip_signature = kwargs.get("skip_signature", False) 142 self._gcs_client = self._create_gcs_client() 143 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 144 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 145 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE) 146 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY) 147 self._rust_client = None 148 if "rust_client" in kwargs: 149 # Inherit the rust client options from the kwargs 150 rust_client_options = kwargs["rust_client"] 151 if "max_concurrency" in kwargs: 152 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 153 if "multipart_chunksize" in kwargs: 154 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 155 if "read_timeout" in kwargs: 156 rust_client_options["read_timeout"] = kwargs["read_timeout"] 157 if "connect_timeout" in kwargs: 158 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 159 self._rust_client = self._create_rust_client(rust_client_options) 160 161 def _create_gcs_client(self) -> storage.Client: 162 client_options = {} 163 if self._endpoint_url: 164 client_options["api_endpoint"] = self._endpoint_url 165 166 # Use anonymous credentials for public buckets when skip_signature is enabled 167 if self._skip_signature: 168 return storage.Client( 169 project=self._project_id, 170 credentials=auth_credentials.AnonymousCredentials(), 171 client_options=client_options, 172 ) 173 174 if self._credentials_provider: 175 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 176 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 177 token = self._credentials_provider.get_credentials().get_custom_field("token") 178 179 # Use Workload Identity Federation (WIF) 180 identity_pool_credentials = identity_pool.Credentials( 181 audience=audience, 182 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 183 subject_token_supplier=StringTokenSupplier(token), 184 ) 185 return storage.Client( 186 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 187 ) 188 else: 189 # Use OAuth 2.0 token 190 token = self._credentials_provider.get_credentials().token 191 creds = OAuth2Credentials(token=token) 192 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 193 else: 194 return storage.Client(project=self._project_id, client_options=client_options) 195 196 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 197 configs = {} 198 199 if self._credentials_provider or self._endpoint_url: 200 # Workload Identity Federation (WIF) is not supported by the rust client: 201 # https://github.com/apache/arrow-rs-object-store/issues/258 202 logger.warning( 203 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client" 204 ) 205 return None 206 207 # Use environment variables if available, could be overridden by rust_client_options 208 if os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): 209 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") 210 if os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"): 211 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY") 212 if os.getenv("GOOGLE_SERVICE_ACCOUNT"): 213 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT") 214 if os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"): 215 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH") 216 217 if self._skip_signature: 218 configs["skip_signature"] = True 219 220 if rust_client_options: 221 if "application_credentials" in rust_client_options: 222 configs["application_credentials"] = rust_client_options["application_credentials"] 223 if "service_account_key" in rust_client_options: 224 configs["service_account_key"] = rust_client_options["service_account_key"] 225 if "service_account_path" in rust_client_options: 226 configs["service_account_path"] = rust_client_options["service_account_path"] 227 if "skip_signature" in rust_client_options: 228 configs["skip_signature"] = rust_client_options["skip_signature"] 229 if "max_concurrency" in rust_client_options: 230 configs["max_concurrency"] = rust_client_options["max_concurrency"] 231 if "multipart_chunksize" in rust_client_options: 232 configs["multipart_chunksize"] = rust_client_options["multipart_chunksize"] 233 if "read_timeout" in rust_client_options: 234 configs["read_timeout"] = rust_client_options["read_timeout"] 235 if "connect_timeout" in rust_client_options: 236 configs["connect_timeout"] = rust_client_options["connect_timeout"] 237 238 if rust_client_options and "bucket" in rust_client_options: 239 configs["bucket"] = rust_client_options["bucket"] 240 else: 241 bucket, _ = split_path(self._base_path) 242 configs["bucket"] = bucket 243 244 return RustClient( 245 provider=PROVIDER, 246 configs=configs, 247 ) 248 249 def _refresh_gcs_client_if_needed(self) -> None: 250 """ 251 Refreshes the GCS client if the current credentials are expired. 252 """ 253 if self._credentials_provider: 254 credentials = self._credentials_provider.get_credentials() 255 if credentials.is_expired(): 256 self._credentials_provider.refresh_credentials() 257 self._gcs_client = self._create_gcs_client() 258 259 def _translate_errors( 260 self, 261 func: Callable[[], _T], 262 operation: str, 263 bucket: str, 264 key: str, 265 ) -> _T: 266 """ 267 Translates errors like timeouts and client errors. 268 269 :param func: The function that performs the actual GCS operation. 270 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 271 :param bucket: The name of the GCS bucket involved in the operation. 272 :param key: The key of the object within the GCS bucket. 273 274 :return: The result of the GCS operation, typically the return value of the `func` callable. 275 """ 276 try: 277 return func() 278 except GoogleAPICallError as error: 279 status_code = error.code if error.code else -1 280 error_info = f"status_code: {status_code}, message: {error.message}" 281 if status_code == 404: 282 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 283 elif status_code == 412: 284 raise PreconditionFailedError( 285 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 286 ) from error 287 elif status_code == 304: 288 # for if_none_match with a specific etag condition. 289 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 290 else: 291 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 292 except InvalidResponse as error: 293 response_text = error.response.text 294 error_details = f"error: {error}, error_response_text: {response_text}" 295 # Check for NoSuchUpload within the response text 296 if "NoSuchUpload" in response_text: 297 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 298 else: 299 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 300 except RustRetryableError as error: 301 raise RetryableError( 302 f"Failed to {operation} object(s) at {bucket}/{key} due to exhausted retries from Rust. " 303 f"error_type: {type(error).__name__}" 304 ) from error 305 except FileNotFoundError: 306 raise 307 except Exception as error: 308 error_details = str(error) 309 raise RuntimeError( 310 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 311 ) from error 312 313 def _put_object( 314 self, 315 path: str, 316 body: bytes, 317 if_match: Optional[str] = None, 318 if_none_match: Optional[str] = None, 319 attributes: Optional[dict[str, str]] = None, 320 ) -> int: 321 """ 322 Uploads an object to Google Cloud Storage. 323 324 :param path: The path to the object to upload. 325 :param body: The content of the object to upload. 326 :param if_match: Optional ETag to match against the object. 327 :param if_none_match: Optional ETag to match against the object. 328 :param attributes: Optional attributes to attach to the object. 329 """ 330 bucket, key = split_path(path) 331 self._refresh_gcs_client_if_needed() 332 333 def _invoke_api() -> int: 334 bucket_obj = self._gcs_client.bucket(bucket) 335 blob = bucket_obj.blob(key) 336 337 kwargs = {} 338 339 if if_match: 340 kwargs["if_generation_match"] = int(if_match) # 412 error code 341 if if_none_match: 342 if if_none_match == "*": 343 raise NotImplementedError("if_none_match='*' is not supported for GCS") 344 else: 345 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 346 347 validated_attributes = validate_attributes(attributes) 348 if validated_attributes: 349 blob.metadata = validated_attributes 350 351 if ( 352 self._rust_client 353 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 354 and not path.endswith("/") 355 and not kwargs 356 and not validated_attributes 357 ): 358 run_async_rust_client_method(self._rust_client, "put", key, body) 359 else: 360 blob.upload_from_string(body, **kwargs) 361 362 return len(body) 363 364 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 365 366 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 367 bucket, key = split_path(path) 368 self._refresh_gcs_client_if_needed() 369 370 def _invoke_api() -> bytes: 371 bucket_obj = self._gcs_client.bucket(bucket) 372 blob = bucket_obj.blob(key) 373 if byte_range: 374 if self._rust_client: 375 return run_async_rust_client_method( 376 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 377 ) 378 else: 379 return blob.download_as_bytes( 380 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True 381 ) 382 else: 383 if self._rust_client: 384 return run_async_rust_client_method(self._rust_client, "get", key) 385 else: 386 return blob.download_as_bytes(single_shot_download=True) 387 388 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 389 390 def _copy_object(self, src_path: str, dest_path: str) -> int: 391 src_bucket, src_key = split_path(src_path) 392 dest_bucket, dest_key = split_path(dest_path) 393 self._refresh_gcs_client_if_needed() 394 395 src_object = self._get_object_metadata(src_path) 396 397 def _invoke_api() -> int: 398 source_bucket_obj = self._gcs_client.bucket(src_bucket) 399 source_blob = source_bucket_obj.blob(src_key) 400 401 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 402 destination_blob = destination_bucket_obj.blob(dest_key) 403 404 rewrite_tokens = [None] 405 while len(rewrite_tokens) > 0: 406 rewrite_token = rewrite_tokens.pop() 407 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 408 if next_rewrite_token is not None: 409 rewrite_tokens.append(next_rewrite_token) 410 411 return src_object.content_length 412 413 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 414 415 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 416 bucket, key = split_path(path) 417 self._refresh_gcs_client_if_needed() 418 419 def _invoke_api() -> None: 420 bucket_obj = self._gcs_client.bucket(bucket) 421 blob = bucket_obj.blob(key) 422 423 # If if_match is provided, use it as a precondition 424 if if_match: 425 generation = int(if_match) 426 blob.delete(if_generation_match=generation) 427 else: 428 # No if_match check needed, just delete 429 blob.delete() 430 431 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 432 433 def _is_dir(self, path: str) -> bool: 434 # Ensure the path ends with '/' to mimic a directory 435 path = self._append_delimiter(path) 436 437 bucket, key = split_path(path) 438 self._refresh_gcs_client_if_needed() 439 440 def _invoke_api() -> bool: 441 bucket_obj = self._gcs_client.bucket(bucket) 442 # List objects with the given prefix 443 blobs = bucket_obj.list_blobs( 444 prefix=key, 445 delimiter="/", 446 ) 447 # Check if there are any contents or common prefixes 448 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 449 450 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 451 452 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 453 bucket, key = split_path(path) 454 if path.endswith("/") or (bucket and not key): 455 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 456 # which metadata is not guaranteed to exist for cases such as 457 # "virtual prefix" that was never explicitly created. 458 if self._is_dir(path): 459 return ObjectMetadata( 460 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 461 ) 462 else: 463 raise FileNotFoundError(f"Directory {path} does not exist.") 464 else: 465 self._refresh_gcs_client_if_needed() 466 467 def _invoke_api() -> ObjectMetadata: 468 bucket_obj = self._gcs_client.bucket(bucket) 469 blob = bucket_obj.get_blob(key) 470 if not blob: 471 raise NotFound(f"Blob {key} not found in bucket {bucket}") 472 return ObjectMetadata( 473 key=path, 474 content_length=blob.size or 0, 475 content_type=blob.content_type, 476 last_modified=blob.updated or AWARE_DATETIME_MIN, 477 etag=str(blob.generation), 478 metadata=dict(blob.metadata) if blob.metadata else None, 479 ) 480 481 try: 482 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 483 except FileNotFoundError as error: 484 if strict: 485 # If the object does not exist on the given path, we will append a trailing slash and 486 # check if the path is a directory. 487 path = self._append_delimiter(path) 488 if self._is_dir(path): 489 return ObjectMetadata( 490 key=path, 491 type="directory", 492 content_length=0, 493 last_modified=AWARE_DATETIME_MIN, 494 ) 495 raise error 496 497 def _list_objects( 498 self, 499 path: str, 500 start_after: Optional[str] = None, 501 end_at: Optional[str] = None, 502 include_directories: bool = False, 503 follow_symlinks: bool = True, 504 ) -> Iterator[ObjectMetadata]: 505 bucket, prefix = split_path(path) 506 self._refresh_gcs_client_if_needed() 507 508 def _invoke_api() -> Iterator[ObjectMetadata]: 509 bucket_obj = self._gcs_client.bucket(bucket) 510 if include_directories: 511 blobs = bucket_obj.list_blobs( 512 prefix=prefix, 513 # This is ≥ instead of >. 514 start_offset=start_after, 515 delimiter="/", 516 ) 517 else: 518 blobs = bucket_obj.list_blobs( 519 prefix=prefix, 520 # This is ≥ instead of >. 521 start_offset=start_after, 522 ) 523 524 # GCS guarantees lexicographical order. 525 for blob in blobs: 526 key = blob.name 527 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 528 if key.endswith("/"): 529 if include_directories: 530 yield ObjectMetadata( 531 key=os.path.join(bucket, key.rstrip("/")), 532 type="directory", 533 content_length=0, 534 last_modified=blob.updated, 535 ) 536 else: 537 yield ObjectMetadata( 538 key=os.path.join(bucket, key), 539 content_length=blob.size, 540 content_type=blob.content_type, 541 last_modified=blob.updated, 542 etag=blob.etag, 543 ) 544 elif start_after != key: 545 return 546 547 # The directories must be accessed last. 548 if include_directories: 549 for directory in blobs.prefixes: 550 yield ObjectMetadata( 551 key=os.path.join(bucket, directory.rstrip("/")), 552 type="directory", 553 content_length=0, 554 last_modified=AWARE_DATETIME_MIN, 555 ) 556 557 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 558 559 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 560 bucket, key = split_path(remote_path) 561 file_size: int = 0 562 self._refresh_gcs_client_if_needed() 563 564 if isinstance(f, str): 565 file_size = os.path.getsize(f) 566 567 # Upload small files 568 if file_size <= self._multipart_threshold: 569 if self._rust_client and not attributes: 570 run_async_rust_client_method(self._rust_client, "upload", f, key) 571 else: 572 with open(f, "rb") as fp: 573 self._put_object(remote_path, fp.read(), attributes=attributes) 574 return file_size 575 576 # Upload large files using transfer manager 577 def _invoke_api() -> int: 578 if self._rust_client and not attributes: 579 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 580 else: 581 bucket_obj = self._gcs_client.bucket(bucket) 582 blob = bucket_obj.blob(key) 583 # GCS will raise an error if blob.metadata is None 584 validated_attributes = validate_attributes(attributes) 585 if validated_attributes is not None: 586 blob.metadata = validated_attributes 587 transfer_manager.upload_chunks_concurrently( 588 f, 589 blob, 590 chunk_size=self._multipart_chunksize, 591 max_workers=self._max_concurrency, 592 worker_type=transfer_manager.THREAD, 593 ) 594 595 return file_size 596 597 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 598 else: 599 f.seek(0, io.SEEK_END) 600 file_size = f.tell() 601 f.seek(0) 602 603 # Upload small files 604 if file_size <= self._multipart_threshold: 605 if isinstance(f, io.StringIO): 606 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 607 else: 608 self._put_object(remote_path, f.read(), attributes=attributes) 609 return file_size 610 611 # Upload large files using transfer manager 612 def _invoke_api() -> int: 613 bucket_obj = self._gcs_client.bucket(bucket) 614 blob = bucket_obj.blob(key) 615 validated_attributes = validate_attributes(attributes) 616 if validated_attributes: 617 blob.metadata = validated_attributes 618 if isinstance(f, io.StringIO): 619 mode = "w" 620 else: 621 mode = "wb" 622 623 # transfer manager does not support uploading a file object 624 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 625 temp_file_path = fp.name 626 fp.write(f.read()) 627 628 transfer_manager.upload_chunks_concurrently( 629 temp_file_path, 630 blob, 631 chunk_size=self._multipart_chunksize, 632 max_workers=self._max_concurrency, 633 worker_type=transfer_manager.THREAD, 634 ) 635 636 os.unlink(temp_file_path) 637 638 return file_size 639 640 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 641 642 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 643 self._refresh_gcs_client_if_needed() 644 645 if metadata is None: 646 metadata = self._get_object_metadata(remote_path) 647 648 bucket, key = split_path(remote_path) 649 650 if isinstance(f, str): 651 if os.path.dirname(f): 652 os.makedirs(os.path.dirname(f), exist_ok=True) 653 # Download small files 654 if metadata.content_length <= self._multipart_threshold: 655 if self._rust_client: 656 run_async_rust_client_method(self._rust_client, "download", key, f) 657 else: 658 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 659 temp_file_path = fp.name 660 fp.write(self._get_object(remote_path)) 661 os.rename(src=temp_file_path, dst=f) 662 return metadata.content_length 663 664 # Download large files using transfer manager 665 def _invoke_api() -> int: 666 bucket_obj = self._gcs_client.bucket(bucket) 667 blob = bucket_obj.blob(key) 668 if self._rust_client: 669 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f) 670 else: 671 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 672 temp_file_path = fp.name 673 transfer_manager.download_chunks_concurrently( 674 blob, 675 temp_file_path, 676 chunk_size=self._io_chunksize, 677 max_workers=self._max_concurrency, 678 worker_type=transfer_manager.THREAD, 679 ) 680 os.rename(src=temp_file_path, dst=f) 681 682 return metadata.content_length 683 684 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 685 else: 686 # Download small files 687 if metadata.content_length <= self._multipart_threshold: 688 response = self._get_object(remote_path) 689 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 690 # so we need to check whether `.decode()` is available. 691 if isinstance(f, io.StringIO): 692 if hasattr(response, "decode"): 693 f.write(response.decode("utf-8")) 694 else: 695 f.write(codecs.decode(memoryview(response), "utf-8")) 696 else: 697 f.write(response) 698 return metadata.content_length 699 700 # Download large files using transfer manager 701 def _invoke_api() -> int: 702 bucket_obj = self._gcs_client.bucket(bucket) 703 blob = bucket_obj.blob(key) 704 705 # transfer manager does not support downloading to a file object 706 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 707 temp_file_path = fp.name 708 transfer_manager.download_chunks_concurrently( 709 blob, 710 temp_file_path, 711 chunk_size=self._io_chunksize, 712 max_workers=self._max_concurrency, 713 worker_type=transfer_manager.THREAD, 714 ) 715 716 if isinstance(f, io.StringIO): 717 with open(temp_file_path, "r") as fp: 718 f.write(fp.read()) 719 else: 720 with open(temp_file_path, "rb") as fp: 721 f.write(fp.read()) 722 723 os.unlink(temp_file_path) 724 725 return metadata.content_length 726 727 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)