Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import io
 18import logging
 19import os
 20import tempfile
 21import time
 22from collections.abc import Callable, Iterator, Sequence, Sized
 23from typing import IO, Any, Optional, TypeVar, Union
 24
 25import opentelemetry.metrics as api_metrics
 26from google.api_core.exceptions import GoogleAPICallError, NotFound
 27from google.auth import identity_pool
 28from google.cloud import storage
 29from google.cloud.storage import transfer_manager
 30from google.cloud.storage.exceptions import InvalidResponse
 31from google.oauth2.credentials import Credentials as OAuth2Credentials
 32
 33from ..rust_utils import run_async_rust_client_method
 34from ..telemetry import Telemetry
 35from ..telemetry.attributes.base import AttributesProvider
 36from ..types import (
 37    AWARE_DATETIME_MIN,
 38    Credentials,
 39    CredentialsProvider,
 40    NotModifiedError,
 41    ObjectMetadata,
 42    PreconditionFailedError,
 43    Range,
 44    RetryableError,
 45)
 46from ..utils import (
 47    split_path,
 48    validate_attributes,
 49)
 50from .base import BaseStorageProvider
 51
 52_T = TypeVar("_T")
 53
 54PROVIDER = "gcs"
 55
 56MiB = 1024 * 1024
 57
 58DEFAULT_MULTIPART_THRESHOLD = 512 * MiB
 59DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
 60DEFAULT_IO_CHUNKSIZE = 32 * MiB
 61# Python uses a lower default concurrency due to the GIL limiting true parallelism in threads.
 62PYTHON_MAX_CONCURRENCY = 16
 63RUST_MAX_CONCURRENCY = 32
 64
 65logger = logging.getLogger(__name__)
 66
 67
[docs] 68class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 69 """ 70 Supply a string token to the Google Identity Pool. 71 """ 72 73 def __init__(self, token: str): 74 self._token = token 75
[docs] 76 def get_subject_token(self, context, request): 77 return self._token
78 79
[docs] 80class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 81 """ 82 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 83 """ 84 85 def __init__(self, audience: str, token_supplier: str): 86 """ 87 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 88 89 :param audience: The audience for the Google Identity Pool. 90 :param token_supplier: The token supplier for the Google Identity Pool. 91 """ 92 self._audience = audience 93 self._token_supplier = token_supplier 94
[docs] 95 def get_credentials(self) -> Credentials: 96 return Credentials( 97 access_key="", 98 secret_key="", 99 token="", 100 expiration=None, 101 custom_fields={"audience": self._audience, "token": self._token_supplier}, 102 )
103
[docs] 104 def refresh_credentials(self) -> None: 105 pass
106 107
[docs] 108class GoogleStorageProvider(BaseStorageProvider): 109 """ 110 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 111 """ 112 113 def __init__( 114 self, 115 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 116 endpoint_url: str = "", 117 base_path: str = "", 118 credentials_provider: Optional[CredentialsProvider] = None, 119 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {}, 120 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {}, 121 metric_attributes_providers: Sequence[AttributesProvider] = (), 122 **kwargs: Any, 123 ): 124 """ 125 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 126 127 :param project_id: The Google Cloud project ID. 128 :param endpoint_url: The custom endpoint URL for the GCS service. 129 :param base_path: The root prefix path within the bucket where all operations will be scoped. 130 :param credentials_provider: The provider to retrieve GCS credentials. 131 :param metric_counters: Metric counters. 132 :param metric_gauges: Metric gauges. 133 :param metric_attributes_providers: Metric attributes providers. 134 """ 135 super().__init__( 136 base_path=base_path, 137 provider_name=PROVIDER, 138 metric_counters=metric_counters, 139 metric_gauges=metric_gauges, 140 metric_attributes_providers=metric_attributes_providers, 141 ) 142 143 self._project_id = project_id 144 self._endpoint_url = endpoint_url 145 self._credentials_provider = credentials_provider 146 self._gcs_client = self._create_gcs_client() 147 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 148 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 149 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE) 150 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY) 151 self._rust_client = None 152 if "rust_client" in kwargs: 153 self._rust_client = self._create_rust_client(kwargs.get("rust_client")) 154 155 def _create_gcs_client(self) -> storage.Client: 156 client_options = {} 157 if self._endpoint_url: 158 client_options["api_endpoint"] = self._endpoint_url 159 160 if self._credentials_provider: 161 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 162 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 163 token = self._credentials_provider.get_credentials().get_custom_field("token") 164 165 # Use Workload Identity Federation (WIF) 166 identity_pool_credentials = identity_pool.Credentials( 167 audience=audience, 168 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 169 subject_token_supplier=StringTokenSupplier(token), 170 ) 171 return storage.Client( 172 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 173 ) 174 else: 175 # Use OAuth 2.0 token 176 token = self._credentials_provider.get_credentials().token 177 creds = OAuth2Credentials(token=token) 178 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 179 else: 180 return storage.Client(project=self._project_id, client_options=client_options) 181 182 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 183 from multistorageclient_rust import RustClient 184 185 configs = {} 186 187 if self._credentials_provider or self._endpoint_url: 188 # Workload Identity Federation (WIF) is not supported by the rust client: 189 # https://github.com/apache/arrow-rs-object-store/issues/258 190 logger.warning( 191 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client" 192 ) 193 return None 194 195 # Use environment variables if available, could be overridden by rust_client_options 196 if os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): 197 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") 198 if os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"): 199 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY") 200 if os.getenv("GOOGLE_SERVICE_ACCOUNT"): 201 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT") 202 if os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"): 203 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH") 204 205 if rust_client_options: 206 if "application_credentials" in rust_client_options: 207 configs["application_credentials"] = rust_client_options["application_credentials"] 208 if "service_account_key" in rust_client_options: 209 configs["service_account_key"] = rust_client_options["service_account_key"] 210 if "service_account_path" in rust_client_options: 211 configs["service_account_path"] = rust_client_options["service_account_path"] 212 if "skip_signature" in rust_client_options: 213 configs["skip_signature"] = rust_client_options["skip_signature"] 214 configs["max_concurrency"] = rust_client_options.get("max_concurrency", RUST_MAX_CONCURRENCY) 215 configs["multipart_chunksize"] = rust_client_options.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 216 217 if rust_client_options and "bucket" in rust_client_options: 218 configs["bucket"] = rust_client_options["bucket"] 219 else: 220 bucket, _ = split_path(self._base_path) 221 configs["bucket"] = bucket 222 223 return RustClient( 224 provider=PROVIDER, 225 configs=configs, 226 ) 227 228 def _refresh_gcs_client_if_needed(self) -> None: 229 """ 230 Refreshes the GCS client if the current credentials are expired. 231 """ 232 if self._credentials_provider: 233 credentials = self._credentials_provider.get_credentials() 234 if credentials.is_expired(): 235 self._credentials_provider.refresh_credentials() 236 self._gcs_client = self._create_gcs_client() 237 238 def _collect_metrics( 239 self, 240 func: Callable[[], _T], 241 operation: str, 242 bucket: str, 243 key: str, 244 put_object_size: Optional[int] = None, 245 get_object_size: Optional[int] = None, 246 ) -> _T: 247 """ 248 Collects and records performance metrics around GCS operations such as PUT, GET, DELETE, etc. 249 250 This method wraps an GCS operation and measures the time it takes to complete, along with recording 251 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 252 proper logging of duration and object size. 253 254 :param func: The function that performs the actual GCS operation. 255 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 256 :param bucket: The name of the GCS bucket involved in the operation. 257 :param key: The key of the object within the GCS bucket. 258 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 259 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 260 261 :return: The result of the GCS operation, typically the return value of the `func` callable. 262 """ 263 start_time = time.time() 264 status_code = 200 265 266 object_size = None 267 if operation == "PUT": 268 object_size = put_object_size 269 elif operation == "GET" and get_object_size: 270 object_size = get_object_size 271 272 try: 273 result = func() 274 if operation == "GET" and object_size is None and isinstance(result, Sized): 275 object_size = len(result) 276 return result 277 except GoogleAPICallError as error: 278 status_code = error.code if error.code else -1 279 error_info = f"status_code: {status_code}, message: {error.message}" 280 if status_code == 404: 281 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 282 elif status_code == 412: 283 raise PreconditionFailedError( 284 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 285 ) from error 286 elif status_code == 304: 287 # for if_none_match with a specific etag condition. 288 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 289 else: 290 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 291 except InvalidResponse as error: 292 status_code = error.response.status_code 293 response_text = error.response.text 294 error_details = f"error: {error}, error_response_text: {response_text}" 295 # Check for NoSuchUpload within the response text 296 if "NoSuchUpload" in response_text: 297 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 298 else: 299 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 300 except Exception as error: 301 status_code = -1 302 error_details = str(error) 303 raise RuntimeError( 304 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 305 ) from error 306 finally: 307 elapsed_time = time.time() - start_time 308 self._metric_helper.record_duration( 309 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code 310 ) 311 if object_size: 312 self._metric_helper.record_object_size( 313 object_size, 314 provider=self._provider_name, 315 operation=operation, 316 bucket=bucket, 317 status_code=status_code, 318 ) 319 320 def _put_object( 321 self, 322 path: str, 323 body: bytes, 324 if_match: Optional[str] = None, 325 if_none_match: Optional[str] = None, 326 attributes: Optional[dict[str, str]] = None, 327 ) -> int: 328 """ 329 Uploads an object to Google Cloud Storage. 330 331 :param path: The path to the object to upload. 332 :param body: The content of the object to upload. 333 :param if_match: Optional ETag to match against the object. 334 :param if_none_match: Optional ETag to match against the object. 335 :param attributes: Optional attributes to attach to the object. 336 """ 337 bucket, key = split_path(path) 338 self._refresh_gcs_client_if_needed() 339 340 def _invoke_api() -> int: 341 bucket_obj = self._gcs_client.bucket(bucket) 342 blob = bucket_obj.blob(key) 343 344 kwargs = {} 345 346 if if_match: 347 kwargs["if_generation_match"] = int(if_match) # 412 error code 348 if if_none_match: 349 if if_none_match == "*": 350 raise NotImplementedError("if_none_match='*' is not supported for GCS") 351 else: 352 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 353 354 validated_attributes = validate_attributes(attributes) 355 if validated_attributes: 356 blob.metadata = validated_attributes 357 358 if ( 359 self._rust_client 360 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 361 and not path.endswith("/") 362 and not kwargs 363 and not validated_attributes 364 ): 365 run_async_rust_client_method(self._rust_client, "put", key, body) 366 else: 367 blob.upload_from_string(body, **kwargs) 368 369 return len(body) 370 371 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 372 373 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 374 bucket, key = split_path(path) 375 self._refresh_gcs_client_if_needed() 376 377 def _invoke_api() -> bytes: 378 bucket_obj = self._gcs_client.bucket(bucket) 379 blob = bucket_obj.blob(key) 380 if byte_range: 381 if self._rust_client: 382 return run_async_rust_client_method( 383 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 384 ) 385 else: 386 return blob.download_as_bytes( 387 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True 388 ) 389 else: 390 if self._rust_client: 391 return run_async_rust_client_method(self._rust_client, "get", key) 392 else: 393 return blob.download_as_bytes(single_shot_download=True) 394 395 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 396 397 def _copy_object(self, src_path: str, dest_path: str) -> int: 398 src_bucket, src_key = split_path(src_path) 399 dest_bucket, dest_key = split_path(dest_path) 400 self._refresh_gcs_client_if_needed() 401 402 src_object = self._get_object_metadata(src_path) 403 404 def _invoke_api() -> int: 405 source_bucket_obj = self._gcs_client.bucket(src_bucket) 406 source_blob = source_bucket_obj.blob(src_key) 407 408 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 409 destination_blob = destination_bucket_obj.blob(dest_key) 410 411 rewrite_tokens = [None] 412 while len(rewrite_tokens) > 0: 413 rewrite_token = rewrite_tokens.pop() 414 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 415 if next_rewrite_token is not None: 416 rewrite_tokens.append(next_rewrite_token) 417 418 return src_object.content_length 419 420 return self._collect_metrics( 421 _invoke_api, 422 operation="COPY", 423 bucket=src_bucket, 424 key=src_key, 425 put_object_size=src_object.content_length, 426 ) 427 428 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 429 bucket, key = split_path(path) 430 self._refresh_gcs_client_if_needed() 431 432 def _invoke_api() -> None: 433 bucket_obj = self._gcs_client.bucket(bucket) 434 blob = bucket_obj.blob(key) 435 436 # If if_match is provided, use it as a precondition 437 if if_match: 438 generation = int(if_match) 439 blob.delete(if_generation_match=generation) 440 else: 441 # No if_match check needed, just delete 442 blob.delete() 443 444 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 445 446 def _is_dir(self, path: str) -> bool: 447 # Ensure the path ends with '/' to mimic a directory 448 path = self._append_delimiter(path) 449 450 bucket, key = split_path(path) 451 self._refresh_gcs_client_if_needed() 452 453 def _invoke_api() -> bool: 454 bucket_obj = self._gcs_client.bucket(bucket) 455 # List objects with the given prefix 456 blobs = bucket_obj.list_blobs( 457 prefix=key, 458 delimiter="/", 459 ) 460 # Check if there are any contents or common prefixes 461 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 462 463 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) 464 465 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 466 bucket, key = split_path(path) 467 if path.endswith("/") or (bucket and not key): 468 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 469 # which metadata is not guaranteed to exist for cases such as 470 # "virtual prefix" that was never explicitly created. 471 if self._is_dir(path): 472 return ObjectMetadata( 473 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 474 ) 475 else: 476 raise FileNotFoundError(f"Directory {path} does not exist.") 477 else: 478 self._refresh_gcs_client_if_needed() 479 480 def _invoke_api() -> ObjectMetadata: 481 bucket_obj = self._gcs_client.bucket(bucket) 482 blob = bucket_obj.get_blob(key) 483 if not blob: 484 raise NotFound(f"Blob {key} not found in bucket {bucket}") 485 return ObjectMetadata( 486 key=path, 487 content_length=blob.size or 0, 488 content_type=blob.content_type, 489 last_modified=blob.updated or AWARE_DATETIME_MIN, 490 etag=str(blob.generation), 491 metadata=dict(blob.metadata) if blob.metadata else None, 492 ) 493 494 try: 495 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 496 except FileNotFoundError as error: 497 if strict: 498 # If the object does not exist on the given path, we will append a trailing slash and 499 # check if the path is a directory. 500 path = self._append_delimiter(path) 501 if self._is_dir(path): 502 return ObjectMetadata( 503 key=path, 504 type="directory", 505 content_length=0, 506 last_modified=AWARE_DATETIME_MIN, 507 ) 508 raise error 509 510 def _list_objects( 511 self, 512 path: str, 513 start_after: Optional[str] = None, 514 end_at: Optional[str] = None, 515 include_directories: bool = False, 516 ) -> Iterator[ObjectMetadata]: 517 bucket, prefix = split_path(path) 518 self._refresh_gcs_client_if_needed() 519 520 def _invoke_api() -> Iterator[ObjectMetadata]: 521 bucket_obj = self._gcs_client.bucket(bucket) 522 if include_directories: 523 blobs = bucket_obj.list_blobs( 524 prefix=prefix, 525 # This is ≥ instead of >. 526 start_offset=start_after, 527 delimiter="/", 528 ) 529 else: 530 blobs = bucket_obj.list_blobs( 531 prefix=prefix, 532 # This is ≥ instead of >. 533 start_offset=start_after, 534 ) 535 536 # GCS guarantees lexicographical order. 537 for blob in blobs: 538 key = blob.name 539 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 540 if key.endswith("/"): 541 if include_directories: 542 yield ObjectMetadata( 543 key=os.path.join(bucket, key.rstrip("/")), 544 type="directory", 545 content_length=0, 546 last_modified=blob.updated, 547 ) 548 else: 549 yield ObjectMetadata( 550 key=os.path.join(bucket, key), 551 content_length=blob.size, 552 content_type=blob.content_type, 553 last_modified=blob.updated, 554 etag=blob.etag, 555 ) 556 elif start_after != key: 557 return 558 559 # The directories must be accessed last. 560 if include_directories: 561 for directory in blobs.prefixes: 562 yield ObjectMetadata( 563 key=os.path.join(bucket, directory.rstrip("/")), 564 type="directory", 565 content_length=0, 566 last_modified=AWARE_DATETIME_MIN, 567 ) 568 569 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 570 571 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 572 bucket, key = split_path(remote_path) 573 file_size: int = 0 574 self._refresh_gcs_client_if_needed() 575 576 if isinstance(f, str): 577 file_size = os.path.getsize(f) 578 579 # Upload small files 580 if file_size <= self._multipart_threshold: 581 if self._rust_client and not attributes: 582 run_async_rust_client_method(self._rust_client, "upload", f, key) 583 else: 584 with open(f, "rb") as fp: 585 self._put_object(remote_path, fp.read(), attributes=attributes) 586 return file_size 587 588 # Upload large files using transfer manager 589 def _invoke_api() -> int: 590 if self._rust_client and not attributes: 591 run_async_rust_client_method(self._rust_client, "upload_multipart", f, key) 592 else: 593 bucket_obj = self._gcs_client.bucket(bucket) 594 blob = bucket_obj.blob(key) 595 blob.metadata = validate_attributes(attributes) 596 transfer_manager.upload_chunks_concurrently( 597 f, 598 blob, 599 chunk_size=self._multipart_chunksize, 600 max_workers=self._max_concurrency, 601 worker_type=transfer_manager.THREAD, 602 ) 603 604 return file_size 605 606 return self._collect_metrics( 607 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 608 ) 609 else: 610 f.seek(0, io.SEEK_END) 611 file_size = f.tell() 612 f.seek(0) 613 614 # Upload small files 615 if file_size <= self._multipart_threshold: 616 if isinstance(f, io.StringIO): 617 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 618 else: 619 self._put_object(remote_path, f.read(), attributes=attributes) 620 return file_size 621 622 # Upload large files using transfer manager 623 def _invoke_api() -> int: 624 bucket_obj = self._gcs_client.bucket(bucket) 625 blob = bucket_obj.blob(key) 626 validated_attributes = validate_attributes(attributes) 627 if validated_attributes: 628 blob.metadata = validated_attributes 629 if isinstance(f, io.StringIO): 630 mode = "w" 631 else: 632 mode = "wb" 633 634 # transfer manager does not support uploading a file object 635 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 636 temp_file_path = fp.name 637 fp.write(f.read()) 638 639 transfer_manager.upload_chunks_concurrently( 640 temp_file_path, 641 blob, 642 chunk_size=self._multipart_chunksize, 643 max_workers=self._max_concurrency, 644 worker_type=transfer_manager.THREAD, 645 ) 646 647 os.unlink(temp_file_path) 648 649 return file_size 650 651 return self._collect_metrics( 652 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 653 ) 654 655 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 656 self._refresh_gcs_client_if_needed() 657 658 if metadata is None: 659 metadata = self._get_object_metadata(remote_path) 660 661 bucket, key = split_path(remote_path) 662 663 if isinstance(f, str): 664 if os.path.dirname(f): 665 os.makedirs(os.path.dirname(f), exist_ok=True) 666 # Download small files 667 if metadata.content_length <= self._multipart_threshold: 668 if self._rust_client: 669 run_async_rust_client_method(self._rust_client, "download", key, f) 670 else: 671 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 672 temp_file_path = fp.name 673 fp.write(self._get_object(remote_path)) 674 os.rename(src=temp_file_path, dst=f) 675 return metadata.content_length 676 677 # Download large files using transfer manager 678 def _invoke_api() -> int: 679 bucket_obj = self._gcs_client.bucket(bucket) 680 blob = bucket_obj.blob(key) 681 if self._rust_client: 682 run_async_rust_client_method(self._rust_client, "download_multipart", key, f) 683 else: 684 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 685 temp_file_path = fp.name 686 transfer_manager.download_chunks_concurrently( 687 blob, 688 temp_file_path, 689 chunk_size=self._io_chunksize, 690 max_workers=self._max_concurrency, 691 worker_type=transfer_manager.THREAD, 692 ) 693 os.rename(src=temp_file_path, dst=f) 694 695 return metadata.content_length 696 697 return self._collect_metrics( 698 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 699 ) 700 else: 701 # Download small files 702 if metadata.content_length <= self._multipart_threshold: 703 response = self._get_object(remote_path) 704 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 705 # so we need to check whether `.decode()` is available. 706 if isinstance(f, io.StringIO): 707 if hasattr(response, "decode"): 708 f.write(response.decode("utf-8")) 709 else: 710 f.write(codecs.decode(memoryview(response), "utf-8")) 711 else: 712 f.write(response) 713 return metadata.content_length 714 715 # Download large files using transfer manager 716 def _invoke_api() -> int: 717 bucket_obj = self._gcs_client.bucket(bucket) 718 blob = bucket_obj.blob(key) 719 720 # transfer manager does not support downloading to a file object 721 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 722 temp_file_path = fp.name 723 transfer_manager.download_chunks_concurrently( 724 blob, 725 temp_file_path, 726 chunk_size=self._io_chunksize, 727 max_workers=self._max_concurrency, 728 worker_type=transfer_manager.THREAD, 729 ) 730 731 if isinstance(f, io.StringIO): 732 with open(temp_file_path, "r") as fp: 733 f.write(fp.read()) 734 else: 735 with open(temp_file_path, "rb") as fp: 736 f.write(fp.read()) 737 738 os.unlink(temp_file_path) 739 740 return metadata.content_length 741 742 return self._collect_metrics( 743 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 744 )