1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import copy
18import io
19import json
20import logging
21import os
22import tempfile
23from collections.abc import Callable, Iterator
24from typing import IO, Any, Optional, TypeVar, Union
25
26from google.api_core.exceptions import GoogleAPICallError, NotFound
27from google.auth import credentials as auth_credentials
28from google.auth import identity_pool
29from google.cloud import storage
30from google.cloud.storage import transfer_manager
31from google.cloud.storage.exceptions import InvalidResponse
32from google.oauth2 import service_account
33from google.oauth2.credentials import Credentials as OAuth2Credentials
34
35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
36
37from ..rust_utils import parse_retry_config, run_async_rust_client_method
38from ..telemetry import Telemetry
39from ..types import (
40 AWARE_DATETIME_MIN,
41 Credentials,
42 CredentialsProvider,
43 NotModifiedError,
44 ObjectMetadata,
45 PreconditionFailedError,
46 Range,
47 RetryableError,
48)
49from ..utils import (
50 safe_makedirs,
51 split_path,
52 validate_attributes,
53)
54from .base import BaseStorageProvider
55
56_T = TypeVar("_T")
57
58PROVIDER = "gcs"
59
60MiB = 1024 * 1024
61
62DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
63DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
64DEFAULT_IO_CHUNKSIZE = 32 * MiB
65PYTHON_MAX_CONCURRENCY = 8
66
67logger = logging.getLogger(__name__)
68
69
[docs]
70class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
71 """
72 Supply a string token to the Google Identity Pool.
73 """
74
75 def __init__(self, token: str):
76 self._token = token
77
[docs]
78 def get_subject_token(self, context, request):
79 return self._token
80
81
[docs]
82class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
83 """
84 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
85 """
86
87 def __init__(self, audience: str, token_supplier: str):
88 """
89 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
90
91 :param audience: The audience for the Google Identity Pool.
92 :param token_supplier: The token supplier for the Google Identity Pool.
93 """
94 self._audience = audience
95 self._token_supplier = token_supplier
96
[docs]
97 def get_credentials(self) -> Credentials:
98 return Credentials(
99 access_key="",
100 secret_key="",
101 token="",
102 expiration=None,
103 custom_fields={"audience": self._audience, "token": self._token_supplier},
104 )
105
[docs]
106 def refresh_credentials(self) -> None:
107 pass
108
109
[docs]
110class GoogleServiceAccountCredentialsProvider(CredentialsProvider):
111 """
112 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials.
113 """
114
115 #: Google service account private key file contents.
116 _info: dict[str, Any]
117
118 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None):
119 """
120 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a
121 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file
122 or the file contents.
123
124 :param file: Path to a Google service account private key file.
125 :param info: Google service account private key file contents.
126 """
127 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)):
128 raise ValueError("Must specify exactly one of file or info")
129
130 if file is not None:
131 with open(file, "r") as f:
132 self._info = json.load(f)
133 elif info is not None:
134 self._info = copy.deepcopy(info)
135
[docs]
136 def get_credentials(self) -> Credentials:
137 return Credentials(
138 access_key="",
139 secret_key="",
140 token=None,
141 expiration=None,
142 custom_fields={"info": copy.deepcopy(self._info)},
143 )
144
[docs]
145 def refresh_credentials(self) -> None:
146 pass
147
148
[docs]
149class GoogleStorageProvider(BaseStorageProvider):
150 """
151 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
152 """
153
154 def __init__(
155 self,
156 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
157 endpoint_url: str = "",
158 base_path: str = "",
159 credentials_provider: Optional[CredentialsProvider] = None,
160 config_dict: Optional[dict[str, Any]] = None,
161 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
162 **kwargs: Any,
163 ):
164 """
165 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
166
167 :param project_id: The Google Cloud project ID.
168 :param endpoint_url: The custom endpoint URL for the GCS service.
169 :param base_path: The root prefix path within the bucket where all operations will be scoped.
170 :param credentials_provider: The provider to retrieve GCS credentials.
171 :param config_dict: Resolved MSC config.
172 :param telemetry_provider: A function that provides a telemetry instance.
173 """
174 super().__init__(
175 base_path=base_path,
176 provider_name=PROVIDER,
177 config_dict=config_dict,
178 telemetry_provider=telemetry_provider,
179 )
180
181 self._project_id = project_id
182 self._endpoint_url = endpoint_url
183 self._credentials_provider = credentials_provider
184 self._skip_signature = kwargs.get("skip_signature", False)
185 self._gcs_client = self._create_gcs_client()
186 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
187 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
188 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE)
189 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)
190 self._rust_client = None
191 if "rust_client" in kwargs:
192 # Inherit the rust client options from the kwargs
193 rust_client_options = copy.deepcopy(kwargs["rust_client"])
194 if "max_concurrency" in kwargs:
195 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
196 if "multipart_chunksize" in kwargs:
197 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
198 if "read_timeout" in kwargs:
199 rust_client_options["read_timeout"] = kwargs["read_timeout"]
200 if "connect_timeout" in kwargs:
201 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
202 self._rust_client = self._create_rust_client(rust_client_options)
203
204 def _create_gcs_client(self) -> storage.Client:
205 client_options = {}
206 if self._endpoint_url:
207 client_options["api_endpoint"] = self._endpoint_url
208
209 # Use anonymous credentials for public buckets when skip_signature is enabled
210 if self._skip_signature:
211 return storage.Client(
212 project=self._project_id,
213 credentials=auth_credentials.AnonymousCredentials(),
214 client_options=client_options,
215 )
216
217 if self._credentials_provider:
218 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
219 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
220 token = self._credentials_provider.get_credentials().get_custom_field("token")
221
222 # Use Workload Identity Federation (WIF)
223 identity_pool_credentials = identity_pool.Credentials(
224 audience=audience,
225 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
226 subject_token_supplier=StringTokenSupplier(token),
227 )
228 return storage.Client(
229 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
230 )
231 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
232 # Use service account key.
233 service_account_credentials = service_account.Credentials.from_service_account_info(
234 info=self._credentials_provider.get_credentials().get_custom_field("info")
235 )
236 return storage.Client(
237 project=self._project_id, credentials=service_account_credentials, client_options=client_options
238 )
239 else:
240 # Use OAuth 2.0 token
241 token = self._credentials_provider.get_credentials().token
242 creds = OAuth2Credentials(token=token)
243 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
244 else:
245 return storage.Client(project=self._project_id, client_options=client_options)
246
247 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
248 if self._endpoint_url:
249 logger.warning("Rust client for GCS does not support customized endpoint URL, skipping rust client")
250 return None
251
252 configs = dict(rust_client_options) if rust_client_options else {}
253
254 # Extract and parse retry configuration
255 retry_config = parse_retry_config(configs)
256
257 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
258 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
259 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"):
260 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY")
261 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"):
262 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT")
263 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"):
264 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH")
265
266 if self._skip_signature and "skip_signature" not in configs:
267 configs["skip_signature"] = True
268
269 if "bucket" not in configs:
270 bucket, _ = split_path(self._base_path)
271 configs["bucket"] = bucket
272
273 if self._credentials_provider:
274 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
275 # Workload Identity Federation (WIF) is not supported by the rust client:
276 # https://github.com/apache/arrow-rs-object-store/issues/258
277 logger.warning("Rust client for GCS doesn't support Workload Identity Federation, skipping rust client")
278 return None
279 if isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
280 # Use service account key.
281 configs["service_account_key"] = json.dumps(
282 self._credentials_provider.get_credentials().get_custom_field("info")
283 )
284 return RustClient(
285 provider=PROVIDER,
286 configs=configs,
287 retry=retry_config,
288 )
289 try:
290 return RustClient(
291 provider=PROVIDER,
292 configs=configs,
293 credentials_provider=self._credentials_provider,
294 retry=retry_config,
295 )
296 except Exception as e:
297 logger.warning(f"Failed to create rust client for GCS: {e}, falling back to Python client")
298 return None
299
300 def _refresh_gcs_client_if_needed(self) -> None:
301 """
302 Refreshes the GCS client if the current credentials are expired.
303 """
304 if self._credentials_provider:
305 credentials = self._credentials_provider.get_credentials()
306 if credentials.is_expired():
307 self._credentials_provider.refresh_credentials()
308 self._gcs_client = self._create_gcs_client()
309
310 def _translate_errors(
311 self,
312 func: Callable[[], _T],
313 operation: str,
314 bucket: str,
315 key: str,
316 ) -> _T:
317 """
318 Translates errors like timeouts and client errors.
319
320 :param func: The function that performs the actual GCS operation.
321 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
322 :param bucket: The name of the GCS bucket involved in the operation.
323 :param key: The key of the object within the GCS bucket.
324
325 :return: The result of the GCS operation, typically the return value of the `func` callable.
326 """
327 try:
328 return func()
329 except GoogleAPICallError as error:
330 status_code = error.code if error.code else -1
331 error_info = f"status_code: {status_code}, message: {error.message}"
332 if status_code == 404:
333 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
334 elif status_code == 412:
335 raise PreconditionFailedError(
336 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
337 ) from error
338 elif status_code == 304:
339 # for if_none_match with a specific etag condition.
340 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
341 else:
342 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
343 except InvalidResponse as error:
344 response_text = error.response.text
345 error_details = f"error: {error}, error_response_text: {response_text}"
346 # Check for NoSuchUpload within the response text
347 if "NoSuchUpload" in response_text:
348 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
349 else:
350 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
351 except RustRetryableError as error:
352 raise RetryableError(
353 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
354 f"error_type: {type(error).__name__}"
355 ) from error
356 except RustClientError as error:
357 message = error.args[0]
358 status_code = error.args[1]
359 if status_code == 404:
360 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
361 elif status_code == 403:
362 raise PermissionError(
363 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
364 ) from error
365 else:
366 raise RetryableError(
367 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
368 ) from error
369 except Exception as error:
370 error_details = str(error)
371 raise RuntimeError(
372 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
373 ) from error
374
375 def _put_object(
376 self,
377 path: str,
378 body: bytes,
379 if_match: Optional[str] = None,
380 if_none_match: Optional[str] = None,
381 attributes: Optional[dict[str, str]] = None,
382 ) -> int:
383 """
384 Uploads an object to Google Cloud Storage.
385
386 :param path: The path to the object to upload.
387 :param body: The content of the object to upload.
388 :param if_match: Optional ETag to match against the object.
389 :param if_none_match: Optional ETag to match against the object.
390 :param attributes: Optional attributes to attach to the object.
391 """
392 bucket, key = split_path(path)
393 self._refresh_gcs_client_if_needed()
394
395 def _invoke_api() -> int:
396 bucket_obj = self._gcs_client.bucket(bucket)
397 blob = bucket_obj.blob(key)
398
399 kwargs = {}
400
401 if if_match:
402 kwargs["if_generation_match"] = int(if_match) # 412 error code
403 if if_none_match:
404 if if_none_match == "*":
405 raise NotImplementedError("if_none_match='*' is not supported for GCS")
406 else:
407 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
408
409 validated_attributes = validate_attributes(attributes)
410 if validated_attributes:
411 blob.metadata = validated_attributes
412
413 if (
414 self._rust_client
415 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
416 and not path.endswith("/")
417 and not kwargs
418 and not validated_attributes
419 ):
420 run_async_rust_client_method(self._rust_client, "put", key, body)
421 else:
422 blob.upload_from_string(body, **kwargs)
423
424 return len(body)
425
426 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
427
428 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
429 bucket, key = split_path(path)
430 self._refresh_gcs_client_if_needed()
431
432 def _invoke_api() -> bytes:
433 bucket_obj = self._gcs_client.bucket(bucket)
434 blob = bucket_obj.blob(key)
435 if byte_range:
436 if self._rust_client:
437 return run_async_rust_client_method(
438 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
439 )
440 else:
441 return blob.download_as_bytes(
442 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True
443 )
444 else:
445 if self._rust_client:
446 return run_async_rust_client_method(self._rust_client, "get", key)
447 else:
448 return blob.download_as_bytes(single_shot_download=True)
449
450 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
451
452 def _copy_object(self, src_path: str, dest_path: str) -> int:
453 src_bucket, src_key = split_path(src_path)
454 dest_bucket, dest_key = split_path(dest_path)
455 self._refresh_gcs_client_if_needed()
456
457 src_object = self._get_object_metadata(src_path)
458
459 def _invoke_api() -> int:
460 source_bucket_obj = self._gcs_client.bucket(src_bucket)
461 source_blob = source_bucket_obj.blob(src_key)
462
463 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
464 destination_blob = destination_bucket_obj.blob(dest_key)
465
466 rewrite_tokens = [None]
467 while len(rewrite_tokens) > 0:
468 rewrite_token = rewrite_tokens.pop()
469 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
470 if next_rewrite_token is not None:
471 rewrite_tokens.append(next_rewrite_token)
472
473 return src_object.content_length
474
475 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
476
477 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
478 bucket, key = split_path(path)
479 self._refresh_gcs_client_if_needed()
480
481 def _invoke_api() -> None:
482 bucket_obj = self._gcs_client.bucket(bucket)
483 blob = bucket_obj.blob(key)
484
485 # If if_match is provided, use it as a precondition
486 if if_match:
487 generation = int(if_match)
488 blob.delete(if_generation_match=generation)
489 else:
490 # No if_match check needed, just delete
491 blob.delete()
492
493 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
494
495 def _delete_objects(self, paths: list[str]) -> None:
496 if not paths:
497 return
498
499 by_bucket: dict[str, list[str]] = {}
500 for p in paths:
501 bucket, key = split_path(p)
502 by_bucket.setdefault(bucket, []).append(key)
503 self._refresh_gcs_client_if_needed()
504
505 GCS_BATCH_LIMIT = 100
506
507 def _invoke_api() -> None:
508 for bucket, keys in by_bucket.items():
509 bucket_obj = self._gcs_client.bucket(bucket)
510 for i in range(0, len(keys), GCS_BATCH_LIMIT):
511 chunk = keys[i : i + GCS_BATCH_LIMIT]
512 with self._gcs_client.batch():
513 for k in chunk:
514 bucket_obj.blob(k).delete()
515
516 bucket_desc = "(" + "|".join(by_bucket) + ")"
517 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)"
518 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc)
519
520 def _is_dir(self, path: str) -> bool:
521 # Ensure the path ends with '/' to mimic a directory
522 path = self._append_delimiter(path)
523
524 bucket, key = split_path(path)
525 self._refresh_gcs_client_if_needed()
526
527 def _invoke_api() -> bool:
528 bucket_obj = self._gcs_client.bucket(bucket)
529 # List objects with the given prefix
530 blobs = bucket_obj.list_blobs(
531 prefix=key,
532 delimiter="/",
533 )
534 # Check if there are any contents or common prefixes
535 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
536
537 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
538
539 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
540 bucket, key = split_path(path)
541 if path.endswith("/") or (bucket and not key):
542 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
543 # which metadata is not guaranteed to exist for cases such as
544 # "virtual prefix" that was never explicitly created.
545 if self._is_dir(path):
546 return ObjectMetadata(
547 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
548 )
549 else:
550 raise FileNotFoundError(f"Directory {path} does not exist.")
551 else:
552 self._refresh_gcs_client_if_needed()
553
554 def _invoke_api() -> ObjectMetadata:
555 bucket_obj = self._gcs_client.bucket(bucket)
556 blob = bucket_obj.get_blob(key)
557 if not blob:
558 raise NotFound(f"Blob {key} not found in bucket {bucket}")
559 return ObjectMetadata(
560 key=path,
561 content_length=blob.size or 0,
562 content_type=blob.content_type,
563 last_modified=blob.updated or AWARE_DATETIME_MIN,
564 etag=str(blob.generation),
565 metadata=dict(blob.metadata) if blob.metadata else None,
566 )
567
568 try:
569 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
570 except FileNotFoundError as error:
571 if strict:
572 # If the object does not exist on the given path, we will append a trailing slash and
573 # check if the path is a directory.
574 path = self._append_delimiter(path)
575 if self._is_dir(path):
576 return ObjectMetadata(
577 key=path,
578 type="directory",
579 content_length=0,
580 last_modified=AWARE_DATETIME_MIN,
581 )
582 raise error
583
584 def _list_objects(
585 self,
586 path: str,
587 start_after: Optional[str] = None,
588 end_at: Optional[str] = None,
589 include_directories: bool = False,
590 follow_symlinks: bool = True,
591 ) -> Iterator[ObjectMetadata]:
592 bucket, prefix = split_path(path)
593
594 # Get the prefix of the start_after and end_at paths relative to the bucket.
595 if start_after:
596 _, start_after = split_path(start_after)
597 if end_at:
598 _, end_at = split_path(end_at)
599
600 self._refresh_gcs_client_if_needed()
601
602 def _invoke_api() -> Iterator[ObjectMetadata]:
603 bucket_obj = self._gcs_client.bucket(bucket)
604 if include_directories:
605 blobs = bucket_obj.list_blobs(
606 prefix=prefix,
607 # This is ≥ instead of >.
608 start_offset=start_after,
609 delimiter="/",
610 )
611 else:
612 blobs = bucket_obj.list_blobs(
613 prefix=prefix,
614 # This is ≥ instead of >.
615 start_offset=start_after,
616 )
617
618 # GCS guarantees lexicographical order.
619 for blob in blobs:
620 key = blob.name
621 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
622 if key.endswith("/"):
623 if include_directories:
624 yield ObjectMetadata(
625 key=os.path.join(bucket, key.rstrip("/")),
626 type="directory",
627 content_length=0,
628 last_modified=blob.updated,
629 )
630 else:
631 yield ObjectMetadata(
632 key=os.path.join(bucket, key),
633 content_length=blob.size,
634 content_type=blob.content_type,
635 last_modified=blob.updated,
636 etag=blob.etag,
637 )
638 elif start_after != key:
639 return
640
641 # The directories must be accessed last.
642 if include_directories:
643 for directory in blobs.prefixes:
644 prefix_key = directory.rstrip("/")
645 # Filter by start_after and end_at if specified
646 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
647 yield ObjectMetadata(
648 key=os.path.join(bucket, prefix_key),
649 type="directory",
650 content_length=0,
651 last_modified=AWARE_DATETIME_MIN,
652 )
653
654 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
655
656 @property
657 def supports_parallel_listing(self) -> bool:
658 return True
659
660 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
661 bucket, key = split_path(remote_path)
662 file_size: int = 0
663 self._refresh_gcs_client_if_needed()
664
665 if isinstance(f, str):
666 file_size = os.path.getsize(f)
667
668 # Upload small files
669 if file_size <= self._multipart_threshold:
670 if self._rust_client and not attributes:
671 run_async_rust_client_method(self._rust_client, "upload", f, key)
672 else:
673 with open(f, "rb") as fp:
674 self._put_object(remote_path, fp.read(), attributes=attributes)
675 return file_size
676
677 # Upload large files using transfer manager
678 def _invoke_api() -> int:
679 if self._rust_client and not attributes:
680 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
681 else:
682 bucket_obj = self._gcs_client.bucket(bucket)
683 blob = bucket_obj.blob(key)
684 # GCS will raise an error if blob.metadata is None
685 validated_attributes = validate_attributes(attributes)
686 if validated_attributes is not None:
687 blob.metadata = validated_attributes
688 transfer_manager.upload_chunks_concurrently(
689 f,
690 blob,
691 chunk_size=self._multipart_chunksize,
692 max_workers=self._max_concurrency,
693 worker_type=transfer_manager.THREAD,
694 )
695
696 return file_size
697
698 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
699 else:
700 f.seek(0, io.SEEK_END)
701 file_size = f.tell()
702 f.seek(0)
703
704 # Upload small files
705 if file_size <= self._multipart_threshold:
706 if isinstance(f, io.StringIO):
707 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
708 else:
709 self._put_object(remote_path, f.read(), attributes=attributes)
710 return file_size
711
712 # Upload large files using transfer manager
713 def _invoke_api() -> int:
714 bucket_obj = self._gcs_client.bucket(bucket)
715 blob = bucket_obj.blob(key)
716 validated_attributes = validate_attributes(attributes)
717 if validated_attributes:
718 blob.metadata = validated_attributes
719 if isinstance(f, io.StringIO):
720 mode = "w"
721 else:
722 mode = "wb"
723
724 # transfer manager does not support uploading a file object
725 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
726 temp_file_path = fp.name
727 fp.write(f.read())
728
729 transfer_manager.upload_chunks_concurrently(
730 temp_file_path,
731 blob,
732 chunk_size=self._multipart_chunksize,
733 max_workers=self._max_concurrency,
734 worker_type=transfer_manager.THREAD,
735 )
736
737 os.unlink(temp_file_path)
738
739 return file_size
740
741 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
742
743 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
744 self._refresh_gcs_client_if_needed()
745
746 if metadata is None:
747 metadata = self._get_object_metadata(remote_path)
748
749 bucket, key = split_path(remote_path)
750
751 if isinstance(f, str):
752 if os.path.dirname(f):
753 safe_makedirs(os.path.dirname(f))
754 # Download small files
755 if metadata.content_length <= self._multipart_threshold:
756 if self._rust_client:
757 run_async_rust_client_method(self._rust_client, "download", key, f)
758 else:
759 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
760 temp_file_path = fp.name
761 fp.write(self._get_object(remote_path))
762 os.rename(src=temp_file_path, dst=f)
763 return metadata.content_length
764
765 # Download large files using transfer manager
766 def _invoke_api() -> int:
767 bucket_obj = self._gcs_client.bucket(bucket)
768 blob = bucket_obj.blob(key)
769 if self._rust_client:
770 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f)
771 else:
772 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
773 temp_file_path = fp.name
774 transfer_manager.download_chunks_concurrently(
775 blob,
776 temp_file_path,
777 chunk_size=self._io_chunksize,
778 max_workers=self._max_concurrency,
779 worker_type=transfer_manager.THREAD,
780 )
781 os.rename(src=temp_file_path, dst=f)
782
783 return metadata.content_length
784
785 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
786 else:
787 # Download small files
788 if metadata.content_length <= self._multipart_threshold:
789 response = self._get_object(remote_path)
790 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol,
791 # so we need to check whether `.decode()` is available.
792 if isinstance(f, io.StringIO):
793 if hasattr(response, "decode"):
794 f.write(response.decode("utf-8"))
795 else:
796 f.write(codecs.decode(memoryview(response), "utf-8"))
797 else:
798 f.write(response)
799 return metadata.content_length
800
801 # Download large files using transfer manager
802 def _invoke_api() -> int:
803 bucket_obj = self._gcs_client.bucket(bucket)
804 blob = bucket_obj.blob(key)
805
806 # transfer manager does not support downloading to a file object
807 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
808 temp_file_path = fp.name
809 transfer_manager.download_chunks_concurrently(
810 blob,
811 temp_file_path,
812 chunk_size=self._io_chunksize,
813 max_workers=self._max_concurrency,
814 worker_type=transfer_manager.THREAD,
815 )
816
817 if isinstance(f, io.StringIO):
818 with open(temp_file_path, "r") as fp:
819 f.write(fp.read())
820 else:
821 with open(temp_file_path, "rb") as fp:
822 f.write(fp.read())
823
824 os.unlink(temp_file_path)
825
826 return metadata.content_length
827
828 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)