Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import codecs
 17import copy
 18import io
 19import json
 20import logging
 21import os
 22import tempfile
 23from collections.abc import Callable, Iterator
 24from typing import IO, Any, Optional, TypeVar, Union
 25
 26from google.api_core.exceptions import GoogleAPICallError, NotFound
 27from google.auth import credentials as auth_credentials
 28from google.auth import identity_pool
 29from google.cloud import storage
 30from google.cloud.storage import transfer_manager
 31from google.cloud.storage.exceptions import InvalidResponse
 32from google.oauth2 import service_account
 33from google.oauth2.credentials import Credentials as OAuth2Credentials
 34
 35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
 36
 37from ..rust_utils import run_async_rust_client_method
 38from ..telemetry import Telemetry
 39from ..types import (
 40    AWARE_DATETIME_MIN,
 41    Credentials,
 42    CredentialsProvider,
 43    NotModifiedError,
 44    ObjectMetadata,
 45    PreconditionFailedError,
 46    Range,
 47    RetryableError,
 48)
 49from ..utils import (
 50    split_path,
 51    validate_attributes,
 52)
 53from .base import BaseStorageProvider
 54
 55_T = TypeVar("_T")
 56
 57PROVIDER = "gcs"
 58
 59MiB = 1024 * 1024
 60
 61DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
 62DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
 63DEFAULT_IO_CHUNKSIZE = 32 * MiB
 64PYTHON_MAX_CONCURRENCY = 8
 65
 66logger = logging.getLogger(__name__)
 67
 68
[docs] 69class StringTokenSupplier(identity_pool.SubjectTokenSupplier): 70 """ 71 Supply a string token to the Google Identity Pool. 72 """ 73 74 def __init__(self, token: str): 75 self._token = token 76
[docs] 77 def get_subject_token(self, context, request): 78 return self._token
79 80
[docs] 81class GoogleIdentityPoolCredentialsProvider(CredentialsProvider): 82 """ 83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials. 84 """ 85 86 def __init__(self, audience: str, token_supplier: str): 87 """ 88 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier. 89 90 :param audience: The audience for the Google Identity Pool. 91 :param token_supplier: The token supplier for the Google Identity Pool. 92 """ 93 self._audience = audience 94 self._token_supplier = token_supplier 95
[docs] 96 def get_credentials(self) -> Credentials: 97 return Credentials( 98 access_key="", 99 secret_key="", 100 token="", 101 expiration=None, 102 custom_fields={"audience": self._audience, "token": self._token_supplier}, 103 )
104
[docs] 105 def refresh_credentials(self) -> None: 106 pass
107 108
[docs] 109class GoogleServiceAccountCredentialsProvider(CredentialsProvider): 110 """ 111 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials. 112 """ 113 114 #: Google service account private key file contents. 115 _info: dict[str, Any] 116 117 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None): 118 """ 119 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a 120 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file 121 or the file contents. 122 123 :param file: Path to a Google service account private key file. 124 :param info: Google service account private key file contents. 125 """ 126 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)): 127 raise ValueError("Must specify exactly one of file or info") 128 129 if file is not None: 130 with open(file, "r") as f: 131 self._info = json.load(f) 132 elif info is not None: 133 self._info = copy.deepcopy(info) 134
[docs] 135 def get_credentials(self) -> Credentials: 136 return Credentials( 137 access_key="", 138 secret_key="", 139 token=None, 140 expiration=None, 141 custom_fields={"info": copy.deepcopy(self._info)}, 142 )
143
[docs] 144 def refresh_credentials(self) -> None: 145 pass
146 147
[docs] 148class GoogleStorageProvider(BaseStorageProvider): 149 """ 150 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 151 """ 152 153 def __init__( 154 self, 155 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""), 156 endpoint_url: str = "", 157 base_path: str = "", 158 credentials_provider: Optional[CredentialsProvider] = None, 159 config_dict: Optional[dict[str, Any]] = None, 160 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 161 **kwargs: Any, 162 ): 163 """ 164 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 165 166 :param project_id: The Google Cloud project ID. 167 :param endpoint_url: The custom endpoint URL for the GCS service. 168 :param base_path: The root prefix path within the bucket where all operations will be scoped. 169 :param credentials_provider: The provider to retrieve GCS credentials. 170 :param config_dict: Resolved MSC config. 171 :param telemetry_provider: A function that provides a telemetry instance. 172 """ 173 super().__init__( 174 base_path=base_path, 175 provider_name=PROVIDER, 176 config_dict=config_dict, 177 telemetry_provider=telemetry_provider, 178 ) 179 180 self._project_id = project_id 181 self._endpoint_url = endpoint_url 182 self._credentials_provider = credentials_provider 183 self._skip_signature = kwargs.get("skip_signature", False) 184 self._gcs_client = self._create_gcs_client() 185 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD) 186 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE) 187 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE) 188 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY) 189 self._rust_client = None 190 if "rust_client" in kwargs: 191 # Inherit the rust client options from the kwargs 192 rust_client_options = copy.deepcopy(kwargs["rust_client"]) 193 if "max_concurrency" in kwargs: 194 rust_client_options["max_concurrency"] = kwargs["max_concurrency"] 195 if "multipart_chunksize" in kwargs: 196 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"] 197 if "read_timeout" in kwargs: 198 rust_client_options["read_timeout"] = kwargs["read_timeout"] 199 if "connect_timeout" in kwargs: 200 rust_client_options["connect_timeout"] = kwargs["connect_timeout"] 201 if "service_account_key" not in rust_client_options and isinstance( 202 self._credentials_provider, GoogleServiceAccountCredentialsProvider 203 ): 204 rust_client_options["service_account_key"] = json.dumps( 205 self._credentials_provider.get_credentials().get_custom_field("info") 206 ) 207 self._rust_client = self._create_rust_client(rust_client_options) 208 209 def _create_gcs_client(self) -> storage.Client: 210 client_options = {} 211 if self._endpoint_url: 212 client_options["api_endpoint"] = self._endpoint_url 213 214 # Use anonymous credentials for public buckets when skip_signature is enabled 215 if self._skip_signature: 216 return storage.Client( 217 project=self._project_id, 218 credentials=auth_credentials.AnonymousCredentials(), 219 client_options=client_options, 220 ) 221 222 if self._credentials_provider: 223 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider): 224 audience = self._credentials_provider.get_credentials().get_custom_field("audience") 225 token = self._credentials_provider.get_credentials().get_custom_field("token") 226 227 # Use Workload Identity Federation (WIF) 228 identity_pool_credentials = identity_pool.Credentials( 229 audience=audience, 230 subject_token_type="urn:ietf:params:oauth:token-type:id_token", 231 subject_token_supplier=StringTokenSupplier(token), 232 ) 233 return storage.Client( 234 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options 235 ) 236 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider): 237 # Use service account key. 238 service_account_credentials = service_account.Credentials.from_service_account_info( 239 info=self._credentials_provider.get_credentials().get_custom_field("info") 240 ) 241 return storage.Client( 242 project=self._project_id, credentials=service_account_credentials, client_options=client_options 243 ) 244 else: 245 # Use OAuth 2.0 token 246 token = self._credentials_provider.get_credentials().token 247 creds = OAuth2Credentials(token=token) 248 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options) 249 else: 250 return storage.Client(project=self._project_id, client_options=client_options) 251 252 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None): 253 if self._credentials_provider or self._endpoint_url: 254 # Workload Identity Federation (WIF) is not supported by the rust client: 255 # https://github.com/apache/arrow-rs-object-store/issues/258 256 logger.warning( 257 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client" 258 ) 259 return None 260 configs = dict(rust_client_options) if rust_client_options else {} 261 262 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): 263 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") 264 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"): 265 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY") 266 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"): 267 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT") 268 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"): 269 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH") 270 271 if self._skip_signature and "skip_signature" not in configs: 272 configs["skip_signature"] = True 273 274 if "bucket" not in configs: 275 bucket, _ = split_path(self._base_path) 276 configs["bucket"] = bucket 277 278 return RustClient( 279 provider=PROVIDER, 280 configs=configs, 281 ) 282 283 def _refresh_gcs_client_if_needed(self) -> None: 284 """ 285 Refreshes the GCS client if the current credentials are expired. 286 """ 287 if self._credentials_provider: 288 credentials = self._credentials_provider.get_credentials() 289 if credentials.is_expired(): 290 self._credentials_provider.refresh_credentials() 291 self._gcs_client = self._create_gcs_client() 292 293 def _translate_errors( 294 self, 295 func: Callable[[], _T], 296 operation: str, 297 bucket: str, 298 key: str, 299 ) -> _T: 300 """ 301 Translates errors like timeouts and client errors. 302 303 :param func: The function that performs the actual GCS operation. 304 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 305 :param bucket: The name of the GCS bucket involved in the operation. 306 :param key: The key of the object within the GCS bucket. 307 308 :return: The result of the GCS operation, typically the return value of the `func` callable. 309 """ 310 try: 311 return func() 312 except GoogleAPICallError as error: 313 status_code = error.code if error.code else -1 314 error_info = f"status_code: {status_code}, message: {error.message}" 315 if status_code == 404: 316 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 317 elif status_code == 412: 318 raise PreconditionFailedError( 319 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 320 ) from error 321 elif status_code == 304: 322 # for if_none_match with a specific etag condition. 323 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error 324 else: 325 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 326 except InvalidResponse as error: 327 response_text = error.response.text 328 error_details = f"error: {error}, error_response_text: {response_text}" 329 # Check for NoSuchUpload within the response text 330 if "NoSuchUpload" in response_text: 331 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error 332 else: 333 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error 334 except RustRetryableError as error: 335 raise RetryableError( 336 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. " 337 f"error_type: {type(error).__name__}" 338 ) from error 339 except RustClientError as error: 340 message = error.args[0] 341 status_code = error.args[1] 342 if status_code == 404: 343 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error 344 elif status_code == 403: 345 raise PermissionError( 346 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}" 347 ) from error 348 else: 349 raise RetryableError( 350 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}" 351 ) from error 352 except Exception as error: 353 error_details = str(error) 354 raise RuntimeError( 355 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}" 356 ) from error 357 358 def _put_object( 359 self, 360 path: str, 361 body: bytes, 362 if_match: Optional[str] = None, 363 if_none_match: Optional[str] = None, 364 attributes: Optional[dict[str, str]] = None, 365 ) -> int: 366 """ 367 Uploads an object to Google Cloud Storage. 368 369 :param path: The path to the object to upload. 370 :param body: The content of the object to upload. 371 :param if_match: Optional ETag to match against the object. 372 :param if_none_match: Optional ETag to match against the object. 373 :param attributes: Optional attributes to attach to the object. 374 """ 375 bucket, key = split_path(path) 376 self._refresh_gcs_client_if_needed() 377 378 def _invoke_api() -> int: 379 bucket_obj = self._gcs_client.bucket(bucket) 380 blob = bucket_obj.blob(key) 381 382 kwargs = {} 383 384 if if_match: 385 kwargs["if_generation_match"] = int(if_match) # 412 error code 386 if if_none_match: 387 if if_none_match == "*": 388 raise NotImplementedError("if_none_match='*' is not supported for GCS") 389 else: 390 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code 391 392 validated_attributes = validate_attributes(attributes) 393 if validated_attributes: 394 blob.metadata = validated_attributes 395 396 if ( 397 self._rust_client 398 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026 399 and not path.endswith("/") 400 and not kwargs 401 and not validated_attributes 402 ): 403 run_async_rust_client_method(self._rust_client, "put", key, body) 404 else: 405 blob.upload_from_string(body, **kwargs) 406 407 return len(body) 408 409 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 410 411 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 412 bucket, key = split_path(path) 413 self._refresh_gcs_client_if_needed() 414 415 def _invoke_api() -> bytes: 416 bucket_obj = self._gcs_client.bucket(bucket) 417 blob = bucket_obj.blob(key) 418 if byte_range: 419 if self._rust_client: 420 return run_async_rust_client_method( 421 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1 422 ) 423 else: 424 return blob.download_as_bytes( 425 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True 426 ) 427 else: 428 if self._rust_client: 429 return run_async_rust_client_method(self._rust_client, "get", key) 430 else: 431 return blob.download_as_bytes(single_shot_download=True) 432 433 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 434 435 def _copy_object(self, src_path: str, dest_path: str) -> int: 436 src_bucket, src_key = split_path(src_path) 437 dest_bucket, dest_key = split_path(dest_path) 438 self._refresh_gcs_client_if_needed() 439 440 src_object = self._get_object_metadata(src_path) 441 442 def _invoke_api() -> int: 443 source_bucket_obj = self._gcs_client.bucket(src_bucket) 444 source_blob = source_bucket_obj.blob(src_key) 445 446 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 447 destination_blob = destination_bucket_obj.blob(dest_key) 448 449 rewrite_tokens = [None] 450 while len(rewrite_tokens) > 0: 451 rewrite_token = rewrite_tokens.pop() 452 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token) 453 if next_rewrite_token is not None: 454 rewrite_tokens.append(next_rewrite_token) 455 456 return src_object.content_length 457 458 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 459 460 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 461 bucket, key = split_path(path) 462 self._refresh_gcs_client_if_needed() 463 464 def _invoke_api() -> None: 465 bucket_obj = self._gcs_client.bucket(bucket) 466 blob = bucket_obj.blob(key) 467 468 # If if_match is provided, use it as a precondition 469 if if_match: 470 generation = int(if_match) 471 blob.delete(if_generation_match=generation) 472 else: 473 # No if_match check needed, just delete 474 blob.delete() 475 476 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 477 478 def _is_dir(self, path: str) -> bool: 479 # Ensure the path ends with '/' to mimic a directory 480 path = self._append_delimiter(path) 481 482 bucket, key = split_path(path) 483 self._refresh_gcs_client_if_needed() 484 485 def _invoke_api() -> bool: 486 bucket_obj = self._gcs_client.bucket(bucket) 487 # List objects with the given prefix 488 blobs = bucket_obj.list_blobs( 489 prefix=key, 490 delimiter="/", 491 ) 492 # Check if there are any contents or common prefixes 493 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 494 495 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 496 497 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 498 bucket, key = split_path(path) 499 if path.endswith("/") or (bucket and not key): 500 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 501 # which metadata is not guaranteed to exist for cases such as 502 # "virtual prefix" that was never explicitly created. 503 if self._is_dir(path): 504 return ObjectMetadata( 505 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None 506 ) 507 else: 508 raise FileNotFoundError(f"Directory {path} does not exist.") 509 else: 510 self._refresh_gcs_client_if_needed() 511 512 def _invoke_api() -> ObjectMetadata: 513 bucket_obj = self._gcs_client.bucket(bucket) 514 blob = bucket_obj.get_blob(key) 515 if not blob: 516 raise NotFound(f"Blob {key} not found in bucket {bucket}") 517 return ObjectMetadata( 518 key=path, 519 content_length=blob.size or 0, 520 content_type=blob.content_type, 521 last_modified=blob.updated or AWARE_DATETIME_MIN, 522 etag=str(blob.generation), 523 metadata=dict(blob.metadata) if blob.metadata else None, 524 ) 525 526 try: 527 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 528 except FileNotFoundError as error: 529 if strict: 530 # If the object does not exist on the given path, we will append a trailing slash and 531 # check if the path is a directory. 532 path = self._append_delimiter(path) 533 if self._is_dir(path): 534 return ObjectMetadata( 535 key=path, 536 type="directory", 537 content_length=0, 538 last_modified=AWARE_DATETIME_MIN, 539 ) 540 raise error 541 542 def _list_objects( 543 self, 544 path: str, 545 start_after: Optional[str] = None, 546 end_at: Optional[str] = None, 547 include_directories: bool = False, 548 follow_symlinks: bool = True, 549 ) -> Iterator[ObjectMetadata]: 550 bucket, prefix = split_path(path) 551 552 # Get the prefix of the start_after and end_at paths relative to the bucket. 553 if start_after: 554 _, start_after = split_path(start_after) 555 if end_at: 556 _, end_at = split_path(end_at) 557 558 self._refresh_gcs_client_if_needed() 559 560 def _invoke_api() -> Iterator[ObjectMetadata]: 561 bucket_obj = self._gcs_client.bucket(bucket) 562 if include_directories: 563 blobs = bucket_obj.list_blobs( 564 prefix=prefix, 565 # This is ≥ instead of >. 566 start_offset=start_after, 567 delimiter="/", 568 ) 569 else: 570 blobs = bucket_obj.list_blobs( 571 prefix=prefix, 572 # This is ≥ instead of >. 573 start_offset=start_after, 574 ) 575 576 # GCS guarantees lexicographical order. 577 for blob in blobs: 578 key = blob.name 579 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 580 if key.endswith("/"): 581 if include_directories: 582 yield ObjectMetadata( 583 key=os.path.join(bucket, key.rstrip("/")), 584 type="directory", 585 content_length=0, 586 last_modified=blob.updated, 587 ) 588 else: 589 yield ObjectMetadata( 590 key=os.path.join(bucket, key), 591 content_length=blob.size, 592 content_type=blob.content_type, 593 last_modified=blob.updated, 594 etag=blob.etag, 595 ) 596 elif start_after != key: 597 return 598 599 # The directories must be accessed last. 600 if include_directories: 601 for directory in blobs.prefixes: 602 yield ObjectMetadata( 603 key=os.path.join(bucket, directory.rstrip("/")), 604 type="directory", 605 content_length=0, 606 last_modified=AWARE_DATETIME_MIN, 607 ) 608 609 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 610 611 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 612 bucket, key = split_path(remote_path) 613 file_size: int = 0 614 self._refresh_gcs_client_if_needed() 615 616 if isinstance(f, str): 617 file_size = os.path.getsize(f) 618 619 # Upload small files 620 if file_size <= self._multipart_threshold: 621 if self._rust_client and not attributes: 622 run_async_rust_client_method(self._rust_client, "upload", f, key) 623 else: 624 with open(f, "rb") as fp: 625 self._put_object(remote_path, fp.read(), attributes=attributes) 626 return file_size 627 628 # Upload large files using transfer manager 629 def _invoke_api() -> int: 630 if self._rust_client and not attributes: 631 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key) 632 else: 633 bucket_obj = self._gcs_client.bucket(bucket) 634 blob = bucket_obj.blob(key) 635 # GCS will raise an error if blob.metadata is None 636 validated_attributes = validate_attributes(attributes) 637 if validated_attributes is not None: 638 blob.metadata = validated_attributes 639 transfer_manager.upload_chunks_concurrently( 640 f, 641 blob, 642 chunk_size=self._multipart_chunksize, 643 max_workers=self._max_concurrency, 644 worker_type=transfer_manager.THREAD, 645 ) 646 647 return file_size 648 649 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 650 else: 651 f.seek(0, io.SEEK_END) 652 file_size = f.tell() 653 f.seek(0) 654 655 # Upload small files 656 if file_size <= self._multipart_threshold: 657 if isinstance(f, io.StringIO): 658 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes) 659 else: 660 self._put_object(remote_path, f.read(), attributes=attributes) 661 return file_size 662 663 # Upload large files using transfer manager 664 def _invoke_api() -> int: 665 bucket_obj = self._gcs_client.bucket(bucket) 666 blob = bucket_obj.blob(key) 667 validated_attributes = validate_attributes(attributes) 668 if validated_attributes: 669 blob.metadata = validated_attributes 670 if isinstance(f, io.StringIO): 671 mode = "w" 672 else: 673 mode = "wb" 674 675 # transfer manager does not support uploading a file object 676 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp: 677 temp_file_path = fp.name 678 fp.write(f.read()) 679 680 transfer_manager.upload_chunks_concurrently( 681 temp_file_path, 682 blob, 683 chunk_size=self._multipart_chunksize, 684 max_workers=self._max_concurrency, 685 worker_type=transfer_manager.THREAD, 686 ) 687 688 os.unlink(temp_file_path) 689 690 return file_size 691 692 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 693 694 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 695 self._refresh_gcs_client_if_needed() 696 697 if metadata is None: 698 metadata = self._get_object_metadata(remote_path) 699 700 bucket, key = split_path(remote_path) 701 702 if isinstance(f, str): 703 if os.path.dirname(f): 704 os.makedirs(os.path.dirname(f), exist_ok=True) 705 # Download small files 706 if metadata.content_length <= self._multipart_threshold: 707 if self._rust_client: 708 run_async_rust_client_method(self._rust_client, "download", key, f) 709 else: 710 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 711 temp_file_path = fp.name 712 fp.write(self._get_object(remote_path)) 713 os.rename(src=temp_file_path, dst=f) 714 return metadata.content_length 715 716 # Download large files using transfer manager 717 def _invoke_api() -> int: 718 bucket_obj = self._gcs_client.bucket(bucket) 719 blob = bucket_obj.blob(key) 720 if self._rust_client: 721 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f) 722 else: 723 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 724 temp_file_path = fp.name 725 transfer_manager.download_chunks_concurrently( 726 blob, 727 temp_file_path, 728 chunk_size=self._io_chunksize, 729 max_workers=self._max_concurrency, 730 worker_type=transfer_manager.THREAD, 731 ) 732 os.rename(src=temp_file_path, dst=f) 733 734 return metadata.content_length 735 736 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 737 else: 738 # Download small files 739 if metadata.content_length <= self._multipart_threshold: 740 response = self._get_object(remote_path) 741 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol, 742 # so we need to check whether `.decode()` is available. 743 if isinstance(f, io.StringIO): 744 if hasattr(response, "decode"): 745 f.write(response.decode("utf-8")) 746 else: 747 f.write(codecs.decode(memoryview(response), "utf-8")) 748 else: 749 f.write(response) 750 return metadata.content_length 751 752 # Download large files using transfer manager 753 def _invoke_api() -> int: 754 bucket_obj = self._gcs_client.bucket(bucket) 755 blob = bucket_obj.blob(key) 756 757 # transfer manager does not support downloading to a file object 758 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp: 759 temp_file_path = fp.name 760 transfer_manager.download_chunks_concurrently( 761 blob, 762 temp_file_path, 763 chunk_size=self._io_chunksize, 764 max_workers=self._max_concurrency, 765 worker_type=transfer_manager.THREAD, 766 ) 767 768 if isinstance(f, io.StringIO): 769 with open(temp_file_path, "r") as fp: 770 f.write(fp.read()) 771 else: 772 with open(temp_file_path, "rb") as fp: 773 f.write(fp.read()) 774 775 os.unlink(temp_file_path) 776 777 return metadata.content_length 778 779 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)