Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import copy
 18import io
 19import json
 20import logging
 21import os
 22import tempfile
 23from collections.abc import Callable, Iterator
 24from typing import IO, Any, Optional, TypeVar, Union
 25
 26from google.api_core.exceptions import GoogleAPICallError, NotFound
 27from google.auth import credentials as auth_credentials
 28from google.auth import identity_pool
 29from google.cloud import storage
 30from google.cloud.storage import transfer_manager
 31from google.cloud.storage.exceptions import InvalidResponse
 32from google.oauth2 import service_account
 33from google.oauth2.credentials import Credentials as OAuth2Credentials
 34
 35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
 36
 37from ..rust_utils import parse_retry_config, run_async_rust_client_method
 38from ..telemetry import Telemetry
 39from ..types import (
 40    AWARE_DATETIME_MIN,
 41    Credentials,
 42    CredentialsProvider,
 43    NotModifiedError,
 44    ObjectMetadata,
 45    PreconditionFailedError,
 46    Range,
 47    RetryableError,
 48)
 49from ..utils import (
 50    split_path,
 51    validate_attributes,
 52)
 53from .base import BaseStorageProvider
 54
 55_T = TypeVar("_T")
 56
 57PROVIDER = "gcs"
 58
 59MiB = 1024 * 1024
 60
 61DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
 62DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
 63DEFAULT_IO_CHUNKSIZE = 32 * MiB
 64PYTHON_MAX_CONCURRENCY = 8
 65
 66logger = logging.getLogger(__name__)
 67
 68
[docs] 69class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 70 """ 71 Supply a string token to the Google Identity Pool. 72 """ 73 74 def __init__(self, token: str): 75 self._token = token 76
[docs] 77 def get_subject_token(self, context, request): 78 return self._token
79 80
[docs] 81class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 82 """ 83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 84 """ 85 86 def __init__(self, audience: str, token_supplier: str): 87 """ 88 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 89 90 :param audience: The audience for the Google Identity Pool. 91 :param token_supplier: The token supplier for the Google Identity Pool. 92 """ 93 self._audience = audience 94 self._token_supplier = token_supplier 95
[docs] 96 def get_credentials(self) -> Credentials: 97 return Credentials( 98 access_key="", 99 secret_key="", 100 token="", 101 expiration=None, 102 custom_fields={"audience": self._audience, "token": self._token_supplier}, 103 )
104
[docs] 105 def refresh_credentials(self) -> None: 106 pass
107 108
[docs] 109class GoogleServiceAccountCredentialsProvider(CredentialsProvider): 110 """ 111 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials. 112 """ 113 114 #: Google service account private key file contents. 115 _info: dict[str, Any] 116 117 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None): 118 """ 119 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a 120 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file 121 or the file contents. 122 123 :param file: Path to a Google service account private key file. 124 :param info: Google service account private key file contents. 125 """ 126 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)): 127 raise ValueError("Must specify exactly one of file or info") 128 129 if file is not None: 130 with open(file, "r") as f: 131 self._info = json.load(f) 132 elif info is not None: 133 self._info = copy.deepcopy(info) 134
[docs] 135 def get_credentials(self) -> Credentials: 136 return Credentials( 137 access_key="", 138 secret_key="", 139 token=None, 140 expiration=None, 141 custom_fields={"info": copy.deepcopy(self._info)}, 142 )
143
[docs] 144 def refresh_credentials(self) -> None: 145 pass
146 147
[docs] 148class GoogleStorageProvider(BaseStorageProvider): 149 """ 150 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 151 """ 152 153 def __init__( 154 self, 155 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 156 endpoint_url: str = "", 157 base_path: str = "", 158 credentials_provider: Optional[CredentialsProvider] = None, 159 config_dict: Optional[dict[str, Any]] = None, 160 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 161 **kwargs: Any, 162 ): 163 """ 164 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 165 166 :param project_id: The Google Cloud project ID. 167 :param endpoint_url: The custom endpoint URL for the GCS service. 168 :param base_path: The root prefix path within the bucket where all operations will be scoped. 169 :param credentials_provider: The provider to retrieve GCS credentials. 170 :param config_dict: Resolved MSC config. 171 :param telemetry_provider: A function that provides a telemetry instance. 172 """ 173 super().__init__( 174 base_path=base_path, 175 provider_name=PROVIDER, 176 config_dict=config_dict, 177 telemetry_provider=telemetry_provider, 178 ) 179 180 self._project_id = project_id 181 self._endpoint_url = endpoint_url 182 self._credentials_provider = credentials_provider 183 self._skip_signature = kwargs.get("skip_signature", False) 184 self._gcs_client = self._create_gcs_client() 185 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 186 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 187 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE) 188 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY) 189 self._rust_client = None 190 if "rust_client" in kwargs: 191 # Inherit the rust client options from the kwargs 192 rust_client_options = copy.deepcopy(kwargs["rust_client"]) 193 if "max_concurrency" in kwargs: 194 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 195 if "multipart_chunksize" in kwargs: 196 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 197 if "read_timeout" in kwargs: 198 rust_client_options["read_timeout"] = kwargs["read_timeout"] 199 if "connect_timeout" in kwargs: 200 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 201 if "service_account_key" not in rust_client_options and isinstance( 202 self._credentials_provider, GoogleServiceAccountCredentialsProvider 203 ): 204 rust_client_options["service_account_key"] = json.dumps( 205 self._credentials_provider.get_credentials().get_custom_field("info") 206 ) 207 self._rust_client = self._create_rust_client(rust_client_options) 208 209 def _create_gcs_client(self) -> storage.Client: 210 client_options = {} 211 if self._endpoint_url: 212 client_options["api_endpoint"] = self._endpoint_url 213 214 # Use anonymous credentials for public buckets when skip_signature is enabled 215 if self._skip_signature: 216 return storage.Client( 217 project=self._project_id, 218 credentials=auth_credentials.AnonymousCredentials(), 219 client_options=client_options, 220 ) 221 222 if self._credentials_provider: 223 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 224 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 225 token = self._credentials_provider.get_credentials().get_custom_field("token") 226 227 # Use Workload Identity Federation (WIF) 228 identity_pool_credentials = identity_pool.Credentials( 229 audience=audience, 230 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 231 subject_token_supplier=StringTokenSupplier(token), 232 ) 233 return storage.Client( 234 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 235 ) 236 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 237 # Use service account key. 238 service_account_credentials = service_account.Credentials.from_service_account_info( 239 info=self._credentials_provider.get_credentials().get_custom_field("info") 240 ) 241 return storage.Client( 242 project=self._project_id, credentials=service_account_credentials, client_options=client_options 243 ) 244 else: 245 # Use OAuth 2.0 token 246 token = self._credentials_provider.get_credentials().token 247 creds = OAuth2Credentials(token=token) 248 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 249 else: 250 return storage.Client(project=self._project_id, client_options=client_options) 251 252 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 253 if self._credentials_provider or self._endpoint_url: 254 # Workload Identity Federation (WIF) is not supported by the rust client: 255 # https://github.com/apache/arrow-rs-object-store/issues/258 256 logger.warning( 257 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client" 258 ) 259 return None 260 configs = dict(rust_client_options) if rust_client_options else {} 261 262 # Extract and parse retry configuration 263 retry_config = parse_retry_config(configs) 264 265 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): 266 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") 267 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"): 268 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY") 269 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"): 270 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT") 271 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"): 272 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH") 273 274 if self._skip_signature and "skip_signature" not in configs: 275 configs["skip_signature"] = True 276 277 if "bucket" not in configs: 278 bucket, _ = split_path(self._base_path) 279 configs["bucket"] = bucket 280 281 return RustClient( 282 provider=PROVIDER, 283 configs=configs, 284 retry=retry_config, 285 ) 286 287 def _refresh_gcs_client_if_needed(self) -> None: 288 """ 289 Refreshes the GCS client if the current credentials are expired. 290 """ 291 if self._credentials_provider: 292 credentials = self._credentials_provider.get_credentials() 293 if credentials.is_expired(): 294 self._credentials_provider.refresh_credentials() 295 self._gcs_client = self._create_gcs_client() 296 297 def _translate_errors( 298 self, 299 func: Callable[[], _T], 300 operation: str, 301 bucket: str, 302 key: str, 303 ) -> _T: 304 """ 305 Translates errors like timeouts and client errors. 306 307 :param func: The function that performs the actual GCS operation. 308 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 309 :param bucket: The name of the GCS bucket involved in the operation. 310 :param key: The key of the object within the GCS bucket. 311 312 :return: The result of the GCS operation, typically the return value of the `func` callable. 313 """ 314 try: 315 return func() 316 except GoogleAPICallError as error: 317 status_code = error.code if error.code else -1 318 error_info = f"status_code: {status_code}, message: {error.message}" 319 if status_code == 404: 320 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 321 elif status_code == 412: 322 raise PreconditionFailedError( 323 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 324 ) from error 325 elif status_code == 304: 326 # for if_none_match with a specific etag condition. 327 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 328 else: 329 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 330 except InvalidResponse as error: 331 response_text = error.response.text 332 error_details = f"error: {error}, error_response_text: {response_text}" 333 # Check for NoSuchUpload within the response text 334 if "NoSuchUpload" in response_text: 335 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 336 else: 337 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 338 except RustRetryableError as error: 339 raise RetryableError( 340 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 341 f"error_type: {type(error).__name__}" 342 ) from error 343 except RustClientError as error: 344 message = error.args[0] 345 status_code = error.args[1] 346 if status_code == 404: 347 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 348 elif status_code == 403: 349 raise PermissionError( 350 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 351 ) from error 352 else: 353 raise RetryableError( 354 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 355 ) from error 356 except Exception as error: 357 error_details = str(error) 358 raise RuntimeError( 359 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 360 ) from error 361 362 def _put_object( 363 self, 364 path: str, 365 body: bytes, 366 if_match: Optional[str] = None, 367 if_none_match: Optional[str] = None, 368 attributes: Optional[dict[str, str]] = None, 369 ) -> int: 370 """ 371 Uploads an object to Google Cloud Storage. 372 373 :param path: The path to the object to upload. 374 :param body: The content of the object to upload. 375 :param if_match: Optional ETag to match against the object. 376 :param if_none_match: Optional ETag to match against the object. 377 :param attributes: Optional attributes to attach to the object. 378 """ 379 bucket, key = split_path(path) 380 self._refresh_gcs_client_if_needed() 381 382 def _invoke_api() -> int: 383 bucket_obj = self._gcs_client.bucket(bucket) 384 blob = bucket_obj.blob(key) 385 386 kwargs = {} 387 388 if if_match: 389 kwargs["if_generation_match"] = int(if_match) # 412 error code 390 if if_none_match: 391 if if_none_match == "*": 392 raise NotImplementedError("if_none_match='*' is not supported for GCS") 393 else: 394 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 395 396 validated_attributes = validate_attributes(attributes) 397 if validated_attributes: 398 blob.metadata = validated_attributes 399 400 if ( 401 self._rust_client 402 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 403 and not path.endswith("/") 404 and not kwargs 405 and not validated_attributes 406 ): 407 run_async_rust_client_method(self._rust_client, "put", key, body) 408 else: 409 blob.upload_from_string(body, **kwargs) 410 411 return len(body) 412 413 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 414 415 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 416 bucket, key = split_path(path) 417 self._refresh_gcs_client_if_needed() 418 419 def _invoke_api() -> bytes: 420 bucket_obj = self._gcs_client.bucket(bucket) 421 blob = bucket_obj.blob(key) 422 if byte_range: 423 if self._rust_client: 424 return run_async_rust_client_method( 425 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 426 ) 427 else: 428 return blob.download_as_bytes( 429 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True 430 ) 431 else: 432 if self._rust_client: 433 return run_async_rust_client_method(self._rust_client, "get", key) 434 else: 435 return blob.download_as_bytes(single_shot_download=True) 436 437 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 438 439 def _copy_object(self, src_path: str, dest_path: str) -> int: 440 src_bucket, src_key = split_path(src_path) 441 dest_bucket, dest_key = split_path(dest_path) 442 self._refresh_gcs_client_if_needed() 443 444 src_object = self._get_object_metadata(src_path) 445 446 def _invoke_api() -> int: 447 source_bucket_obj = self._gcs_client.bucket(src_bucket) 448 source_blob = source_bucket_obj.blob(src_key) 449 450 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 451 destination_blob = destination_bucket_obj.blob(dest_key) 452 453 rewrite_tokens = [None] 454 while len(rewrite_tokens) > 0: 455 rewrite_token = rewrite_tokens.pop() 456 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 457 if next_rewrite_token is not None: 458 rewrite_tokens.append(next_rewrite_token) 459 460 return src_object.content_length 461 462 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 463 464 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 465 bucket, key = split_path(path) 466 self._refresh_gcs_client_if_needed() 467 468 def _invoke_api() -> None: 469 bucket_obj = self._gcs_client.bucket(bucket) 470 blob = bucket_obj.blob(key) 471 472 # If if_match is provided, use it as a precondition 473 if if_match: 474 generation = int(if_match) 475 blob.delete(if_generation_match=generation) 476 else: 477 # No if_match check needed, just delete 478 blob.delete() 479 480 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 481 482 def _is_dir(self, path: str) -> bool: 483 # Ensure the path ends with '/' to mimic a directory 484 path = self._append_delimiter(path) 485 486 bucket, key = split_path(path) 487 self._refresh_gcs_client_if_needed() 488 489 def _invoke_api() -> bool: 490 bucket_obj = self._gcs_client.bucket(bucket) 491 # List objects with the given prefix 492 blobs = bucket_obj.list_blobs( 493 prefix=key, 494 delimiter="/", 495 ) 496 # Check if there are any contents or common prefixes 497 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 498 499 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 500 501 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 502 bucket, key = split_path(path) 503 if path.endswith("/") or (bucket and not key): 504 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 505 # which metadata is not guaranteed to exist for cases such as 506 # "virtual prefix" that was never explicitly created. 507 if self._is_dir(path): 508 return ObjectMetadata( 509 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 510 ) 511 else: 512 raise FileNotFoundError(f"Directory {path} does not exist.") 513 else: 514 self._refresh_gcs_client_if_needed() 515 516 def _invoke_api() -> ObjectMetadata: 517 bucket_obj = self._gcs_client.bucket(bucket) 518 blob = bucket_obj.get_blob(key) 519 if not blob: 520 raise NotFound(f"Blob {key} not found in bucket {bucket}") 521 return ObjectMetadata( 522 key=path, 523 content_length=blob.size or 0, 524 content_type=blob.content_type, 525 last_modified=blob.updated or AWARE_DATETIME_MIN, 526 etag=str(blob.generation), 527 metadata=dict(blob.metadata) if blob.metadata else None, 528 ) 529 530 try: 531 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 532 except FileNotFoundError as error: 533 if strict: 534 # If the object does not exist on the given path, we will append a trailing slash and 535 # check if the path is a directory. 536 path = self._append_delimiter(path) 537 if self._is_dir(path): 538 return ObjectMetadata( 539 key=path, 540 type="directory", 541 content_length=0, 542 last_modified=AWARE_DATETIME_MIN, 543 ) 544 raise error 545 546 def _list_objects( 547 self, 548 path: str, 549 start_after: Optional[str] = None, 550 end_at: Optional[str] = None, 551 include_directories: bool = False, 552 follow_symlinks: bool = True, 553 ) -> Iterator[ObjectMetadata]: 554 bucket, prefix = split_path(path) 555 556 # Get the prefix of the start_after and end_at paths relative to the bucket. 557 if start_after: 558 _, start_after = split_path(start_after) 559 if end_at: 560 _, end_at = split_path(end_at) 561 562 self._refresh_gcs_client_if_needed() 563 564 def _invoke_api() -> Iterator[ObjectMetadata]: 565 bucket_obj = self._gcs_client.bucket(bucket) 566 if include_directories: 567 blobs = bucket_obj.list_blobs( 568 prefix=prefix, 569 # This is ≥ instead of >. 570 start_offset=start_after, 571 delimiter="/", 572 ) 573 else: 574 blobs = bucket_obj.list_blobs( 575 prefix=prefix, 576 # This is ≥ instead of >. 577 start_offset=start_after, 578 ) 579 580 # GCS guarantees lexicographical order. 581 for blob in blobs: 582 key = blob.name 583 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 584 if key.endswith("/"): 585 if include_directories: 586 yield ObjectMetadata( 587 key=os.path.join(bucket, key.rstrip("/")), 588 type="directory", 589 content_length=0, 590 last_modified=blob.updated, 591 ) 592 else: 593 yield ObjectMetadata( 594 key=os.path.join(bucket, key), 595 content_length=blob.size, 596 content_type=blob.content_type, 597 last_modified=blob.updated, 598 etag=blob.etag, 599 ) 600 elif start_after != key: 601 return 602 603 # The directories must be accessed last. 604 if include_directories: 605 for directory in blobs.prefixes: 606 yield ObjectMetadata( 607 key=os.path.join(bucket, directory.rstrip("/")), 608 type="directory", 609 content_length=0, 610 last_modified=AWARE_DATETIME_MIN, 611 ) 612 613 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 614 615 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 616 bucket, key = split_path(remote_path) 617 file_size: int = 0 618 self._refresh_gcs_client_if_needed() 619 620 if isinstance(f, str): 621 file_size = os.path.getsize(f) 622 623 # Upload small files 624 if file_size <= self._multipart_threshold: 625 if self._rust_client and not attributes: 626 run_async_rust_client_method(self._rust_client, "upload", f, key) 627 else: 628 with open(f, "rb") as fp: 629 self._put_object(remote_path, fp.read(), attributes=attributes) 630 return file_size 631 632 # Upload large files using transfer manager 633 def _invoke_api() -> int: 634 if self._rust_client and not attributes: 635 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 636 else: 637 bucket_obj = self._gcs_client.bucket(bucket) 638 blob = bucket_obj.blob(key) 639 # GCS will raise an error if blob.metadata is None 640 validated_attributes = validate_attributes(attributes) 641 if validated_attributes is not None: 642 blob.metadata = validated_attributes 643 transfer_manager.upload_chunks_concurrently( 644 f, 645 blob, 646 chunk_size=self._multipart_chunksize, 647 max_workers=self._max_concurrency, 648 worker_type=transfer_manager.THREAD, 649 ) 650 651 return file_size 652 653 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 654 else: 655 f.seek(0, io.SEEK_END) 656 file_size = f.tell() 657 f.seek(0) 658 659 # Upload small files 660 if file_size <= self._multipart_threshold: 661 if isinstance(f, io.StringIO): 662 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 663 else: 664 self._put_object(remote_path, f.read(), attributes=attributes) 665 return file_size 666 667 # Upload large files using transfer manager 668 def _invoke_api() -> int: 669 bucket_obj = self._gcs_client.bucket(bucket) 670 blob = bucket_obj.blob(key) 671 validated_attributes = validate_attributes(attributes) 672 if validated_attributes: 673 blob.metadata = validated_attributes 674 if isinstance(f, io.StringIO): 675 mode = "w" 676 else: 677 mode = "wb" 678 679 # transfer manager does not support uploading a file object 680 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 681 temp_file_path = fp.name 682 fp.write(f.read()) 683 684 transfer_manager.upload_chunks_concurrently( 685 temp_file_path, 686 blob, 687 chunk_size=self._multipart_chunksize, 688 max_workers=self._max_concurrency, 689 worker_type=transfer_manager.THREAD, 690 ) 691 692 os.unlink(temp_file_path) 693 694 return file_size 695 696 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 697 698 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 699 self._refresh_gcs_client_if_needed() 700 701 if metadata is None: 702 metadata = self._get_object_metadata(remote_path) 703 704 bucket, key = split_path(remote_path) 705 706 if isinstance(f, str): 707 if os.path.dirname(f): 708 os.makedirs(os.path.dirname(f), exist_ok=True) 709 # Download small files 710 if metadata.content_length <= self._multipart_threshold: 711 if self._rust_client: 712 run_async_rust_client_method(self._rust_client, "download", key, f) 713 else: 714 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 715 temp_file_path = fp.name 716 fp.write(self._get_object(remote_path)) 717 os.rename(src=temp_file_path, dst=f) 718 return metadata.content_length 719 720 # Download large files using transfer manager 721 def _invoke_api() -> int: 722 bucket_obj = self._gcs_client.bucket(bucket) 723 blob = bucket_obj.blob(key) 724 if self._rust_client: 725 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f) 726 else: 727 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 728 temp_file_path = fp.name 729 transfer_manager.download_chunks_concurrently( 730 blob, 731 temp_file_path, 732 chunk_size=self._io_chunksize, 733 max_workers=self._max_concurrency, 734 worker_type=transfer_manager.THREAD, 735 ) 736 os.rename(src=temp_file_path, dst=f) 737 738 return metadata.content_length 739 740 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 741 else: 742 # Download small files 743 if metadata.content_length <= self._multipart_threshold: 744 response = self._get_object(remote_path) 745 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 746 # so we need to check whether `.decode()` is available. 747 if isinstance(f, io.StringIO): 748 if hasattr(response, "decode"): 749 f.write(response.decode("utf-8")) 750 else: 751 f.write(codecs.decode(memoryview(response), "utf-8")) 752 else: 753 f.write(response) 754 return metadata.content_length 755 756 # Download large files using transfer manager 757 def _invoke_api() -> int: 758 bucket_obj = self._gcs_client.bucket(bucket) 759 blob = bucket_obj.blob(key) 760 761 # transfer manager does not support downloading to a file object 762 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 763 temp_file_path = fp.name 764 transfer_manager.download_chunks_concurrently( 765 blob, 766 temp_file_path, 767 chunk_size=self._io_chunksize, 768 max_workers=self._max_concurrency, 769 worker_type=transfer_manager.THREAD, 770 ) 771 772 if isinstance(f, io.StringIO): 773 with open(temp_file_path, "r") as fp: 774 f.write(fp.read()) 775 else: 776 with open(temp_file_path, "rb") as fp: 777 f.write(fp.read()) 778 779 os.unlink(temp_file_path) 780 781 return metadata.content_length 782 783 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)