1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import logging
19import os
20import tempfile
21from collections.abc import Callable, Iterator
22from typing import IO, Any, Optional, TypeVar, Union
23
24from google.api_core.exceptions import GoogleAPICallError, NotFound
25from google.auth import credentials as auth_credentials
26from google.auth import identity_pool
27from google.cloud import storage
28from google.cloud.storage import transfer_manager
29from google.cloud.storage.exceptions import InvalidResponse
30from google.oauth2.credentials import Credentials as OAuth2Credentials
31
32from multistorageclient_rust import RustClient, RustRetryableError
33
34from ..rust_utils import run_async_rust_client_method
35from ..telemetry import Telemetry
36from ..types import (
37 AWARE_DATETIME_MIN,
38 Credentials,
39 CredentialsProvider,
40 NotModifiedError,
41 ObjectMetadata,
42 PreconditionFailedError,
43 Range,
44 RetryableError,
45)
46from ..utils import (
47 split_path,
48 validate_attributes,
49)
50from .base import BaseStorageProvider
51
52_T = TypeVar("_T")
53
54PROVIDER = "gcs"
55
56MiB = 1024 * 1024
57
58DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
59DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
60DEFAULT_IO_CHUNKSIZE = 32 * MiB
61PYTHON_MAX_CONCURRENCY = 8
62
63logger = logging.getLogger(__name__)
64
65
[docs]
66class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
67 """
68 Supply a string token to the Google Identity Pool.
69 """
70
71 def __init__(self, token: str):
72 self._token = token
73
[docs]
74 def get_subject_token(self, context, request):
75 return self._token
76
77
[docs]
78class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
79 """
80 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
81 """
82
83 def __init__(self, audience: str, token_supplier: str):
84 """
85 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
86
87 :param audience: The audience for the Google Identity Pool.
88 :param token_supplier: The token supplier for the Google Identity Pool.
89 """
90 self._audience = audience
91 self._token_supplier = token_supplier
92
[docs]
93 def get_credentials(self) -> Credentials:
94 return Credentials(
95 access_key="",
96 secret_key="",
97 token="",
98 expiration=None,
99 custom_fields={"audience": self._audience, "token": self._token_supplier},
100 )
101
[docs]
102 def refresh_credentials(self) -> None:
103 pass
104
105
[docs]
106class GoogleStorageProvider(BaseStorageProvider):
107 """
108 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
109 """
110
111 def __init__(
112 self,
113 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
114 endpoint_url: str = "",
115 base_path: str = "",
116 credentials_provider: Optional[CredentialsProvider] = None,
117 config_dict: Optional[dict[str, Any]] = None,
118 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
119 **kwargs: Any,
120 ):
121 """
122 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
123
124 :param project_id: The Google Cloud project ID.
125 :param endpoint_url: The custom endpoint URL for the GCS service.
126 :param base_path: The root prefix path within the bucket where all operations will be scoped.
127 :param credentials_provider: The provider to retrieve GCS credentials.
128 :param config_dict: Resolved MSC config.
129 :param telemetry_provider: A function that provides a telemetry instance.
130 """
131 super().__init__(
132 base_path=base_path,
133 provider_name=PROVIDER,
134 config_dict=config_dict,
135 telemetry_provider=telemetry_provider,
136 )
137
138 self._project_id = project_id
139 self._endpoint_url = endpoint_url
140 self._credentials_provider = credentials_provider
141 self._skip_signature = kwargs.get("skip_signature", False)
142 self._gcs_client = self._create_gcs_client()
143 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
144 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
145 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE)
146 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)
147 self._rust_client = None
148 if "rust_client" in kwargs:
149 # Inherit the rust client options from the kwargs
150 rust_client_options = kwargs["rust_client"]
151 if "max_concurrency" in kwargs:
152 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
153 if "multipart_chunksize" in kwargs:
154 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
155 if "read_timeout" in kwargs:
156 rust_client_options["read_timeout"] = kwargs["read_timeout"]
157 if "connect_timeout" in kwargs:
158 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
159 self._rust_client = self._create_rust_client(rust_client_options)
160
161 def _create_gcs_client(self) -> storage.Client:
162 client_options = {}
163 if self._endpoint_url:
164 client_options["api_endpoint"] = self._endpoint_url
165
166 # Use anonymous credentials for public buckets when skip_signature is enabled
167 if self._skip_signature:
168 return storage.Client(
169 project=self._project_id,
170 credentials=auth_credentials.AnonymousCredentials(),
171 client_options=client_options,
172 )
173
174 if self._credentials_provider:
175 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
176 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
177 token = self._credentials_provider.get_credentials().get_custom_field("token")
178
179 # Use Workload Identity Federation (WIF)
180 identity_pool_credentials = identity_pool.Credentials(
181 audience=audience,
182 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
183 subject_token_supplier=StringTokenSupplier(token),
184 )
185 return storage.Client(
186 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
187 )
188 else:
189 # Use OAuth 2.0 token
190 token = self._credentials_provider.get_credentials().token
191 creds = OAuth2Credentials(token=token)
192 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
193 else:
194 return storage.Client(project=self._project_id, client_options=client_options)
195
196 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
197 configs = {}
198
199 if self._credentials_provider or self._endpoint_url:
200 # Workload Identity Federation (WIF) is not supported by the rust client:
201 # https://github.com/apache/arrow-rs-object-store/issues/258
202 logger.warning(
203 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client"
204 )
205 return None
206
207 # Use environment variables if available, could be overridden by rust_client_options
208 if os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
209 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
210 if os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"):
211 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY")
212 if os.getenv("GOOGLE_SERVICE_ACCOUNT"):
213 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT")
214 if os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"):
215 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH")
216
217 if self._skip_signature:
218 configs["skip_signature"] = True
219
220 if rust_client_options:
221 if "application_credentials" in rust_client_options:
222 configs["application_credentials"] = rust_client_options["application_credentials"]
223 if "service_account_key" in rust_client_options:
224 configs["service_account_key"] = rust_client_options["service_account_key"]
225 if "service_account_path" in rust_client_options:
226 configs["service_account_path"] = rust_client_options["service_account_path"]
227 if "skip_signature" in rust_client_options:
228 configs["skip_signature"] = rust_client_options["skip_signature"]
229 if "max_concurrency" in rust_client_options:
230 configs["max_concurrency"] = rust_client_options["max_concurrency"]
231 if "multipart_chunksize" in rust_client_options:
232 configs["multipart_chunksize"] = rust_client_options["multipart_chunksize"]
233 if "read_timeout" in rust_client_options:
234 configs["read_timeout"] = rust_client_options["read_timeout"]
235 if "connect_timeout" in rust_client_options:
236 configs["connect_timeout"] = rust_client_options["connect_timeout"]
237
238 if rust_client_options and "bucket" in rust_client_options:
239 configs["bucket"] = rust_client_options["bucket"]
240 else:
241 bucket, _ = split_path(self._base_path)
242 configs["bucket"] = bucket
243
244 return RustClient(
245 provider=PROVIDER,
246 configs=configs,
247 )
248
249 def _refresh_gcs_client_if_needed(self) -> None:
250 """
251 Refreshes the GCS client if the current credentials are expired.
252 """
253 if self._credentials_provider:
254 credentials = self._credentials_provider.get_credentials()
255 if credentials.is_expired():
256 self._credentials_provider.refresh_credentials()
257 self._gcs_client = self._create_gcs_client()
258
259 def _translate_errors(
260 self,
261 func: Callable[[], _T],
262 operation: str,
263 bucket: str,
264 key: str,
265 ) -> _T:
266 """
267 Translates errors like timeouts and client errors.
268
269 :param func: The function that performs the actual GCS operation.
270 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
271 :param bucket: The name of the GCS bucket involved in the operation.
272 :param key: The key of the object within the GCS bucket.
273
274 :return: The result of the GCS operation, typically the return value of the `func` callable.
275 """
276 try:
277 return func()
278 except GoogleAPICallError as error:
279 status_code = error.code if error.code else -1
280 error_info = f"status_code: {status_code}, message: {error.message}"
281 if status_code == 404:
282 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
283 elif status_code == 412:
284 raise PreconditionFailedError(
285 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
286 ) from error
287 elif status_code == 304:
288 # for if_none_match with a specific etag condition.
289 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
290 else:
291 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
292 except InvalidResponse as error:
293 response_text = error.response.text
294 error_details = f"error: {error}, error_response_text: {response_text}"
295 # Check for NoSuchUpload within the response text
296 if "NoSuchUpload" in response_text:
297 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
298 else:
299 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
300 except RustRetryableError as error:
301 raise RetryableError(
302 f"Failed to {operation} object(s) at {bucket}/{key} due to exhausted retries from Rust. "
303 f"error_type: {type(error).__name__}"
304 ) from error
305 except FileNotFoundError:
306 raise
307 except Exception as error:
308 error_details = str(error)
309 raise RuntimeError(
310 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
311 ) from error
312
313 def _put_object(
314 self,
315 path: str,
316 body: bytes,
317 if_match: Optional[str] = None,
318 if_none_match: Optional[str] = None,
319 attributes: Optional[dict[str, str]] = None,
320 ) -> int:
321 """
322 Uploads an object to Google Cloud Storage.
323
324 :param path: The path to the object to upload.
325 :param body: The content of the object to upload.
326 :param if_match: Optional ETag to match against the object.
327 :param if_none_match: Optional ETag to match against the object.
328 :param attributes: Optional attributes to attach to the object.
329 """
330 bucket, key = split_path(path)
331 self._refresh_gcs_client_if_needed()
332
333 def _invoke_api() -> int:
334 bucket_obj = self._gcs_client.bucket(bucket)
335 blob = bucket_obj.blob(key)
336
337 kwargs = {}
338
339 if if_match:
340 kwargs["if_generation_match"] = int(if_match) # 412 error code
341 if if_none_match:
342 if if_none_match == "*":
343 raise NotImplementedError("if_none_match='*' is not supported for GCS")
344 else:
345 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
346
347 validated_attributes = validate_attributes(attributes)
348 if validated_attributes:
349 blob.metadata = validated_attributes
350
351 if (
352 self._rust_client
353 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
354 and not path.endswith("/")
355 and not kwargs
356 and not validated_attributes
357 ):
358 run_async_rust_client_method(self._rust_client, "put", key, body)
359 else:
360 blob.upload_from_string(body, **kwargs)
361
362 return len(body)
363
364 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
365
366 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
367 bucket, key = split_path(path)
368 self._refresh_gcs_client_if_needed()
369
370 def _invoke_api() -> bytes:
371 bucket_obj = self._gcs_client.bucket(bucket)
372 blob = bucket_obj.blob(key)
373 if byte_range:
374 if self._rust_client:
375 return run_async_rust_client_method(
376 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
377 )
378 else:
379 return blob.download_as_bytes(
380 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True
381 )
382 else:
383 if self._rust_client:
384 return run_async_rust_client_method(self._rust_client, "get", key)
385 else:
386 return blob.download_as_bytes(single_shot_download=True)
387
388 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
389
390 def _copy_object(self, src_path: str, dest_path: str) -> int:
391 src_bucket, src_key = split_path(src_path)
392 dest_bucket, dest_key = split_path(dest_path)
393 self._refresh_gcs_client_if_needed()
394
395 src_object = self._get_object_metadata(src_path)
396
397 def _invoke_api() -> int:
398 source_bucket_obj = self._gcs_client.bucket(src_bucket)
399 source_blob = source_bucket_obj.blob(src_key)
400
401 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
402 destination_blob = destination_bucket_obj.blob(dest_key)
403
404 rewrite_tokens = [None]
405 while len(rewrite_tokens) > 0:
406 rewrite_token = rewrite_tokens.pop()
407 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
408 if next_rewrite_token is not None:
409 rewrite_tokens.append(next_rewrite_token)
410
411 return src_object.content_length
412
413 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
414
415 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
416 bucket, key = split_path(path)
417 self._refresh_gcs_client_if_needed()
418
419 def _invoke_api() -> None:
420 bucket_obj = self._gcs_client.bucket(bucket)
421 blob = bucket_obj.blob(key)
422
423 # If if_match is provided, use it as a precondition
424 if if_match:
425 generation = int(if_match)
426 blob.delete(if_generation_match=generation)
427 else:
428 # No if_match check needed, just delete
429 blob.delete()
430
431 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
432
433 def _is_dir(self, path: str) -> bool:
434 # Ensure the path ends with '/' to mimic a directory
435 path = self._append_delimiter(path)
436
437 bucket, key = split_path(path)
438 self._refresh_gcs_client_if_needed()
439
440 def _invoke_api() -> bool:
441 bucket_obj = self._gcs_client.bucket(bucket)
442 # List objects with the given prefix
443 blobs = bucket_obj.list_blobs(
444 prefix=key,
445 delimiter="/",
446 )
447 # Check if there are any contents or common prefixes
448 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
449
450 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
451
452 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
453 bucket, key = split_path(path)
454 if path.endswith("/") or (bucket and not key):
455 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
456 # which metadata is not guaranteed to exist for cases such as
457 # "virtual prefix" that was never explicitly created.
458 if self._is_dir(path):
459 return ObjectMetadata(
460 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
461 )
462 else:
463 raise FileNotFoundError(f"Directory {path} does not exist.")
464 else:
465 self._refresh_gcs_client_if_needed()
466
467 def _invoke_api() -> ObjectMetadata:
468 bucket_obj = self._gcs_client.bucket(bucket)
469 blob = bucket_obj.get_blob(key)
470 if not blob:
471 raise NotFound(f"Blob {key} not found in bucket {bucket}")
472 return ObjectMetadata(
473 key=path,
474 content_length=blob.size or 0,
475 content_type=blob.content_type,
476 last_modified=blob.updated or AWARE_DATETIME_MIN,
477 etag=str(blob.generation),
478 metadata=dict(blob.metadata) if blob.metadata else None,
479 )
480
481 try:
482 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
483 except FileNotFoundError as error:
484 if strict:
485 # If the object does not exist on the given path, we will append a trailing slash and
486 # check if the path is a directory.
487 path = self._append_delimiter(path)
488 if self._is_dir(path):
489 return ObjectMetadata(
490 key=path,
491 type="directory",
492 content_length=0,
493 last_modified=AWARE_DATETIME_MIN,
494 )
495 raise error
496
497 def _list_objects(
498 self,
499 path: str,
500 start_after: Optional[str] = None,
501 end_at: Optional[str] = None,
502 include_directories: bool = False,
503 follow_symlinks: bool = True,
504 ) -> Iterator[ObjectMetadata]:
505 bucket, prefix = split_path(path)
506 self._refresh_gcs_client_if_needed()
507
508 def _invoke_api() -> Iterator[ObjectMetadata]:
509 bucket_obj = self._gcs_client.bucket(bucket)
510 if include_directories:
511 blobs = bucket_obj.list_blobs(
512 prefix=prefix,
513 # This is ≥ instead of >.
514 start_offset=start_after,
515 delimiter="/",
516 )
517 else:
518 blobs = bucket_obj.list_blobs(
519 prefix=prefix,
520 # This is ≥ instead of >.
521 start_offset=start_after,
522 )
523
524 # GCS guarantees lexicographical order.
525 for blob in blobs:
526 key = blob.name
527 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
528 if key.endswith("/"):
529 if include_directories:
530 yield ObjectMetadata(
531 key=os.path.join(bucket, key.rstrip("/")),
532 type="directory",
533 content_length=0,
534 last_modified=blob.updated,
535 )
536 else:
537 yield ObjectMetadata(
538 key=os.path.join(bucket, key),
539 content_length=blob.size,
540 content_type=blob.content_type,
541 last_modified=blob.updated,
542 etag=blob.etag,
543 )
544 elif start_after != key:
545 return
546
547 # The directories must be accessed last.
548 if include_directories:
549 for directory in blobs.prefixes:
550 yield ObjectMetadata(
551 key=os.path.join(bucket, directory.rstrip("/")),
552 type="directory",
553 content_length=0,
554 last_modified=AWARE_DATETIME_MIN,
555 )
556
557 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
558
559 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
560 bucket, key = split_path(remote_path)
561 file_size: int = 0
562 self._refresh_gcs_client_if_needed()
563
564 if isinstance(f, str):
565 file_size = os.path.getsize(f)
566
567 # Upload small files
568 if file_size <= self._multipart_threshold:
569 if self._rust_client and not attributes:
570 run_async_rust_client_method(self._rust_client, "upload", f, key)
571 else:
572 with open(f, "rb") as fp:
573 self._put_object(remote_path, fp.read(), attributes=attributes)
574 return file_size
575
576 # Upload large files using transfer manager
577 def _invoke_api() -> int:
578 if self._rust_client and not attributes:
579 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
580 else:
581 bucket_obj = self._gcs_client.bucket(bucket)
582 blob = bucket_obj.blob(key)
583 # GCS will raise an error if blob.metadata is None
584 validated_attributes = validate_attributes(attributes)
585 if validated_attributes is not None:
586 blob.metadata = validated_attributes
587 transfer_manager.upload_chunks_concurrently(
588 f,
589 blob,
590 chunk_size=self._multipart_chunksize,
591 max_workers=self._max_concurrency,
592 worker_type=transfer_manager.THREAD,
593 )
594
595 return file_size
596
597 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
598 else:
599 f.seek(0, io.SEEK_END)
600 file_size = f.tell()
601 f.seek(0)
602
603 # Upload small files
604 if file_size <= self._multipart_threshold:
605 if isinstance(f, io.StringIO):
606 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
607 else:
608 self._put_object(remote_path, f.read(), attributes=attributes)
609 return file_size
610
611 # Upload large files using transfer manager
612 def _invoke_api() -> int:
613 bucket_obj = self._gcs_client.bucket(bucket)
614 blob = bucket_obj.blob(key)
615 validated_attributes = validate_attributes(attributes)
616 if validated_attributes:
617 blob.metadata = validated_attributes
618 if isinstance(f, io.StringIO):
619 mode = "w"
620 else:
621 mode = "wb"
622
623 # transfer manager does not support uploading a file object
624 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
625 temp_file_path = fp.name
626 fp.write(f.read())
627
628 transfer_manager.upload_chunks_concurrently(
629 temp_file_path,
630 blob,
631 chunk_size=self._multipart_chunksize,
632 max_workers=self._max_concurrency,
633 worker_type=transfer_manager.THREAD,
634 )
635
636 os.unlink(temp_file_path)
637
638 return file_size
639
640 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
641
642 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
643 self._refresh_gcs_client_if_needed()
644
645 if metadata is None:
646 metadata = self._get_object_metadata(remote_path)
647
648 bucket, key = split_path(remote_path)
649
650 if isinstance(f, str):
651 if os.path.dirname(f):
652 os.makedirs(os.path.dirname(f), exist_ok=True)
653 # Download small files
654 if metadata.content_length <= self._multipart_threshold:
655 if self._rust_client:
656 run_async_rust_client_method(self._rust_client, "download", key, f)
657 else:
658 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
659 temp_file_path = fp.name
660 fp.write(self._get_object(remote_path))
661 os.rename(src=temp_file_path, dst=f)
662 return metadata.content_length
663
664 # Download large files using transfer manager
665 def _invoke_api() -> int:
666 bucket_obj = self._gcs_client.bucket(bucket)
667 blob = bucket_obj.blob(key)
668 if self._rust_client:
669 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f)
670 else:
671 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
672 temp_file_path = fp.name
673 transfer_manager.download_chunks_concurrently(
674 blob,
675 temp_file_path,
676 chunk_size=self._io_chunksize,
677 max_workers=self._max_concurrency,
678 worker_type=transfer_manager.THREAD,
679 )
680 os.rename(src=temp_file_path, dst=f)
681
682 return metadata.content_length
683
684 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
685 else:
686 # Download small files
687 if metadata.content_length <= self._multipart_threshold:
688 response = self._get_object(remote_path)
689 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
690 # so we need to check whether `.decode()` is available.
691 if isinstance(f, io.StringIO):
692 if hasattr(response, "decode"):
693 f.write(response.decode("utf-8"))
694 else:
695 f.write(codecs.decode(memoryview(response), "utf-8"))
696 else:
697 f.write(response)
698 return metadata.content_length
699
700 # Download large files using transfer manager
701 def _invoke_api() -> int:
702 bucket_obj = self._gcs_client.bucket(bucket)
703 blob = bucket_obj.blob(key)
704
705 # transfer manager does not support downloading to a file object
706 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
707 temp_file_path = fp.name
708 transfer_manager.download_chunks_concurrently(
709 blob,
710 temp_file_path,
711 chunk_size=self._io_chunksize,
712 max_workers=self._max_concurrency,
713 worker_type=transfer_manager.THREAD,
714 )
715
716 if isinstance(f, io.StringIO):
717 with open(temp_file_path, "r") as fp:
718 f.write(fp.read())
719 else:
720 with open(temp_file_path, "rb") as fp:
721 f.write(fp.read())
722
723 os.unlink(temp_file_path)
724
725 return metadata.content_length
726
727 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)