1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import copy
18import io
19import json
20import logging
21import os
22import tempfile
23from collections.abc import Callable, Iterator
24from typing import IO, Any, Optional, TypeVar, Union
25
26from google.api_core.exceptions import GoogleAPICallError, NotFound
27from google.auth import credentials as auth_credentials
28from google.auth import identity_pool
29from google.cloud import storage
30from google.cloud.storage import transfer_manager
31from google.cloud.storage.exceptions import InvalidResponse
32from google.oauth2 import service_account
33from google.oauth2.credentials import Credentials as OAuth2Credentials
34
35from multistorageclient_rust import RustClient, RustRetryableError
36
37from ..rust_utils import run_async_rust_client_method
38from ..telemetry import Telemetry
39from ..types import (
40 AWARE_DATETIME_MIN,
41 Credentials,
42 CredentialsProvider,
43 NotModifiedError,
44 ObjectMetadata,
45 PreconditionFailedError,
46 Range,
47 RetryableError,
48)
49from ..utils import (
50 split_path,
51 validate_attributes,
52)
53from .base import BaseStorageProvider
54
55_T = TypeVar("_T")
56
57PROVIDER = "gcs"
58
59MiB = 1024 * 1024
60
61DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
62DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
63DEFAULT_IO_CHUNKSIZE = 32 * MiB
64PYTHON_MAX_CONCURRENCY = 8
65
66logger = logging.getLogger(__name__)
67
68
[docs]
69class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
70 """
71 Supply a string token to the Google Identity Pool.
72 """
73
74 def __init__(self, token: str):
75 self._token = token
76
[docs]
77 def get_subject_token(self, context, request):
78 return self._token
79
80
[docs]
81class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
82 """
83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
84 """
85
86 def __init__(self, audience: str, token_supplier: str):
87 """
88 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
89
90 :param audience: The audience for the Google Identity Pool.
91 :param token_supplier: The token supplier for the Google Identity Pool.
92 """
93 self._audience = audience
94 self._token_supplier = token_supplier
95
[docs]
96 def get_credentials(self) -> Credentials:
97 return Credentials(
98 access_key="",
99 secret_key="",
100 token="",
101 expiration=None,
102 custom_fields={"audience": self._audience, "token": self._token_supplier},
103 )
104
[docs]
105 def refresh_credentials(self) -> None:
106 pass
107
108
[docs]
109class GoogleServiceAccountCredentialsProvider(CredentialsProvider):
110 """
111 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials.
112 """
113
114 #: Google service account private key file contents.
115 _info: dict[str, Any]
116
117 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None):
118 """
119 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a
120 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file
121 or the file contents.
122
123 :param file: Path to a Google service account private key file.
124 :param info: Google service account private key file contents.
125 """
126 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)):
127 raise ValueError("Must specify exactly one of file or info")
128
129 if file is not None:
130 with open(file, "r") as f:
131 self._info = json.load(f)
132 elif info is not None:
133 self._info = copy.deepcopy(info)
134
[docs]
135 def get_credentials(self) -> Credentials:
136 return Credentials(
137 access_key="",
138 secret_key="",
139 token=None,
140 expiration=None,
141 custom_fields={"info": copy.deepcopy(self._info)},
142 )
143
[docs]
144 def refresh_credentials(self) -> None:
145 pass
146
147
[docs]
148class GoogleStorageProvider(BaseStorageProvider):
149 """
150 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
151 """
152
153 def __init__(
154 self,
155 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
156 endpoint_url: str = "",
157 base_path: str = "",
158 credentials_provider: Optional[CredentialsProvider] = None,
159 config_dict: Optional[dict[str, Any]] = None,
160 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
161 **kwargs: Any,
162 ):
163 """
164 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
165
166 :param project_id: The Google Cloud project ID.
167 :param endpoint_url: The custom endpoint URL for the GCS service.
168 :param base_path: The root prefix path within the bucket where all operations will be scoped.
169 :param credentials_provider: The provider to retrieve GCS credentials.
170 :param config_dict: Resolved MSC config.
171 :param telemetry_provider: A function that provides a telemetry instance.
172 """
173 super().__init__(
174 base_path=base_path,
175 provider_name=PROVIDER,
176 config_dict=config_dict,
177 telemetry_provider=telemetry_provider,
178 )
179
180 self._project_id = project_id
181 self._endpoint_url = endpoint_url
182 self._credentials_provider = credentials_provider
183 self._skip_signature = kwargs.get("skip_signature", False)
184 self._gcs_client = self._create_gcs_client()
185 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
186 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
187 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE)
188 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)
189 self._rust_client = None
190 if "rust_client" in kwargs:
191 # Inherit the rust client options from the kwargs
192 rust_client_options = copy.deepcopy(kwargs["rust_client"])
193 if "max_concurrency" in kwargs:
194 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
195 if "multipart_chunksize" in kwargs:
196 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
197 if "read_timeout" in kwargs:
198 rust_client_options["read_timeout"] = kwargs["read_timeout"]
199 if "connect_timeout" in kwargs:
200 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
201 if "service_account_key" not in rust_client_options and isinstance(
202 self._credentials_provider, GoogleServiceAccountCredentialsProvider
203 ):
204 rust_client_options["service_account_key"] = json.dumps(
205 self._credentials_provider.get_credentials().get_custom_field("info")
206 )
207 self._rust_client = self._create_rust_client(rust_client_options)
208
209 def _create_gcs_client(self) -> storage.Client:
210 client_options = {}
211 if self._endpoint_url:
212 client_options["api_endpoint"] = self._endpoint_url
213
214 # Use anonymous credentials for public buckets when skip_signature is enabled
215 if self._skip_signature:
216 return storage.Client(
217 project=self._project_id,
218 credentials=auth_credentials.AnonymousCredentials(),
219 client_options=client_options,
220 )
221
222 if self._credentials_provider:
223 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
224 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
225 token = self._credentials_provider.get_credentials().get_custom_field("token")
226
227 # Use Workload Identity Federation (WIF)
228 identity_pool_credentials = identity_pool.Credentials(
229 audience=audience,
230 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
231 subject_token_supplier=StringTokenSupplier(token),
232 )
233 return storage.Client(
234 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
235 )
236 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
237 # Use service account key.
238 service_account_credentials = service_account.Credentials.from_service_account_info(
239 info=self._credentials_provider.get_credentials().get_custom_field("info")
240 )
241 return storage.Client(
242 project=self._project_id, credentials=service_account_credentials, client_options=client_options
243 )
244 else:
245 # Use OAuth 2.0 token
246 token = self._credentials_provider.get_credentials().token
247 creds = OAuth2Credentials(token=token)
248 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
249 else:
250 return storage.Client(project=self._project_id, client_options=client_options)
251
252 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
253 if self._credentials_provider or self._endpoint_url:
254 # Workload Identity Federation (WIF) is not supported by the rust client:
255 # https://github.com/apache/arrow-rs-object-store/issues/258
256 logger.warning(
257 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client"
258 )
259 return None
260 configs = dict(rust_client_options) if rust_client_options else {}
261
262 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
263 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
264 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"):
265 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY")
266 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"):
267 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT")
268 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"):
269 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH")
270
271 if self._skip_signature and "skip_signature" not in configs:
272 configs["skip_signature"] = True
273
274 if "bucket" not in configs:
275 bucket, _ = split_path(self._base_path)
276 configs["bucket"] = bucket
277
278 return RustClient(
279 provider=PROVIDER,
280 configs=configs,
281 )
282
283 def _refresh_gcs_client_if_needed(self) -> None:
284 """
285 Refreshes the GCS client if the current credentials are expired.
286 """
287 if self._credentials_provider:
288 credentials = self._credentials_provider.get_credentials()
289 if credentials.is_expired():
290 self._credentials_provider.refresh_credentials()
291 self._gcs_client = self._create_gcs_client()
292
293 def _translate_errors(
294 self,
295 func: Callable[[], _T],
296 operation: str,
297 bucket: str,
298 key: str,
299 ) -> _T:
300 """
301 Translates errors like timeouts and client errors.
302
303 :param func: The function that performs the actual GCS operation.
304 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
305 :param bucket: The name of the GCS bucket involved in the operation.
306 :param key: The key of the object within the GCS bucket.
307
308 :return: The result of the GCS operation, typically the return value of the `func` callable.
309 """
310 try:
311 return func()
312 except GoogleAPICallError as error:
313 status_code = error.code if error.code else -1
314 error_info = f"status_code: {status_code}, message: {error.message}"
315 if status_code == 404:
316 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
317 elif status_code == 412:
318 raise PreconditionFailedError(
319 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
320 ) from error
321 elif status_code == 304:
322 # for if_none_match with a specific etag condition.
323 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
324 else:
325 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
326 except InvalidResponse as error:
327 response_text = error.response.text
328 error_details = f"error: {error}, error_response_text: {response_text}"
329 # Check for NoSuchUpload within the response text
330 if "NoSuchUpload" in response_text:
331 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
332 else:
333 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
334 except RustRetryableError as error:
335 raise RetryableError(
336 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
337 f"error_type: {type(error).__name__}"
338 ) from error
339 except FileNotFoundError:
340 raise
341 except Exception as error:
342 error_details = str(error)
343 raise RuntimeError(
344 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
345 ) from error
346
347 def _put_object(
348 self,
349 path: str,
350 body: bytes,
351 if_match: Optional[str] = None,
352 if_none_match: Optional[str] = None,
353 attributes: Optional[dict[str, str]] = None,
354 ) -> int:
355 """
356 Uploads an object to Google Cloud Storage.
357
358 :param path: The path to the object to upload.
359 :param body: The content of the object to upload.
360 :param if_match: Optional ETag to match against the object.
361 :param if_none_match: Optional ETag to match against the object.
362 :param attributes: Optional attributes to attach to the object.
363 """
364 bucket, key = split_path(path)
365 self._refresh_gcs_client_if_needed()
366
367 def _invoke_api() -> int:
368 bucket_obj = self._gcs_client.bucket(bucket)
369 blob = bucket_obj.blob(key)
370
371 kwargs = {}
372
373 if if_match:
374 kwargs["if_generation_match"] = int(if_match) # 412 error code
375 if if_none_match:
376 if if_none_match == "*":
377 raise NotImplementedError("if_none_match='*' is not supported for GCS")
378 else:
379 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
380
381 validated_attributes = validate_attributes(attributes)
382 if validated_attributes:
383 blob.metadata = validated_attributes
384
385 if (
386 self._rust_client
387 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
388 and not path.endswith("/")
389 and not kwargs
390 and not validated_attributes
391 ):
392 run_async_rust_client_method(self._rust_client, "put", key, body)
393 else:
394 blob.upload_from_string(body, **kwargs)
395
396 return len(body)
397
398 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
399
400 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
401 bucket, key = split_path(path)
402 self._refresh_gcs_client_if_needed()
403
404 def _invoke_api() -> bytes:
405 bucket_obj = self._gcs_client.bucket(bucket)
406 blob = bucket_obj.blob(key)
407 if byte_range:
408 if self._rust_client:
409 return run_async_rust_client_method(
410 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
411 )
412 else:
413 return blob.download_as_bytes(
414 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True
415 )
416 else:
417 if self._rust_client:
418 return run_async_rust_client_method(self._rust_client, "get", key)
419 else:
420 return blob.download_as_bytes(single_shot_download=True)
421
422 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
423
424 def _copy_object(self, src_path: str, dest_path: str) -> int:
425 src_bucket, src_key = split_path(src_path)
426 dest_bucket, dest_key = split_path(dest_path)
427 self._refresh_gcs_client_if_needed()
428
429 src_object = self._get_object_metadata(src_path)
430
431 def _invoke_api() -> int:
432 source_bucket_obj = self._gcs_client.bucket(src_bucket)
433 source_blob = source_bucket_obj.blob(src_key)
434
435 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
436 destination_blob = destination_bucket_obj.blob(dest_key)
437
438 rewrite_tokens = [None]
439 while len(rewrite_tokens) > 0:
440 rewrite_token = rewrite_tokens.pop()
441 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
442 if next_rewrite_token is not None:
443 rewrite_tokens.append(next_rewrite_token)
444
445 return src_object.content_length
446
447 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
448
449 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
450 bucket, key = split_path(path)
451 self._refresh_gcs_client_if_needed()
452
453 def _invoke_api() -> None:
454 bucket_obj = self._gcs_client.bucket(bucket)
455 blob = bucket_obj.blob(key)
456
457 # If if_match is provided, use it as a precondition
458 if if_match:
459 generation = int(if_match)
460 blob.delete(if_generation_match=generation)
461 else:
462 # No if_match check needed, just delete
463 blob.delete()
464
465 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
466
467 def _is_dir(self, path: str) -> bool:
468 # Ensure the path ends with '/' to mimic a directory
469 path = self._append_delimiter(path)
470
471 bucket, key = split_path(path)
472 self._refresh_gcs_client_if_needed()
473
474 def _invoke_api() -> bool:
475 bucket_obj = self._gcs_client.bucket(bucket)
476 # List objects with the given prefix
477 blobs = bucket_obj.list_blobs(
478 prefix=key,
479 delimiter="/",
480 )
481 # Check if there are any contents or common prefixes
482 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
483
484 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
485
486 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
487 bucket, key = split_path(path)
488 if path.endswith("/") or (bucket and not key):
489 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
490 # which metadata is not guaranteed to exist for cases such as
491 # "virtual prefix" that was never explicitly created.
492 if self._is_dir(path):
493 return ObjectMetadata(
494 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
495 )
496 else:
497 raise FileNotFoundError(f"Directory {path} does not exist.")
498 else:
499 self._refresh_gcs_client_if_needed()
500
501 def _invoke_api() -> ObjectMetadata:
502 bucket_obj = self._gcs_client.bucket(bucket)
503 blob = bucket_obj.get_blob(key)
504 if not blob:
505 raise NotFound(f"Blob {key} not found in bucket {bucket}")
506 return ObjectMetadata(
507 key=path,
508 content_length=blob.size or 0,
509 content_type=blob.content_type,
510 last_modified=blob.updated or AWARE_DATETIME_MIN,
511 etag=str(blob.generation),
512 metadata=dict(blob.metadata) if blob.metadata else None,
513 )
514
515 try:
516 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
517 except FileNotFoundError as error:
518 if strict:
519 # If the object does not exist on the given path, we will append a trailing slash and
520 # check if the path is a directory.
521 path = self._append_delimiter(path)
522 if self._is_dir(path):
523 return ObjectMetadata(
524 key=path,
525 type="directory",
526 content_length=0,
527 last_modified=AWARE_DATETIME_MIN,
528 )
529 raise error
530
531 def _list_objects(
532 self,
533 path: str,
534 start_after: Optional[str] = None,
535 end_at: Optional[str] = None,
536 include_directories: bool = False,
537 follow_symlinks: bool = True,
538 ) -> Iterator[ObjectMetadata]:
539 bucket, prefix = split_path(path)
540
541 # Get the prefix of the start_after and end_at paths relative to the bucket.
542 if start_after:
543 _, start_after = split_path(start_after)
544 if end_at:
545 _, end_at = split_path(end_at)
546
547 self._refresh_gcs_client_if_needed()
548
549 def _invoke_api() -> Iterator[ObjectMetadata]:
550 bucket_obj = self._gcs_client.bucket(bucket)
551 if include_directories:
552 blobs = bucket_obj.list_blobs(
553 prefix=prefix,
554 # This is ≥ instead of >.
555 start_offset=start_after,
556 delimiter="/",
557 )
558 else:
559 blobs = bucket_obj.list_blobs(
560 prefix=prefix,
561 # This is ≥ instead of >.
562 start_offset=start_after,
563 )
564
565 # GCS guarantees lexicographical order.
566 for blob in blobs:
567 key = blob.name
568 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
569 if key.endswith("/"):
570 if include_directories:
571 yield ObjectMetadata(
572 key=os.path.join(bucket, key.rstrip("/")),
573 type="directory",
574 content_length=0,
575 last_modified=blob.updated,
576 )
577 else:
578 yield ObjectMetadata(
579 key=os.path.join(bucket, key),
580 content_length=blob.size,
581 content_type=blob.content_type,
582 last_modified=blob.updated,
583 etag=blob.etag,
584 )
585 elif start_after != key:
586 return
587
588 # The directories must be accessed last.
589 if include_directories:
590 for directory in blobs.prefixes:
591 yield ObjectMetadata(
592 key=os.path.join(bucket, directory.rstrip("/")),
593 type="directory",
594 content_length=0,
595 last_modified=AWARE_DATETIME_MIN,
596 )
597
598 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
599
600 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
601 bucket, key = split_path(remote_path)
602 file_size: int = 0
603 self._refresh_gcs_client_if_needed()
604
605 if isinstance(f, str):
606 file_size = os.path.getsize(f)
607
608 # Upload small files
609 if file_size <= self._multipart_threshold:
610 if self._rust_client and not attributes:
611 run_async_rust_client_method(self._rust_client, "upload", f, key)
612 else:
613 with open(f, "rb") as fp:
614 self._put_object(remote_path, fp.read(), attributes=attributes)
615 return file_size
616
617 # Upload large files using transfer manager
618 def _invoke_api() -> int:
619 if self._rust_client and not attributes:
620 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
621 else:
622 bucket_obj = self._gcs_client.bucket(bucket)
623 blob = bucket_obj.blob(key)
624 # GCS will raise an error if blob.metadata is None
625 validated_attributes = validate_attributes(attributes)
626 if validated_attributes is not None:
627 blob.metadata = validated_attributes
628 transfer_manager.upload_chunks_concurrently(
629 f,
630 blob,
631 chunk_size=self._multipart_chunksize,
632 max_workers=self._max_concurrency,
633 worker_type=transfer_manager.THREAD,
634 )
635
636 return file_size
637
638 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
639 else:
640 f.seek(0, io.SEEK_END)
641 file_size = f.tell()
642 f.seek(0)
643
644 # Upload small files
645 if file_size <= self._multipart_threshold:
646 if isinstance(f, io.StringIO):
647 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
648 else:
649 self._put_object(remote_path, f.read(), attributes=attributes)
650 return file_size
651
652 # Upload large files using transfer manager
653 def _invoke_api() -> int:
654 bucket_obj = self._gcs_client.bucket(bucket)
655 blob = bucket_obj.blob(key)
656 validated_attributes = validate_attributes(attributes)
657 if validated_attributes:
658 blob.metadata = validated_attributes
659 if isinstance(f, io.StringIO):
660 mode = "w"
661 else:
662 mode = "wb"
663
664 # transfer manager does not support uploading a file object
665 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
666 temp_file_path = fp.name
667 fp.write(f.read())
668
669 transfer_manager.upload_chunks_concurrently(
670 temp_file_path,
671 blob,
672 chunk_size=self._multipart_chunksize,
673 max_workers=self._max_concurrency,
674 worker_type=transfer_manager.THREAD,
675 )
676
677 os.unlink(temp_file_path)
678
679 return file_size
680
681 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
682
683 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
684 self._refresh_gcs_client_if_needed()
685
686 if metadata is None:
687 metadata = self._get_object_metadata(remote_path)
688
689 bucket, key = split_path(remote_path)
690
691 if isinstance(f, str):
692 if os.path.dirname(f):
693 os.makedirs(os.path.dirname(f), exist_ok=True)
694 # Download small files
695 if metadata.content_length <= self._multipart_threshold:
696 if self._rust_client:
697 run_async_rust_client_method(self._rust_client, "download", key, f)
698 else:
699 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
700 temp_file_path = fp.name
701 fp.write(self._get_object(remote_path))
702 os.rename(src=temp_file_path, dst=f)
703 return metadata.content_length
704
705 # Download large files using transfer manager
706 def _invoke_api() -> int:
707 bucket_obj = self._gcs_client.bucket(bucket)
708 blob = bucket_obj.blob(key)
709 if self._rust_client:
710 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f)
711 else:
712 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
713 temp_file_path = fp.name
714 transfer_manager.download_chunks_concurrently(
715 blob,
716 temp_file_path,
717 chunk_size=self._io_chunksize,
718 max_workers=self._max_concurrency,
719 worker_type=transfer_manager.THREAD,
720 )
721 os.rename(src=temp_file_path, dst=f)
722
723 return metadata.content_length
724
725 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
726 else:
727 # Download small files
728 if metadata.content_length <= self._multipart_threshold:
729 response = self._get_object(remote_path)
730 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
731 # so we need to check whether `.decode()` is available.
732 if isinstance(f, io.StringIO):
733 if hasattr(response, "decode"):
734 f.write(response.decode("utf-8"))
735 else:
736 f.write(codecs.decode(memoryview(response), "utf-8"))
737 else:
738 f.write(response)
739 return metadata.content_length
740
741 # Download large files using transfer manager
742 def _invoke_api() -> int:
743 bucket_obj = self._gcs_client.bucket(bucket)
744 blob = bucket_obj.blob(key)
745
746 # transfer manager does not support downloading to a file object
747 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
748 temp_file_path = fp.name
749 transfer_manager.download_chunks_concurrently(
750 blob,
751 temp_file_path,
752 chunk_size=self._io_chunksize,
753 max_workers=self._max_concurrency,
754 worker_type=transfer_manager.THREAD,
755 )
756
757 if isinstance(f, io.StringIO):
758 with open(temp_file_path, "r") as fp:
759 f.write(fp.read())
760 else:
761 with open(temp_file_path, "rb") as fp:
762 f.write(fp.read())
763
764 os.unlink(temp_file_path)
765
766 return metadata.content_length
767
768 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)