1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import copy
18import io
19import json
20import logging
21import os
22import tempfile
23from collections.abc import Callable, Iterator
24from typing import IO, Any, Optional, TypeVar, Union
25
26from google.api_core.exceptions import GoogleAPICallError, NotFound
27from google.auth import credentials as auth_credentials
28from google.auth import identity_pool
29from google.cloud import storage
30from google.cloud.storage import transfer_manager
31from google.cloud.storage.exceptions import InvalidResponse
32from google.oauth2 import service_account
33from google.oauth2.credentials import Credentials as OAuth2Credentials
34
35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
36
37from ..constants import DEFAULT_MAX_POOL_CONNECTIONS
38from ..rust_utils import parse_retry_config, run_async_rust_client_method
39from ..telemetry import Telemetry
40from ..types import (
41 AWARE_DATETIME_MIN,
42 Credentials,
43 CredentialsProvider,
44 NotModifiedError,
45 ObjectMetadata,
46 PreconditionFailedError,
47 Range,
48 RetryableError,
49 SymlinkHandling,
50)
51from ..utils import (
52 safe_makedirs,
53 split_path,
54 validate_attributes,
55)
56from .base import BaseStorageProvider
57
58_T = TypeVar("_T")
59
60PROVIDER = "gcs"
61
62MiB = 1024 * 1024
63
64DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
65DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
66DEFAULT_IO_CHUNKSIZE = 32 * MiB
67PYTHON_MAX_CONCURRENCY = 8
68
69logger = logging.getLogger(__name__)
70
71
[docs]
72class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
73 """
74 Supply a string token to the Google Identity Pool.
75 """
76
77 def __init__(self, token: str):
78 self._token = token
79
[docs]
80 def get_subject_token(self, context, request):
81 return self._token
82
83
[docs]
84class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
85 """
86 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
87 """
88
89 def __init__(self, audience: str, token_supplier: str):
90 """
91 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
92
93 :param audience: The audience for the Google Identity Pool.
94 :param token_supplier: The token supplier for the Google Identity Pool.
95 """
96 self._audience = audience
97 self._token_supplier = token_supplier
98
[docs]
99 def get_credentials(self) -> Credentials:
100 return Credentials(
101 access_key="",
102 secret_key="",
103 token="",
104 expiration=None,
105 custom_fields={"audience": self._audience, "token": self._token_supplier},
106 )
107
[docs]
108 def refresh_credentials(self) -> None:
109 pass
110
111
[docs]
112class GoogleServiceAccountCredentialsProvider(CredentialsProvider):
113 """
114 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials.
115 """
116
117 #: Google service account private key file contents.
118 _info: dict[str, Any]
119
120 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None):
121 """
122 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a
123 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file
124 or the file contents.
125
126 :param file: Path to a Google service account private key file.
127 :param info: Google service account private key file contents.
128 """
129 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)):
130 raise ValueError("Must specify exactly one of file or info")
131
132 if file is not None:
133 with open(file, "r") as f:
134 self._info = json.load(f)
135 elif info is not None:
136 self._info = copy.deepcopy(info)
137
[docs]
138 def get_credentials(self) -> Credentials:
139 return Credentials(
140 access_key="",
141 secret_key="",
142 token=None,
143 expiration=None,
144 custom_fields={"info": copy.deepcopy(self._info)},
145 )
146
[docs]
147 def refresh_credentials(self) -> None:
148 pass
149
150
[docs]
151class GoogleStorageProvider(BaseStorageProvider):
152 """
153 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
154 """
155
156 def __init__(
157 self,
158 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
159 endpoint_url: str = "",
160 base_path: str = "",
161 credentials_provider: Optional[CredentialsProvider] = None,
162 config_dict: Optional[dict[str, Any]] = None,
163 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
164 **kwargs: Any,
165 ):
166 """
167 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
168
169 :param project_id: The Google Cloud project ID.
170 :param endpoint_url: The custom endpoint URL for the GCS service.
171 :param base_path: The root prefix path within the bucket where all operations will be scoped.
172 :param credentials_provider: The provider to retrieve GCS credentials.
173 :param config_dict: Resolved MSC config.
174 :param telemetry_provider: A function that provides a telemetry instance.
175 :param max_pool_connections: Maximum connection pool size for the Rust client.
176 """
177 super().__init__(
178 base_path=base_path,
179 provider_name=PROVIDER,
180 config_dict=config_dict,
181 telemetry_provider=telemetry_provider,
182 )
183
184 self._project_id = project_id
185 self._endpoint_url = endpoint_url
186 self._credentials_provider = credentials_provider
187 self._skip_signature = kwargs.get("skip_signature", False)
188 self._gcs_client = self._create_gcs_client()
189 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
190 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
191 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE)
192 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)
193 self._rust_client = None
194 if "rust_client" in kwargs:
195 # Inherit the rust client options from the kwargs
196 rust_client_options = copy.deepcopy(kwargs["rust_client"])
197 if "max_pool_connections" in kwargs:
198 rust_client_options["max_pool_connections"] = kwargs["max_pool_connections"]
199 if "max_concurrency" in kwargs:
200 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
201 if "multipart_chunksize" in kwargs:
202 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
203 if "read_timeout" in kwargs:
204 rust_client_options["read_timeout"] = kwargs["read_timeout"]
205 if "connect_timeout" in kwargs:
206 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
207 self._rust_client = self._create_rust_client(rust_client_options)
208
209 def _create_gcs_client(self) -> storage.Client:
210 client_options = {}
211 if self._endpoint_url:
212 client_options["api_endpoint"] = self._endpoint_url
213
214 # Use anonymous credentials for public buckets when skip_signature is enabled
215 if self._skip_signature:
216 return storage.Client(
217 project=self._project_id,
218 credentials=auth_credentials.AnonymousCredentials(),
219 client_options=client_options,
220 )
221
222 if self._credentials_provider:
223 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
224 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
225 token = self._credentials_provider.get_credentials().get_custom_field("token")
226
227 # Use Workload Identity Federation (WIF)
228 identity_pool_credentials = identity_pool.Credentials(
229 audience=audience,
230 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
231 subject_token_supplier=StringTokenSupplier(token),
232 )
233 return storage.Client(
234 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
235 )
236 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
237 # Use service account key.
238 service_account_credentials = service_account.Credentials.from_service_account_info(
239 info=self._credentials_provider.get_credentials().get_custom_field("info")
240 )
241 return storage.Client(
242 project=self._project_id, credentials=service_account_credentials, client_options=client_options
243 )
244 else:
245 # Use OAuth 2.0 token
246 token = self._credentials_provider.get_credentials().token
247 creds = OAuth2Credentials(token=token)
248 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
249 else:
250 return storage.Client(project=self._project_id, client_options=client_options)
251
252 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
253 if self._endpoint_url:
254 logger.warning("Rust client for GCS does not support customized endpoint URL, skipping rust client")
255 return None
256
257 configs = dict(rust_client_options) if rust_client_options else {}
258
259 # Extract and parse retry configuration
260 retry_config = parse_retry_config(configs)
261
262 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
263 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
264 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"):
265 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY")
266 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"):
267 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT")
268 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"):
269 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH")
270
271 if self._skip_signature and "skip_signature" not in configs:
272 configs["skip_signature"] = True
273
274 if "bucket" not in configs:
275 bucket, _ = split_path(self._base_path)
276 configs["bucket"] = bucket
277
278 if "max_pool_connections" not in configs:
279 configs["max_pool_connections"] = DEFAULT_MAX_POOL_CONNECTIONS
280
281 if self._credentials_provider:
282 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
283 # Workload Identity Federation (WIF) is not supported by the rust client:
284 # https://github.com/apache/arrow-rs-object-store/issues/258
285 logger.warning("Rust client for GCS doesn't support Workload Identity Federation, skipping rust client")
286 return None
287 if isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
288 # Use service account key.
289 configs["service_account_key"] = json.dumps(
290 self._credentials_provider.get_credentials().get_custom_field("info")
291 )
292 return RustClient(
293 provider=PROVIDER,
294 configs=configs,
295 retry=retry_config,
296 )
297 try:
298 return RustClient(
299 provider=PROVIDER,
300 configs=configs,
301 credentials_provider=self._credentials_provider,
302 retry=retry_config,
303 )
304 except Exception as e:
305 logger.warning(f"Failed to create rust client for GCS: {e}, falling back to Python client")
306 return None
307
308 def _refresh_gcs_client_if_needed(self) -> None:
309 """
310 Refreshes the GCS client if the current credentials are expired.
311 """
312 if self._credentials_provider:
313 credentials = self._credentials_provider.get_credentials()
314 if credentials.is_expired():
315 self._credentials_provider.refresh_credentials()
316 self._gcs_client = self._create_gcs_client()
317
318 def _translate_errors(
319 self,
320 func: Callable[[], _T],
321 operation: str,
322 bucket: str,
323 key: str,
324 ) -> _T:
325 """
326 Translates errors like timeouts and client errors.
327
328 :param func: The function that performs the actual GCS operation.
329 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
330 :param bucket: The name of the GCS bucket involved in the operation.
331 :param key: The key of the object within the GCS bucket.
332
333 :return: The result of the GCS operation, typically the return value of the `func` callable.
334 """
335 try:
336 return func()
337 except GoogleAPICallError as error:
338 status_code = error.code if error.code else -1
339 error_info = f"status_code: {status_code}, message: {error.message}"
340 if status_code == 404:
341 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
342 elif status_code == 412:
343 raise PreconditionFailedError(
344 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
345 ) from error
346 elif status_code == 304:
347 # for if_none_match with a specific etag condition.
348 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
349 else:
350 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
351 except InvalidResponse as error:
352 response_text = error.response.text
353 error_details = f"error: {error}, error_response_text: {response_text}"
354 # Check for NoSuchUpload within the response text
355 if "NoSuchUpload" in response_text:
356 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
357 else:
358 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
359 except RustRetryableError as error:
360 raise RetryableError(
361 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
362 f"error_type: {type(error).__name__}"
363 ) from error
364 except RustClientError as error:
365 message = error.args[0]
366 status_code = error.args[1]
367 if status_code == 404:
368 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
369 elif status_code == 403:
370 raise PermissionError(
371 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
372 ) from error
373 else:
374 raise RetryableError(
375 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
376 ) from error
377 except Exception as error:
378 error_details = str(error)
379 raise RuntimeError(
380 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
381 ) from error
382
383 def _put_object(
384 self,
385 path: str,
386 body: bytes,
387 if_match: Optional[str] = None,
388 if_none_match: Optional[str] = None,
389 attributes: Optional[dict[str, str]] = None,
390 ) -> int:
391 """
392 Uploads an object to Google Cloud Storage.
393
394 :param path: The path to the object to upload.
395 :param body: The content of the object to upload.
396 :param if_match: Optional ETag to match against the object.
397 :param if_none_match: Optional ETag to match against the object.
398 :param attributes: Optional attributes to attach to the object.
399 """
400 bucket, key = split_path(path)
401 self._refresh_gcs_client_if_needed()
402
403 def _invoke_api() -> int:
404 bucket_obj = self._gcs_client.bucket(bucket)
405 blob = bucket_obj.blob(key)
406
407 kwargs = {}
408
409 if if_match:
410 kwargs["if_generation_match"] = int(if_match) # 412 error code
411 if if_none_match:
412 if if_none_match == "*":
413 raise NotImplementedError("if_none_match='*' is not supported for GCS")
414 else:
415 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
416
417 validated_attributes = validate_attributes(attributes)
418 if validated_attributes:
419 blob.metadata = validated_attributes
420
421 if (
422 self._rust_client
423 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
424 and not path.endswith("/")
425 and not kwargs
426 and not validated_attributes
427 ):
428 run_async_rust_client_method(self._rust_client, "put", key, body)
429 else:
430 blob.upload_from_string(body, **kwargs)
431
432 return len(body)
433
434 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
435
436 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
437 bucket, key = split_path(path)
438 self._refresh_gcs_client_if_needed()
439
440 def _invoke_api() -> bytes:
441 bucket_obj = self._gcs_client.bucket(bucket)
442 blob = bucket_obj.blob(key)
443 if byte_range:
444 if self._rust_client:
445 return run_async_rust_client_method(
446 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
447 )
448 else:
449 return blob.download_as_bytes(
450 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True
451 )
452 else:
453 if self._rust_client:
454 return run_async_rust_client_method(self._rust_client, "get", key)
455 else:
456 return blob.download_as_bytes(single_shot_download=True)
457
458 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
459
460 def _copy_object(self, src_path: str, dest_path: str) -> int:
461 src_bucket, src_key = split_path(src_path)
462 dest_bucket, dest_key = split_path(dest_path)
463 self._refresh_gcs_client_if_needed()
464
465 src_object = self._get_object_metadata(src_path)
466
467 def _invoke_api() -> int:
468 source_bucket_obj = self._gcs_client.bucket(src_bucket)
469 source_blob = source_bucket_obj.blob(src_key)
470
471 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
472 destination_blob = destination_bucket_obj.blob(dest_key)
473
474 rewrite_tokens = [None]
475 while len(rewrite_tokens) > 0:
476 rewrite_token = rewrite_tokens.pop()
477 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
478 if next_rewrite_token is not None:
479 rewrite_tokens.append(next_rewrite_token)
480
481 return src_object.content_length
482
483 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
484
485 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
486 bucket, key = split_path(path)
487 self._refresh_gcs_client_if_needed()
488
489 def _invoke_api() -> None:
490 bucket_obj = self._gcs_client.bucket(bucket)
491 blob = bucket_obj.blob(key)
492
493 # If if_match is provided, use it as a precondition
494 if if_match:
495 generation = int(if_match)
496 blob.delete(if_generation_match=generation)
497 else:
498 # No if_match check needed, just delete
499 blob.delete()
500
501 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
502
503 def _delete_objects(self, paths: list[str]) -> None:
504 if not paths:
505 return
506
507 by_bucket: dict[str, list[str]] = {}
508 for p in paths:
509 bucket, key = split_path(p)
510 by_bucket.setdefault(bucket, []).append(key)
511 self._refresh_gcs_client_if_needed()
512
513 GCS_BATCH_LIMIT = 100
514
515 def _invoke_api() -> None:
516 for bucket, keys in by_bucket.items():
517 bucket_obj = self._gcs_client.bucket(bucket)
518 for i in range(0, len(keys), GCS_BATCH_LIMIT):
519 chunk = keys[i : i + GCS_BATCH_LIMIT]
520 with self._gcs_client.batch():
521 for k in chunk:
522 bucket_obj.blob(k).delete()
523
524 bucket_desc = "(" + "|".join(by_bucket) + ")"
525 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)"
526 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc)
527
528 def _is_dir(self, path: str) -> bool:
529 # Ensure the path ends with '/' to mimic a directory
530 path = self._append_delimiter(path)
531
532 bucket, key = split_path(path)
533 self._refresh_gcs_client_if_needed()
534
535 def _invoke_api() -> bool:
536 bucket_obj = self._gcs_client.bucket(bucket)
537 # List objects with the given prefix
538 blobs = bucket_obj.list_blobs(
539 prefix=key,
540 delimiter="/",
541 )
542 # Check if there are any contents or common prefixes
543 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
544
545 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
546
547 def _make_symlink(self, path: str, target: str) -> None:
548 bucket_name, key = split_path(path)
549 target_bucket, target_key = split_path(target)
550 if bucket_name != target_bucket:
551 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket_name}' -> '{target_bucket}'.")
552 relative_target = ObjectMetadata.encode_symlink_target(key, target_key)
553 self._refresh_gcs_client_if_needed()
554
555 def _invoke_api() -> None:
556 bucket = self._gcs_client.bucket(bucket_name)
557 blob = bucket.blob(key)
558 blob.metadata = {"msc-symlink-target": relative_target}
559 blob.upload_from_string(b"")
560
561 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket_name, key=key)
562
563 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
564 bucket, key = split_path(path)
565 if path.endswith("/") or (bucket and not key):
566 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
567 # which metadata is not guaranteed to exist for cases such as
568 # "virtual prefix" that was never explicitly created.
569 if self._is_dir(path):
570 return ObjectMetadata(
571 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
572 )
573 else:
574 raise FileNotFoundError(f"Directory {path} does not exist.")
575 else:
576 self._refresh_gcs_client_if_needed()
577
578 def _invoke_api() -> ObjectMetadata:
579 bucket_obj = self._gcs_client.bucket(bucket)
580 blob = bucket_obj.get_blob(key)
581 if not blob:
582 raise NotFound(f"Blob {key} not found in bucket {bucket}")
583 user_metadata = dict(blob.metadata) if blob.metadata else None
584 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None
585 return ObjectMetadata(
586 key=path,
587 content_length=blob.size or 0,
588 content_type=blob.content_type,
589 last_modified=blob.updated or AWARE_DATETIME_MIN,
590 etag=str(blob.generation),
591 metadata=user_metadata,
592 symlink_target=symlink_target,
593 )
594
595 try:
596 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
597 except FileNotFoundError as error:
598 if strict:
599 # If the object does not exist on the given path, we will append a trailing slash and
600 # check if the path is a directory.
601 path = self._append_delimiter(path)
602 if self._is_dir(path):
603 return ObjectMetadata(
604 key=path,
605 type="directory",
606 content_length=0,
607 last_modified=AWARE_DATETIME_MIN,
608 )
609 raise error
610
611 def _list_objects(
612 self,
613 path: str,
614 start_after: Optional[str] = None,
615 end_at: Optional[str] = None,
616 include_directories: bool = False,
617 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW,
618 ) -> Iterator[ObjectMetadata]:
619 bucket, prefix = split_path(path)
620
621 # Get the prefix of the start_after and end_at paths relative to the bucket.
622 if start_after:
623 _, start_after = split_path(start_after)
624 if end_at:
625 _, end_at = split_path(end_at)
626
627 self._refresh_gcs_client_if_needed()
628
629 def _invoke_api() -> Iterator[ObjectMetadata]:
630 bucket_obj = self._gcs_client.bucket(bucket)
631 if include_directories:
632 blobs = bucket_obj.list_blobs(
633 prefix=prefix,
634 # This is ≥ instead of >.
635 start_offset=start_after,
636 delimiter="/",
637 )
638 else:
639 blobs = bucket_obj.list_blobs(
640 prefix=prefix,
641 # This is ≥ instead of >.
642 start_offset=start_after,
643 )
644
645 # GCS guarantees lexicographical order.
646 for blob in blobs:
647 key = blob.name
648 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
649 if key.endswith("/"):
650 if include_directories:
651 yield ObjectMetadata(
652 key=os.path.join(bucket, key.rstrip("/")),
653 type="directory",
654 content_length=0,
655 last_modified=blob.updated,
656 )
657 else:
658 user_metadata = blob.metadata
659 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None
660 yield ObjectMetadata(
661 key=os.path.join(bucket, key),
662 content_length=blob.size,
663 content_type=blob.content_type,
664 last_modified=blob.updated,
665 etag=blob.etag,
666 symlink_target=symlink_target,
667 )
668 elif start_after != key:
669 return
670
671 # The directories must be accessed last.
672 if include_directories:
673 for directory in blobs.prefixes:
674 prefix_key = directory.rstrip("/")
675 # Filter by start_after and end_at if specified
676 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
677 yield ObjectMetadata(
678 key=os.path.join(bucket, prefix_key),
679 type="directory",
680 content_length=0,
681 last_modified=AWARE_DATETIME_MIN,
682 )
683
684 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
685
686 @property
687 def supports_parallel_listing(self) -> bool:
688 return True
689
690 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
691 bucket, key = split_path(remote_path)
692 file_size: int = 0
693 self._refresh_gcs_client_if_needed()
694
695 if isinstance(f, str):
696 file_size = os.path.getsize(f)
697
698 # Upload small files
699 if file_size <= self._multipart_threshold:
700 if self._rust_client and not attributes:
701 run_async_rust_client_method(self._rust_client, "upload", f, key)
702 else:
703 with open(f, "rb") as fp:
704 self._put_object(remote_path, fp.read(), attributes=attributes)
705 return file_size
706
707 # Upload large files using transfer manager
708 def _invoke_api() -> int:
709 if self._rust_client and not attributes:
710 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
711 else:
712 bucket_obj = self._gcs_client.bucket(bucket)
713 blob = bucket_obj.blob(key)
714 # GCS will raise an error if blob.metadata is None
715 validated_attributes = validate_attributes(attributes)
716 if validated_attributes is not None:
717 blob.metadata = validated_attributes
718 transfer_manager.upload_chunks_concurrently(
719 f,
720 blob,
721 chunk_size=self._multipart_chunksize,
722 max_workers=self._max_concurrency,
723 worker_type=transfer_manager.THREAD,
724 )
725
726 return file_size
727
728 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
729 else:
730 f.seek(0, io.SEEK_END)
731 file_size = f.tell()
732 f.seek(0)
733
734 # Upload small files
735 if file_size <= self._multipart_threshold:
736 if isinstance(f, io.StringIO):
737 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
738 else:
739 self._put_object(remote_path, f.read(), attributes=attributes)
740 return file_size
741
742 # Upload large files using transfer manager
743 def _invoke_api() -> int:
744 bucket_obj = self._gcs_client.bucket(bucket)
745 blob = bucket_obj.blob(key)
746 validated_attributes = validate_attributes(attributes)
747 if validated_attributes:
748 blob.metadata = validated_attributes
749 if isinstance(f, io.StringIO):
750 mode = "w"
751 else:
752 mode = "wb"
753
754 # transfer manager does not support uploading a file object
755 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
756 temp_file_path = fp.name
757 fp.write(f.read())
758
759 transfer_manager.upload_chunks_concurrently(
760 temp_file_path,
761 blob,
762 chunk_size=self._multipart_chunksize,
763 max_workers=self._max_concurrency,
764 worker_type=transfer_manager.THREAD,
765 )
766
767 os.unlink(temp_file_path)
768
769 return file_size
770
771 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
772
773 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
774 self._refresh_gcs_client_if_needed()
775
776 if metadata is None:
777 metadata = self._get_object_metadata(remote_path)
778
779 bucket, key = split_path(remote_path)
780
781 if isinstance(f, str):
782 if os.path.dirname(f):
783 safe_makedirs(os.path.dirname(f))
784 # Download small files
785 if metadata.content_length <= self._multipart_threshold:
786 if self._rust_client:
787 run_async_rust_client_method(self._rust_client, "download", key, f)
788 else:
789 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
790 temp_file_path = fp.name
791 fp.write(self._get_object(remote_path))
792 os.rename(src=temp_file_path, dst=f)
793 return metadata.content_length
794
795 # Download large files using transfer manager
796 def _invoke_api() -> int:
797 bucket_obj = self._gcs_client.bucket(bucket)
798 blob = bucket_obj.blob(key)
799 if self._rust_client:
800 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f)
801 else:
802 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
803 temp_file_path = fp.name
804 transfer_manager.download_chunks_concurrently(
805 blob,
806 temp_file_path,
807 chunk_size=self._io_chunksize,
808 max_workers=self._max_concurrency,
809 worker_type=transfer_manager.THREAD,
810 )
811 os.rename(src=temp_file_path, dst=f)
812
813 return metadata.content_length
814
815 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
816 else:
817 # Download small files
818 if metadata.content_length <= self._multipart_threshold:
819 response = self._get_object(remote_path)
820 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol,
821 # so we need to check whether `.decode()` is available.
822 if isinstance(f, io.StringIO):
823 if hasattr(response, "decode"):
824 f.write(response.decode("utf-8"))
825 else:
826 f.write(codecs.decode(memoryview(response), "utf-8"))
827 else:
828 f.write(response)
829 return metadata.content_length
830
831 # Download large files using transfer manager
832 def _invoke_api() -> int:
833 bucket_obj = self._gcs_client.bucket(bucket)
834 blob = bucket_obj.blob(key)
835
836 # transfer manager does not support downloading to a file object
837 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
838 temp_file_path = fp.name
839 transfer_manager.download_chunks_concurrently(
840 blob,
841 temp_file_path,
842 chunk_size=self._io_chunksize,
843 max_workers=self._max_concurrency,
844 worker_type=transfer_manager.THREAD,
845 )
846
847 if isinstance(f, io.StringIO):
848 with open(temp_file_path, "r") as fp:
849 f.write(fp.read())
850 else:
851 with open(temp_file_path, "rb") as fp:
852 f.write(fp.read())
853
854 os.unlink(temp_file_path)
855
856 return metadata.content_length
857
858 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)