1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import logging
19import os
20import tempfile
21from collections.abc import Callable, Iterator
22from typing import IO, Any, Optional, TypeVar, Union
23
24from google.api_core.exceptions import GoogleAPICallError, NotFound
25from google.auth import identity_pool
26from google.cloud import storage
27from google.cloud.storage import transfer_manager
28from google.cloud.storage.exceptions import InvalidResponse
29from google.oauth2.credentials import Credentials as OAuth2Credentials
30
31from multistorageclient_rust import RustClient, RustRetryableError
32
33from ..rust_utils import run_async_rust_client_method
34from ..telemetry import Telemetry
35from ..types import (
36 AWARE_DATETIME_MIN,
37 Credentials,
38 CredentialsProvider,
39 NotModifiedError,
40 ObjectMetadata,
41 PreconditionFailedError,
42 Range,
43 RetryableError,
44)
45from ..utils import (
46 split_path,
47 validate_attributes,
48)
49from .base import BaseStorageProvider
50
51_T = TypeVar("_T")
52
53PROVIDER = "gcs"
54
55MiB = 1024 * 1024
56
57DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
58DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
59DEFAULT_IO_CHUNKSIZE = 32 * MiB
60PYTHON_MAX_CONCURRENCY = 8
61
62logger = logging.getLogger(__name__)
63
64
[docs]
65class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
66 """
67 Supply a string token to the Google Identity Pool.
68 """
69
70 def __init__(self, token: str):
71 self._token = token
72
[docs]
73 def get_subject_token(self, context, request):
74 return self._token
75
76
[docs]
77class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
78 """
79 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
80 """
81
82 def __init__(self, audience: str, token_supplier: str):
83 """
84 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
85
86 :param audience: The audience for the Google Identity Pool.
87 :param token_supplier: The token supplier for the Google Identity Pool.
88 """
89 self._audience = audience
90 self._token_supplier = token_supplier
91
[docs]
92 def get_credentials(self) -> Credentials:
93 return Credentials(
94 access_key="",
95 secret_key="",
96 token="",
97 expiration=None,
98 custom_fields={"audience": self._audience, "token": self._token_supplier},
99 )
100
[docs]
101 def refresh_credentials(self) -> None:
102 pass
103
104
[docs]
105class GoogleStorageProvider(BaseStorageProvider):
106 """
107 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
108 """
109
110 def __init__(
111 self,
112 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
113 endpoint_url: str = "",
114 base_path: str = "",
115 credentials_provider: Optional[CredentialsProvider] = None,
116 config_dict: Optional[dict[str, Any]] = None,
117 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
118 **kwargs: Any,
119 ):
120 """
121 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
122
123 :param project_id: The Google Cloud project ID.
124 :param endpoint_url: The custom endpoint URL for the GCS service.
125 :param base_path: The root prefix path within the bucket where all operations will be scoped.
126 :param credentials_provider: The provider to retrieve GCS credentials.
127 :param config_dict: Resolved MSC config.
128 :param telemetry_provider: A function that provides a telemetry instance.
129 """
130 super().__init__(
131 base_path=base_path,
132 provider_name=PROVIDER,
133 config_dict=config_dict,
134 telemetry_provider=telemetry_provider,
135 )
136
137 self._project_id = project_id
138 self._endpoint_url = endpoint_url
139 self._credentials_provider = credentials_provider
140 self._gcs_client = self._create_gcs_client()
141 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
142 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
143 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE)
144 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)
145 self._rust_client = None
146 if "rust_client" in kwargs:
147 # Inherit the rust client options from the kwargs
148 rust_client_options = kwargs["rust_client"]
149 if "max_concurrency" in kwargs:
150 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
151 if "multipart_chunksize" in kwargs:
152 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
153 if "read_timeout" in kwargs:
154 rust_client_options["read_timeout"] = kwargs["read_timeout"]
155 if "connect_timeout" in kwargs:
156 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
157 self._rust_client = self._create_rust_client(rust_client_options)
158
159 def _create_gcs_client(self) -> storage.Client:
160 client_options = {}
161 if self._endpoint_url:
162 client_options["api_endpoint"] = self._endpoint_url
163
164 if self._credentials_provider:
165 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
166 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
167 token = self._credentials_provider.get_credentials().get_custom_field("token")
168
169 # Use Workload Identity Federation (WIF)
170 identity_pool_credentials = identity_pool.Credentials(
171 audience=audience,
172 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
173 subject_token_supplier=StringTokenSupplier(token),
174 )
175 return storage.Client(
176 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
177 )
178 else:
179 # Use OAuth 2.0 token
180 token = self._credentials_provider.get_credentials().token
181 creds = OAuth2Credentials(token=token)
182 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
183 else:
184 return storage.Client(project=self._project_id, client_options=client_options)
185
186 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
187 configs = {}
188
189 if self._credentials_provider or self._endpoint_url:
190 # Workload Identity Federation (WIF) is not supported by the rust client:
191 # https://github.com/apache/arrow-rs-object-store/issues/258
192 logger.warning(
193 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client"
194 )
195 return None
196
197 # Use environment variables if available, could be overridden by rust_client_options
198 if os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
199 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
200 if os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"):
201 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY")
202 if os.getenv("GOOGLE_SERVICE_ACCOUNT"):
203 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT")
204 if os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"):
205 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH")
206
207 if rust_client_options:
208 if "application_credentials" in rust_client_options:
209 configs["application_credentials"] = rust_client_options["application_credentials"]
210 if "service_account_key" in rust_client_options:
211 configs["service_account_key"] = rust_client_options["service_account_key"]
212 if "service_account_path" in rust_client_options:
213 configs["service_account_path"] = rust_client_options["service_account_path"]
214 if "skip_signature" in rust_client_options:
215 configs["skip_signature"] = rust_client_options["skip_signature"]
216 if "max_concurrency" in rust_client_options:
217 configs["max_concurrency"] = rust_client_options["max_concurrency"]
218 if "multipart_chunksize" in rust_client_options:
219 configs["multipart_chunksize"] = rust_client_options["multipart_chunksize"]
220 if "read_timeout" in rust_client_options:
221 configs["read_timeout"] = rust_client_options["read_timeout"]
222 if "connect_timeout" in rust_client_options:
223 configs["connect_timeout"] = rust_client_options["connect_timeout"]
224
225 if rust_client_options and "bucket" in rust_client_options:
226 configs["bucket"] = rust_client_options["bucket"]
227 else:
228 bucket, _ = split_path(self._base_path)
229 configs["bucket"] = bucket
230
231 return RustClient(
232 provider=PROVIDER,
233 configs=configs,
234 )
235
236 def _refresh_gcs_client_if_needed(self) -> None:
237 """
238 Refreshes the GCS client if the current credentials are expired.
239 """
240 if self._credentials_provider:
241 credentials = self._credentials_provider.get_credentials()
242 if credentials.is_expired():
243 self._credentials_provider.refresh_credentials()
244 self._gcs_client = self._create_gcs_client()
245
246 def _translate_errors(
247 self,
248 func: Callable[[], _T],
249 operation: str,
250 bucket: str,
251 key: str,
252 ) -> _T:
253 """
254 Translates errors like timeouts and client errors.
255
256 :param func: The function that performs the actual GCS operation.
257 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
258 :param bucket: The name of the GCS bucket involved in the operation.
259 :param key: The key of the object within the GCS bucket.
260
261 :return: The result of the GCS operation, typically the return value of the `func` callable.
262 """
263 try:
264 return func()
265 except GoogleAPICallError as error:
266 status_code = error.code if error.code else -1
267 error_info = f"status_code: {status_code}, message: {error.message}"
268 if status_code == 404:
269 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
270 elif status_code == 412:
271 raise PreconditionFailedError(
272 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
273 ) from error
274 elif status_code == 304:
275 # for if_none_match with a specific etag condition.
276 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
277 else:
278 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
279 except InvalidResponse as error:
280 response_text = error.response.text
281 error_details = f"error: {error}, error_response_text: {response_text}"
282 # Check for NoSuchUpload within the response text
283 if "NoSuchUpload" in response_text:
284 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
285 else:
286 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
287 except RustRetryableError as error:
288 raise RetryableError(
289 f"Failed to {operation} object(s) at {bucket}/{key} due to exhausted retries from Rust. "
290 f"error_type: {type(error).__name__}"
291 ) from error
292 except Exception as error:
293 error_details = str(error)
294 raise RuntimeError(
295 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
296 ) from error
297
298 def _put_object(
299 self,
300 path: str,
301 body: bytes,
302 if_match: Optional[str] = None,
303 if_none_match: Optional[str] = None,
304 attributes: Optional[dict[str, str]] = None,
305 ) -> int:
306 """
307 Uploads an object to Google Cloud Storage.
308
309 :param path: The path to the object to upload.
310 :param body: The content of the object to upload.
311 :param if_match: Optional ETag to match against the object.
312 :param if_none_match: Optional ETag to match against the object.
313 :param attributes: Optional attributes to attach to the object.
314 """
315 bucket, key = split_path(path)
316 self._refresh_gcs_client_if_needed()
317
318 def _invoke_api() -> int:
319 bucket_obj = self._gcs_client.bucket(bucket)
320 blob = bucket_obj.blob(key)
321
322 kwargs = {}
323
324 if if_match:
325 kwargs["if_generation_match"] = int(if_match) # 412 error code
326 if if_none_match:
327 if if_none_match == "*":
328 raise NotImplementedError("if_none_match='*' is not supported for GCS")
329 else:
330 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
331
332 validated_attributes = validate_attributes(attributes)
333 if validated_attributes:
334 blob.metadata = validated_attributes
335
336 if (
337 self._rust_client
338 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
339 and not path.endswith("/")
340 and not kwargs
341 and not validated_attributes
342 ):
343 run_async_rust_client_method(self._rust_client, "put", key, body)
344 else:
345 blob.upload_from_string(body, **kwargs)
346
347 return len(body)
348
349 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
350
351 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
352 bucket, key = split_path(path)
353 self._refresh_gcs_client_if_needed()
354
355 def _invoke_api() -> bytes:
356 bucket_obj = self._gcs_client.bucket(bucket)
357 blob = bucket_obj.blob(key)
358 if byte_range:
359 if self._rust_client:
360 return run_async_rust_client_method(
361 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
362 )
363 else:
364 return blob.download_as_bytes(
365 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True
366 )
367 else:
368 if self._rust_client:
369 return run_async_rust_client_method(self._rust_client, "get", key)
370 else:
371 return blob.download_as_bytes(single_shot_download=True)
372
373 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
374
375 def _copy_object(self, src_path: str, dest_path: str) -> int:
376 src_bucket, src_key = split_path(src_path)
377 dest_bucket, dest_key = split_path(dest_path)
378 self._refresh_gcs_client_if_needed()
379
380 src_object = self._get_object_metadata(src_path)
381
382 def _invoke_api() -> int:
383 source_bucket_obj = self._gcs_client.bucket(src_bucket)
384 source_blob = source_bucket_obj.blob(src_key)
385
386 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
387 destination_blob = destination_bucket_obj.blob(dest_key)
388
389 rewrite_tokens = [None]
390 while len(rewrite_tokens) > 0:
391 rewrite_token = rewrite_tokens.pop()
392 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
393 if next_rewrite_token is not None:
394 rewrite_tokens.append(next_rewrite_token)
395
396 return src_object.content_length
397
398 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
399
400 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
401 bucket, key = split_path(path)
402 self._refresh_gcs_client_if_needed()
403
404 def _invoke_api() -> None:
405 bucket_obj = self._gcs_client.bucket(bucket)
406 blob = bucket_obj.blob(key)
407
408 # If if_match is provided, use it as a precondition
409 if if_match:
410 generation = int(if_match)
411 blob.delete(if_generation_match=generation)
412 else:
413 # No if_match check needed, just delete
414 blob.delete()
415
416 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
417
418 def _is_dir(self, path: str) -> bool:
419 # Ensure the path ends with '/' to mimic a directory
420 path = self._append_delimiter(path)
421
422 bucket, key = split_path(path)
423 self._refresh_gcs_client_if_needed()
424
425 def _invoke_api() -> bool:
426 bucket_obj = self._gcs_client.bucket(bucket)
427 # List objects with the given prefix
428 blobs = bucket_obj.list_blobs(
429 prefix=key,
430 delimiter="/",
431 )
432 # Check if there are any contents or common prefixes
433 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
434
435 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
436
437 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
438 bucket, key = split_path(path)
439 if path.endswith("/") or (bucket and not key):
440 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
441 # which metadata is not guaranteed to exist for cases such as
442 # "virtual prefix" that was never explicitly created.
443 if self._is_dir(path):
444 return ObjectMetadata(
445 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
446 )
447 else:
448 raise FileNotFoundError(f"Directory {path} does not exist.")
449 else:
450 self._refresh_gcs_client_if_needed()
451
452 def _invoke_api() -> ObjectMetadata:
453 bucket_obj = self._gcs_client.bucket(bucket)
454 blob = bucket_obj.get_blob(key)
455 if not blob:
456 raise NotFound(f"Blob {key} not found in bucket {bucket}")
457 return ObjectMetadata(
458 key=path,
459 content_length=blob.size or 0,
460 content_type=blob.content_type,
461 last_modified=blob.updated or AWARE_DATETIME_MIN,
462 etag=str(blob.generation),
463 metadata=dict(blob.metadata) if blob.metadata else None,
464 )
465
466 try:
467 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
468 except FileNotFoundError as error:
469 if strict:
470 # If the object does not exist on the given path, we will append a trailing slash and
471 # check if the path is a directory.
472 path = self._append_delimiter(path)
473 if self._is_dir(path):
474 return ObjectMetadata(
475 key=path,
476 type="directory",
477 content_length=0,
478 last_modified=AWARE_DATETIME_MIN,
479 )
480 raise error
481
482 def _list_objects(
483 self,
484 path: str,
485 start_after: Optional[str] = None,
486 end_at: Optional[str] = None,
487 include_directories: bool = False,
488 ) -> Iterator[ObjectMetadata]:
489 bucket, prefix = split_path(path)
490 self._refresh_gcs_client_if_needed()
491
492 def _invoke_api() -> Iterator[ObjectMetadata]:
493 bucket_obj = self._gcs_client.bucket(bucket)
494 if include_directories:
495 blobs = bucket_obj.list_blobs(
496 prefix=prefix,
497 # This is ≥ instead of >.
498 start_offset=start_after,
499 delimiter="/",
500 )
501 else:
502 blobs = bucket_obj.list_blobs(
503 prefix=prefix,
504 # This is ≥ instead of >.
505 start_offset=start_after,
506 )
507
508 # GCS guarantees lexicographical order.
509 for blob in blobs:
510 key = blob.name
511 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
512 if key.endswith("/"):
513 if include_directories:
514 yield ObjectMetadata(
515 key=os.path.join(bucket, key.rstrip("/")),
516 type="directory",
517 content_length=0,
518 last_modified=blob.updated,
519 )
520 else:
521 yield ObjectMetadata(
522 key=os.path.join(bucket, key),
523 content_length=blob.size,
524 content_type=blob.content_type,
525 last_modified=blob.updated,
526 etag=blob.etag,
527 )
528 elif start_after != key:
529 return
530
531 # The directories must be accessed last.
532 if include_directories:
533 for directory in blobs.prefixes:
534 yield ObjectMetadata(
535 key=os.path.join(bucket, directory.rstrip("/")),
536 type="directory",
537 content_length=0,
538 last_modified=AWARE_DATETIME_MIN,
539 )
540
541 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
542
543 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
544 bucket, key = split_path(remote_path)
545 file_size: int = 0
546 self._refresh_gcs_client_if_needed()
547
548 if isinstance(f, str):
549 file_size = os.path.getsize(f)
550
551 # Upload small files
552 if file_size <= self._multipart_threshold:
553 if self._rust_client and not attributes:
554 run_async_rust_client_method(self._rust_client, "upload", f, key)
555 else:
556 with open(f, "rb") as fp:
557 self._put_object(remote_path, fp.read(), attributes=attributes)
558 return file_size
559
560 # Upload large files using transfer manager
561 def _invoke_api() -> int:
562 if self._rust_client and not attributes:
563 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
564 else:
565 bucket_obj = self._gcs_client.bucket(bucket)
566 blob = bucket_obj.blob(key)
567 # GCS will raise an error if blob.metadata is None
568 validated_attributes = validate_attributes(attributes)
569 if validated_attributes is not None:
570 blob.metadata = validated_attributes
571 transfer_manager.upload_chunks_concurrently(
572 f,
573 blob,
574 chunk_size=self._multipart_chunksize,
575 max_workers=self._max_concurrency,
576 worker_type=transfer_manager.THREAD,
577 )
578
579 return file_size
580
581 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
582 else:
583 f.seek(0, io.SEEK_END)
584 file_size = f.tell()
585 f.seek(0)
586
587 # Upload small files
588 if file_size <= self._multipart_threshold:
589 if isinstance(f, io.StringIO):
590 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
591 else:
592 self._put_object(remote_path, f.read(), attributes=attributes)
593 return file_size
594
595 # Upload large files using transfer manager
596 def _invoke_api() -> int:
597 bucket_obj = self._gcs_client.bucket(bucket)
598 blob = bucket_obj.blob(key)
599 validated_attributes = validate_attributes(attributes)
600 if validated_attributes:
601 blob.metadata = validated_attributes
602 if isinstance(f, io.StringIO):
603 mode = "w"
604 else:
605 mode = "wb"
606
607 # transfer manager does not support uploading a file object
608 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
609 temp_file_path = fp.name
610 fp.write(f.read())
611
612 transfer_manager.upload_chunks_concurrently(
613 temp_file_path,
614 blob,
615 chunk_size=self._multipart_chunksize,
616 max_workers=self._max_concurrency,
617 worker_type=transfer_manager.THREAD,
618 )
619
620 os.unlink(temp_file_path)
621
622 return file_size
623
624 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
625
626 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
627 self._refresh_gcs_client_if_needed()
628
629 if metadata is None:
630 metadata = self._get_object_metadata(remote_path)
631
632 bucket, key = split_path(remote_path)
633
634 if isinstance(f, str):
635 if os.path.dirname(f):
636 os.makedirs(os.path.dirname(f), exist_ok=True)
637 # Download small files
638 if metadata.content_length <= self._multipart_threshold:
639 if self._rust_client:
640 run_async_rust_client_method(self._rust_client, "download", key, f)
641 else:
642 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
643 temp_file_path = fp.name
644 fp.write(self._get_object(remote_path))
645 os.rename(src=temp_file_path, dst=f)
646 return metadata.content_length
647
648 # Download large files using transfer manager
649 def _invoke_api() -> int:
650 bucket_obj = self._gcs_client.bucket(bucket)
651 blob = bucket_obj.blob(key)
652 if self._rust_client:
653 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f)
654 else:
655 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
656 temp_file_path = fp.name
657 transfer_manager.download_chunks_concurrently(
658 blob,
659 temp_file_path,
660 chunk_size=self._io_chunksize,
661 max_workers=self._max_concurrency,
662 worker_type=transfer_manager.THREAD,
663 )
664 os.rename(src=temp_file_path, dst=f)
665
666 return metadata.content_length
667
668 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
669 else:
670 # Download small files
671 if metadata.content_length <= self._multipart_threshold:
672 response = self._get_object(remote_path)
673 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
674 # so we need to check whether `.decode()` is available.
675 if isinstance(f, io.StringIO):
676 if hasattr(response, "decode"):
677 f.write(response.decode("utf-8"))
678 else:
679 f.write(codecs.decode(memoryview(response), "utf-8"))
680 else:
681 f.write(response)
682 return metadata.content_length
683
684 # Download large files using transfer manager
685 def _invoke_api() -> int:
686 bucket_obj = self._gcs_client.bucket(bucket)
687 blob = bucket_obj.blob(key)
688
689 # transfer manager does not support downloading to a file object
690 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
691 temp_file_path = fp.name
692 transfer_manager.download_chunks_concurrently(
693 blob,
694 temp_file_path,
695 chunk_size=self._io_chunksize,
696 max_workers=self._max_concurrency,
697 worker_type=transfer_manager.THREAD,
698 )
699
700 if isinstance(f, io.StringIO):
701 with open(temp_file_path, "r") as fp:
702 f.write(fp.read())
703 else:
704 with open(temp_file_path, "rb") as fp:
705 f.write(fp.read())
706
707 os.unlink(temp_file_path)
708
709 return metadata.content_length
710
711 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)