1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import io
18import logging
19import os
20import tempfile
21import time
22from collections.abc import Callable, Iterator, Sequence, Sized
23from typing import IO, Any, Optional, TypeVar, Union
24
25import opentelemetry.metrics as api_metrics
26from google.api_core.exceptions import GoogleAPICallError, NotFound
27from google.auth import identity_pool
28from google.cloud import storage
29from google.cloud.storage import transfer_manager
30from google.cloud.storage.exceptions import InvalidResponse
31from google.oauth2.credentials import Credentials as OAuth2Credentials
32
33from ..rust_utils import run_async_rust_client_method
34from ..telemetry import Telemetry
35from ..telemetry.attributes.base import AttributesProvider
36from ..types import (
37 AWARE_DATETIME_MIN,
38 Credentials,
39 CredentialsProvider,
40 NotModifiedError,
41 ObjectMetadata,
42 PreconditionFailedError,
43 Range,
44 RetryableError,
45)
46from ..utils import (
47 split_path,
48 validate_attributes,
49)
50from .base import BaseStorageProvider
51
52_T = TypeVar("_T")
53
54PROVIDER = "gcs"
55
56MiB = 1024 * 1024
57
58DEFAULT_MULTIPART_THRESHOLD = 512 * MiB
59DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
60DEFAULT_IO_CHUNKSIZE = 32 * MiB
61# Python uses a lower default concurrency due to the GIL limiting true parallelism in threads.
62PYTHON_MAX_CONCURRENCY = 16
63RUST_MAX_CONCURRENCY = 32
64
65logger = logging.getLogger(__name__)
66
67
[docs]
68class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
69 """
70 Supply a string token to the Google Identity Pool.
71 """
72
73 def __init__(self, token: str):
74 self._token = token
75
[docs]
76 def get_subject_token(self, context, request):
77 return self._token
78
79
[docs]
80class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
81 """
82 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
83 """
84
85 def __init__(self, audience: str, token_supplier: str):
86 """
87 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
88
89 :param audience: The audience for the Google Identity Pool.
90 :param token_supplier: The token supplier for the Google Identity Pool.
91 """
92 self._audience = audience
93 self._token_supplier = token_supplier
94
[docs]
95 def get_credentials(self) -> Credentials:
96 return Credentials(
97 access_key="",
98 secret_key="",
99 token="",
100 expiration=None,
101 custom_fields={"audience": self._audience, "token": self._token_supplier},
102 )
103
[docs]
104 def refresh_credentials(self) -> None:
105 pass
106
107
[docs]
108class GoogleStorageProvider(BaseStorageProvider):
109 """
110 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
111 """
112
113 def __init__(
114 self,
115 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
116 endpoint_url: str = "",
117 base_path: str = "",
118 credentials_provider: Optional[CredentialsProvider] = None,
119 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
120 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
121 metric_attributes_providers: Sequence[AttributesProvider] = (),
122 **kwargs: Any,
123 ):
124 """
125 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
126
127 :param project_id: The Google Cloud project ID.
128 :param endpoint_url: The custom endpoint URL for the GCS service.
129 :param base_path: The root prefix path within the bucket where all operations will be scoped.
130 :param credentials_provider: The provider to retrieve GCS credentials.
131 :param metric_counters: Metric counters.
132 :param metric_gauges: Metric gauges.
133 :param metric_attributes_providers: Metric attributes providers.
134 """
135 super().__init__(
136 base_path=base_path,
137 provider_name=PROVIDER,
138 metric_counters=metric_counters,
139 metric_gauges=metric_gauges,
140 metric_attributes_providers=metric_attributes_providers,
141 )
142
143 self._project_id = project_id
144 self._endpoint_url = endpoint_url
145 self._credentials_provider = credentials_provider
146 self._gcs_client = self._create_gcs_client()
147 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
148 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
149 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE)
150 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)
151 self._rust_client = None
152 if "rust_client" in kwargs:
153 self._rust_client = self._create_rust_client(kwargs.get("rust_client"))
154
155 def _create_gcs_client(self) -> storage.Client:
156 client_options = {}
157 if self._endpoint_url:
158 client_options["api_endpoint"] = self._endpoint_url
159
160 if self._credentials_provider:
161 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
162 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
163 token = self._credentials_provider.get_credentials().get_custom_field("token")
164
165 # Use Workload Identity Federation (WIF)
166 identity_pool_credentials = identity_pool.Credentials(
167 audience=audience,
168 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
169 subject_token_supplier=StringTokenSupplier(token),
170 )
171 return storage.Client(
172 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
173 )
174 else:
175 # Use OAuth 2.0 token
176 token = self._credentials_provider.get_credentials().token
177 creds = OAuth2Credentials(token=token)
178 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
179 else:
180 return storage.Client(project=self._project_id, client_options=client_options)
181
182 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
183 from multistorageclient_rust import RustClient
184
185 configs = {}
186
187 if self._credentials_provider or self._endpoint_url:
188 # Workload Identity Federation (WIF) is not supported by the rust client:
189 # https://github.com/apache/arrow-rs-object-store/issues/258
190 logger.warning(
191 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client"
192 )
193 return None
194
195 # Use environment variables if available, could be overridden by rust_client_options
196 if os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
197 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
198 if os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"):
199 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY")
200 if os.getenv("GOOGLE_SERVICE_ACCOUNT"):
201 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT")
202 if os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"):
203 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH")
204
205 if rust_client_options:
206 if "application_credentials" in rust_client_options:
207 configs["application_credentials"] = rust_client_options["application_credentials"]
208 if "service_account_key" in rust_client_options:
209 configs["service_account_key"] = rust_client_options["service_account_key"]
210 if "service_account_path" in rust_client_options:
211 configs["service_account_path"] = rust_client_options["service_account_path"]
212 if "skip_signature" in rust_client_options:
213 configs["skip_signature"] = rust_client_options["skip_signature"]
214 configs["max_concurrency"] = rust_client_options.get("max_concurrency", RUST_MAX_CONCURRENCY)
215 configs["multipart_chunksize"] = rust_client_options.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
216
217 if rust_client_options and "bucket" in rust_client_options:
218 configs["bucket"] = rust_client_options["bucket"]
219 else:
220 bucket, _ = split_path(self._base_path)
221 configs["bucket"] = bucket
222
223 return RustClient(
224 provider=PROVIDER,
225 configs=configs,
226 )
227
228 def _refresh_gcs_client_if_needed(self) -> None:
229 """
230 Refreshes the GCS client if the current credentials are expired.
231 """
232 if self._credentials_provider:
233 credentials = self._credentials_provider.get_credentials()
234 if credentials.is_expired():
235 self._credentials_provider.refresh_credentials()
236 self._gcs_client = self._create_gcs_client()
237
238 def _collect_metrics(
239 self,
240 func: Callable[[], _T],
241 operation: str,
242 bucket: str,
243 key: str,
244 put_object_size: Optional[int] = None,
245 get_object_size: Optional[int] = None,
246 ) -> _T:
247 """
248 Collects and records performance metrics around GCS operations such as PUT, GET, DELETE, etc.
249
250 This method wraps an GCS operation and measures the time it takes to complete, along with recording
251 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
252 proper logging of duration and object size.
253
254 :param func: The function that performs the actual GCS operation.
255 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
256 :param bucket: The name of the GCS bucket involved in the operation.
257 :param key: The key of the object within the GCS bucket.
258 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
259 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
260
261 :return: The result of the GCS operation, typically the return value of the `func` callable.
262 """
263 start_time = time.time()
264 status_code = 200
265
266 object_size = None
267 if operation == "PUT":
268 object_size = put_object_size
269 elif operation == "GET" and get_object_size:
270 object_size = get_object_size
271
272 try:
273 result = func()
274 if operation == "GET" and object_size is None and isinstance(result, Sized):
275 object_size = len(result)
276 return result
277 except GoogleAPICallError as error:
278 status_code = error.code if error.code else -1
279 error_info = f"status_code: {status_code}, message: {error.message}"
280 if status_code == 404:
281 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
282 elif status_code == 412:
283 raise PreconditionFailedError(
284 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
285 ) from error
286 elif status_code == 304:
287 # for if_none_match with a specific etag condition.
288 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
289 else:
290 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
291 except InvalidResponse as error:
292 status_code = error.response.status_code
293 response_text = error.response.text
294 error_details = f"error: {error}, error_response_text: {response_text}"
295 # Check for NoSuchUpload within the response text
296 if "NoSuchUpload" in response_text:
297 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
298 else:
299 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
300 except Exception as error:
301 status_code = -1
302 error_details = str(error)
303 raise RuntimeError(
304 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
305 ) from error
306 finally:
307 elapsed_time = time.time() - start_time
308 self._metric_helper.record_duration(
309 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code
310 )
311 if object_size:
312 self._metric_helper.record_object_size(
313 object_size,
314 provider=self._provider_name,
315 operation=operation,
316 bucket=bucket,
317 status_code=status_code,
318 )
319
320 def _put_object(
321 self,
322 path: str,
323 body: bytes,
324 if_match: Optional[str] = None,
325 if_none_match: Optional[str] = None,
326 attributes: Optional[dict[str, str]] = None,
327 ) -> int:
328 """
329 Uploads an object to Google Cloud Storage.
330
331 :param path: The path to the object to upload.
332 :param body: The content of the object to upload.
333 :param if_match: Optional ETag to match against the object.
334 :param if_none_match: Optional ETag to match against the object.
335 :param attributes: Optional attributes to attach to the object.
336 """
337 bucket, key = split_path(path)
338 self._refresh_gcs_client_if_needed()
339
340 def _invoke_api() -> int:
341 bucket_obj = self._gcs_client.bucket(bucket)
342 blob = bucket_obj.blob(key)
343
344 kwargs = {}
345
346 if if_match:
347 kwargs["if_generation_match"] = int(if_match) # 412 error code
348 if if_none_match:
349 if if_none_match == "*":
350 raise NotImplementedError("if_none_match='*' is not supported for GCS")
351 else:
352 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
353
354 validated_attributes = validate_attributes(attributes)
355 if validated_attributes:
356 blob.metadata = validated_attributes
357
358 if (
359 self._rust_client
360 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
361 and not path.endswith("/")
362 and not kwargs
363 and not validated_attributes
364 ):
365 run_async_rust_client_method(self._rust_client, "put", key, body)
366 else:
367 blob.upload_from_string(body, **kwargs)
368
369 return len(body)
370
371 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
372
373 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
374 bucket, key = split_path(path)
375 self._refresh_gcs_client_if_needed()
376
377 def _invoke_api() -> bytes:
378 bucket_obj = self._gcs_client.bucket(bucket)
379 blob = bucket_obj.blob(key)
380 if byte_range:
381 if self._rust_client:
382 return run_async_rust_client_method(
383 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
384 )
385 else:
386 return blob.download_as_bytes(
387 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True
388 )
389 else:
390 if self._rust_client:
391 return run_async_rust_client_method(self._rust_client, "get", key)
392 else:
393 return blob.download_as_bytes(single_shot_download=True)
394
395 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
396
397 def _copy_object(self, src_path: str, dest_path: str) -> int:
398 src_bucket, src_key = split_path(src_path)
399 dest_bucket, dest_key = split_path(dest_path)
400 self._refresh_gcs_client_if_needed()
401
402 src_object = self._get_object_metadata(src_path)
403
404 def _invoke_api() -> int:
405 source_bucket_obj = self._gcs_client.bucket(src_bucket)
406 source_blob = source_bucket_obj.blob(src_key)
407
408 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
409 destination_blob = destination_bucket_obj.blob(dest_key)
410
411 rewrite_tokens = [None]
412 while len(rewrite_tokens) > 0:
413 rewrite_token = rewrite_tokens.pop()
414 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
415 if next_rewrite_token is not None:
416 rewrite_tokens.append(next_rewrite_token)
417
418 return src_object.content_length
419
420 return self._collect_metrics(
421 _invoke_api,
422 operation="COPY",
423 bucket=src_bucket,
424 key=src_key,
425 put_object_size=src_object.content_length,
426 )
427
428 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
429 bucket, key = split_path(path)
430 self._refresh_gcs_client_if_needed()
431
432 def _invoke_api() -> None:
433 bucket_obj = self._gcs_client.bucket(bucket)
434 blob = bucket_obj.blob(key)
435
436 # If if_match is provided, use it as a precondition
437 if if_match:
438 generation = int(if_match)
439 blob.delete(if_generation_match=generation)
440 else:
441 # No if_match check needed, just delete
442 blob.delete()
443
444 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
445
446 def _is_dir(self, path: str) -> bool:
447 # Ensure the path ends with '/' to mimic a directory
448 path = self._append_delimiter(path)
449
450 bucket, key = split_path(path)
451 self._refresh_gcs_client_if_needed()
452
453 def _invoke_api() -> bool:
454 bucket_obj = self._gcs_client.bucket(bucket)
455 # List objects with the given prefix
456 blobs = bucket_obj.list_blobs(
457 prefix=key,
458 delimiter="/",
459 )
460 # Check if there are any contents or common prefixes
461 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
462
463 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key)
464
465 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
466 bucket, key = split_path(path)
467 if path.endswith("/") or (bucket and not key):
468 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
469 # which metadata is not guaranteed to exist for cases such as
470 # "virtual prefix" that was never explicitly created.
471 if self._is_dir(path):
472 return ObjectMetadata(
473 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
474 )
475 else:
476 raise FileNotFoundError(f"Directory {path} does not exist.")
477 else:
478 self._refresh_gcs_client_if_needed()
479
480 def _invoke_api() -> ObjectMetadata:
481 bucket_obj = self._gcs_client.bucket(bucket)
482 blob = bucket_obj.get_blob(key)
483 if not blob:
484 raise NotFound(f"Blob {key} not found in bucket {bucket}")
485 return ObjectMetadata(
486 key=path,
487 content_length=blob.size or 0,
488 content_type=blob.content_type,
489 last_modified=blob.updated or AWARE_DATETIME_MIN,
490 etag=str(blob.generation),
491 metadata=dict(blob.metadata) if blob.metadata else None,
492 )
493
494 try:
495 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
496 except FileNotFoundError as error:
497 if strict:
498 # If the object does not exist on the given path, we will append a trailing slash and
499 # check if the path is a directory.
500 path = self._append_delimiter(path)
501 if self._is_dir(path):
502 return ObjectMetadata(
503 key=path,
504 type="directory",
505 content_length=0,
506 last_modified=AWARE_DATETIME_MIN,
507 )
508 raise error
509
510 def _list_objects(
511 self,
512 path: str,
513 start_after: Optional[str] = None,
514 end_at: Optional[str] = None,
515 include_directories: bool = False,
516 ) -> Iterator[ObjectMetadata]:
517 bucket, prefix = split_path(path)
518 self._refresh_gcs_client_if_needed()
519
520 def _invoke_api() -> Iterator[ObjectMetadata]:
521 bucket_obj = self._gcs_client.bucket(bucket)
522 if include_directories:
523 blobs = bucket_obj.list_blobs(
524 prefix=prefix,
525 # This is ≥ instead of >.
526 start_offset=start_after,
527 delimiter="/",
528 )
529 else:
530 blobs = bucket_obj.list_blobs(
531 prefix=prefix,
532 # This is ≥ instead of >.
533 start_offset=start_after,
534 )
535
536 # GCS guarantees lexicographical order.
537 for blob in blobs:
538 key = blob.name
539 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
540 if key.endswith("/"):
541 if include_directories:
542 yield ObjectMetadata(
543 key=os.path.join(bucket, key.rstrip("/")),
544 type="directory",
545 content_length=0,
546 last_modified=blob.updated,
547 )
548 else:
549 yield ObjectMetadata(
550 key=os.path.join(bucket, key),
551 content_length=blob.size,
552 content_type=blob.content_type,
553 last_modified=blob.updated,
554 etag=blob.etag,
555 )
556 elif start_after != key:
557 return
558
559 # The directories must be accessed last.
560 if include_directories:
561 for directory in blobs.prefixes:
562 yield ObjectMetadata(
563 key=os.path.join(bucket, directory.rstrip("/")),
564 type="directory",
565 content_length=0,
566 last_modified=AWARE_DATETIME_MIN,
567 )
568
569 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
570
571 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
572 bucket, key = split_path(remote_path)
573 file_size: int = 0
574 self._refresh_gcs_client_if_needed()
575
576 if isinstance(f, str):
577 file_size = os.path.getsize(f)
578
579 # Upload small files
580 if file_size <= self._multipart_threshold:
581 if self._rust_client and not attributes:
582 run_async_rust_client_method(self._rust_client, "upload", f, key)
583 else:
584 with open(f, "rb") as fp:
585 self._put_object(remote_path, fp.read(), attributes=attributes)
586 return file_size
587
588 # Upload large files using transfer manager
589 def _invoke_api() -> int:
590 if self._rust_client and not attributes:
591 run_async_rust_client_method(self._rust_client, "upload_multipart", f, key)
592 else:
593 bucket_obj = self._gcs_client.bucket(bucket)
594 blob = bucket_obj.blob(key)
595 blob.metadata = validate_attributes(attributes)
596 transfer_manager.upload_chunks_concurrently(
597 f,
598 blob,
599 chunk_size=self._multipart_chunksize,
600 max_workers=self._max_concurrency,
601 worker_type=transfer_manager.THREAD,
602 )
603
604 return file_size
605
606 return self._collect_metrics(
607 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
608 )
609 else:
610 f.seek(0, io.SEEK_END)
611 file_size = f.tell()
612 f.seek(0)
613
614 # Upload small files
615 if file_size <= self._multipart_threshold:
616 if isinstance(f, io.StringIO):
617 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
618 else:
619 self._put_object(remote_path, f.read(), attributes=attributes)
620 return file_size
621
622 # Upload large files using transfer manager
623 def _invoke_api() -> int:
624 bucket_obj = self._gcs_client.bucket(bucket)
625 blob = bucket_obj.blob(key)
626 validated_attributes = validate_attributes(attributes)
627 if validated_attributes:
628 blob.metadata = validated_attributes
629 if isinstance(f, io.StringIO):
630 mode = "w"
631 else:
632 mode = "wb"
633
634 # transfer manager does not support uploading a file object
635 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
636 temp_file_path = fp.name
637 fp.write(f.read())
638
639 transfer_manager.upload_chunks_concurrently(
640 temp_file_path,
641 blob,
642 chunk_size=self._multipart_chunksize,
643 max_workers=self._max_concurrency,
644 worker_type=transfer_manager.THREAD,
645 )
646
647 os.unlink(temp_file_path)
648
649 return file_size
650
651 return self._collect_metrics(
652 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
653 )
654
655 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
656 self._refresh_gcs_client_if_needed()
657
658 if metadata is None:
659 metadata = self._get_object_metadata(remote_path)
660
661 bucket, key = split_path(remote_path)
662
663 if isinstance(f, str):
664 if os.path.dirname(f):
665 os.makedirs(os.path.dirname(f), exist_ok=True)
666 # Download small files
667 if metadata.content_length <= self._multipart_threshold:
668 if self._rust_client:
669 run_async_rust_client_method(self._rust_client, "download", key, f)
670 else:
671 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
672 temp_file_path = fp.name
673 fp.write(self._get_object(remote_path))
674 os.rename(src=temp_file_path, dst=f)
675 return metadata.content_length
676
677 # Download large files using transfer manager
678 def _invoke_api() -> int:
679 bucket_obj = self._gcs_client.bucket(bucket)
680 blob = bucket_obj.blob(key)
681 if self._rust_client:
682 run_async_rust_client_method(self._rust_client, "download_multipart", key, f)
683 else:
684 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
685 temp_file_path = fp.name
686 transfer_manager.download_chunks_concurrently(
687 blob,
688 temp_file_path,
689 chunk_size=self._io_chunksize,
690 max_workers=self._max_concurrency,
691 worker_type=transfer_manager.THREAD,
692 )
693 os.rename(src=temp_file_path, dst=f)
694
695 return metadata.content_length
696
697 return self._collect_metrics(
698 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
699 )
700 else:
701 # Download small files
702 if metadata.content_length <= self._multipart_threshold:
703 response = self._get_object(remote_path)
704 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
705 # so we need to check whether `.decode()` is available.
706 if isinstance(f, io.StringIO):
707 if hasattr(response, "decode"):
708 f.write(response.decode("utf-8"))
709 else:
710 f.write(codecs.decode(memoryview(response), "utf-8"))
711 else:
712 f.write(response)
713 return metadata.content_length
714
715 # Download large files using transfer manager
716 def _invoke_api() -> int:
717 bucket_obj = self._gcs_client.bucket(bucket)
718 blob = bucket_obj.blob(key)
719
720 # transfer manager does not support downloading to a file object
721 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
722 temp_file_path = fp.name
723 transfer_manager.download_chunks_concurrently(
724 blob,
725 temp_file_path,
726 chunk_size=self._io_chunksize,
727 max_workers=self._max_concurrency,
728 worker_type=transfer_manager.THREAD,
729 )
730
731 if isinstance(f, io.StringIO):
732 with open(temp_file_path, "r") as fp:
733 f.write(fp.read())
734 else:
735 with open(temp_file_path, "rb") as fp:
736 f.write(fp.read())
737
738 os.unlink(temp_file_path)
739
740 return metadata.content_length
741
742 return self._collect_metrics(
743 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
744 )