1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import copy
18import io
19import json
20import logging
21import os
22import tempfile
23from collections.abc import Callable, Iterator
24from typing import IO, Any, Optional, TypeVar, Union
25
26from google.api_core.exceptions import GoogleAPICallError, NotFound
27from google.auth import credentials as auth_credentials
28from google.auth import identity_pool
29from google.cloud import storage
30from google.cloud.storage import transfer_manager
31from google.cloud.storage.exceptions import InvalidResponse
32from google.oauth2 import service_account
33from google.oauth2.credentials import Credentials as OAuth2Credentials
34
35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
36
37from ..rust_utils import parse_retry_config, run_async_rust_client_method
38from ..telemetry import Telemetry
39from ..types import (
40 AWARE_DATETIME_MIN,
41 Credentials,
42 CredentialsProvider,
43 NotModifiedError,
44 ObjectMetadata,
45 PreconditionFailedError,
46 Range,
47 RetryableError,
48 SymlinkHandling,
49)
50from ..utils import (
51 safe_makedirs,
52 split_path,
53 validate_attributes,
54)
55from .base import BaseStorageProvider
56
57_T = TypeVar("_T")
58
59PROVIDER = "gcs"
60
61MiB = 1024 * 1024
62
63DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
64DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
65DEFAULT_IO_CHUNKSIZE = 32 * MiB
66PYTHON_MAX_CONCURRENCY = 8
67
68logger = logging.getLogger(__name__)
69
70
[docs]
71class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
72 """
73 Supply a string token to the Google Identity Pool.
74 """
75
76 def __init__(self, token: str):
77 self._token = token
78
[docs]
79 def get_subject_token(self, context, request):
80 return self._token
81
82
[docs]
83class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
84 """
85 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
86 """
87
88 def __init__(self, audience: str, token_supplier: str):
89 """
90 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
91
92 :param audience: The audience for the Google Identity Pool.
93 :param token_supplier: The token supplier for the Google Identity Pool.
94 """
95 self._audience = audience
96 self._token_supplier = token_supplier
97
[docs]
98 def get_credentials(self) -> Credentials:
99 return Credentials(
100 access_key="",
101 secret_key="",
102 token="",
103 expiration=None,
104 custom_fields={"audience": self._audience, "token": self._token_supplier},
105 )
106
[docs]
107 def refresh_credentials(self) -> None:
108 pass
109
110
[docs]
111class GoogleServiceAccountCredentialsProvider(CredentialsProvider):
112 """
113 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials.
114 """
115
116 #: Google service account private key file contents.
117 _info: dict[str, Any]
118
119 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None):
120 """
121 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a
122 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file
123 or the file contents.
124
125 :param file: Path to a Google service account private key file.
126 :param info: Google service account private key file contents.
127 """
128 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)):
129 raise ValueError("Must specify exactly one of file or info")
130
131 if file is not None:
132 with open(file, "r") as f:
133 self._info = json.load(f)
134 elif info is not None:
135 self._info = copy.deepcopy(info)
136
[docs]
137 def get_credentials(self) -> Credentials:
138 return Credentials(
139 access_key="",
140 secret_key="",
141 token=None,
142 expiration=None,
143 custom_fields={"info": copy.deepcopy(self._info)},
144 )
145
[docs]
146 def refresh_credentials(self) -> None:
147 pass
148
149
[docs]
150class GoogleStorageProvider(BaseStorageProvider):
151 """
152 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
153 """
154
155 def __init__(
156 self,
157 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
158 endpoint_url: str = "",
159 base_path: str = "",
160 credentials_provider: Optional[CredentialsProvider] = None,
161 config_dict: Optional[dict[str, Any]] = None,
162 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
163 **kwargs: Any,
164 ):
165 """
166 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
167
168 :param project_id: The Google Cloud project ID.
169 :param endpoint_url: The custom endpoint URL for the GCS service.
170 :param base_path: The root prefix path within the bucket where all operations will be scoped.
171 :param credentials_provider: The provider to retrieve GCS credentials.
172 :param config_dict: Resolved MSC config.
173 :param telemetry_provider: A function that provides a telemetry instance.
174 """
175 super().__init__(
176 base_path=base_path,
177 provider_name=PROVIDER,
178 config_dict=config_dict,
179 telemetry_provider=telemetry_provider,
180 )
181
182 self._project_id = project_id
183 self._endpoint_url = endpoint_url
184 self._credentials_provider = credentials_provider
185 self._skip_signature = kwargs.get("skip_signature", False)
186 self._gcs_client = self._create_gcs_client()
187 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
188 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
189 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE)
190 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)
191 self._rust_client = None
192 if "rust_client" in kwargs:
193 # Inherit the rust client options from the kwargs
194 rust_client_options = copy.deepcopy(kwargs["rust_client"])
195 if "max_concurrency" in kwargs:
196 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
197 if "multipart_chunksize" in kwargs:
198 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
199 if "read_timeout" in kwargs:
200 rust_client_options["read_timeout"] = kwargs["read_timeout"]
201 if "connect_timeout" in kwargs:
202 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
203 self._rust_client = self._create_rust_client(rust_client_options)
204
205 def _create_gcs_client(self) -> storage.Client:
206 client_options = {}
207 if self._endpoint_url:
208 client_options["api_endpoint"] = self._endpoint_url
209
210 # Use anonymous credentials for public buckets when skip_signature is enabled
211 if self._skip_signature:
212 return storage.Client(
213 project=self._project_id,
214 credentials=auth_credentials.AnonymousCredentials(),
215 client_options=client_options,
216 )
217
218 if self._credentials_provider:
219 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
220 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
221 token = self._credentials_provider.get_credentials().get_custom_field("token")
222
223 # Use Workload Identity Federation (WIF)
224 identity_pool_credentials = identity_pool.Credentials(
225 audience=audience,
226 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
227 subject_token_supplier=StringTokenSupplier(token),
228 )
229 return storage.Client(
230 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
231 )
232 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
233 # Use service account key.
234 service_account_credentials = service_account.Credentials.from_service_account_info(
235 info=self._credentials_provider.get_credentials().get_custom_field("info")
236 )
237 return storage.Client(
238 project=self._project_id, credentials=service_account_credentials, client_options=client_options
239 )
240 else:
241 # Use OAuth 2.0 token
242 token = self._credentials_provider.get_credentials().token
243 creds = OAuth2Credentials(token=token)
244 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
245 else:
246 return storage.Client(project=self._project_id, client_options=client_options)
247
248 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
249 if self._endpoint_url:
250 logger.warning("Rust client for GCS does not support customized endpoint URL, skipping rust client")
251 return None
252
253 configs = dict(rust_client_options) if rust_client_options else {}
254
255 # Extract and parse retry configuration
256 retry_config = parse_retry_config(configs)
257
258 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
259 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
260 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"):
261 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY")
262 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"):
263 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT")
264 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"):
265 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH")
266
267 if self._skip_signature and "skip_signature" not in configs:
268 configs["skip_signature"] = True
269
270 if "bucket" not in configs:
271 bucket, _ = split_path(self._base_path)
272 configs["bucket"] = bucket
273
274 if self._credentials_provider:
275 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
276 # Workload Identity Federation (WIF) is not supported by the rust client:
277 # https://github.com/apache/arrow-rs-object-store/issues/258
278 logger.warning("Rust client for GCS doesn't support Workload Identity Federation, skipping rust client")
279 return None
280 if isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
281 # Use service account key.
282 configs["service_account_key"] = json.dumps(
283 self._credentials_provider.get_credentials().get_custom_field("info")
284 )
285 return RustClient(
286 provider=PROVIDER,
287 configs=configs,
288 retry=retry_config,
289 )
290 try:
291 return RustClient(
292 provider=PROVIDER,
293 configs=configs,
294 credentials_provider=self._credentials_provider,
295 retry=retry_config,
296 )
297 except Exception as e:
298 logger.warning(f"Failed to create rust client for GCS: {e}, falling back to Python client")
299 return None
300
301 def _refresh_gcs_client_if_needed(self) -> None:
302 """
303 Refreshes the GCS client if the current credentials are expired.
304 """
305 if self._credentials_provider:
306 credentials = self._credentials_provider.get_credentials()
307 if credentials.is_expired():
308 self._credentials_provider.refresh_credentials()
309 self._gcs_client = self._create_gcs_client()
310
311 def _translate_errors(
312 self,
313 func: Callable[[], _T],
314 operation: str,
315 bucket: str,
316 key: str,
317 ) -> _T:
318 """
319 Translates errors like timeouts and client errors.
320
321 :param func: The function that performs the actual GCS operation.
322 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
323 :param bucket: The name of the GCS bucket involved in the operation.
324 :param key: The key of the object within the GCS bucket.
325
326 :return: The result of the GCS operation, typically the return value of the `func` callable.
327 """
328 try:
329 return func()
330 except GoogleAPICallError as error:
331 status_code = error.code if error.code else -1
332 error_info = f"status_code: {status_code}, message: {error.message}"
333 if status_code == 404:
334 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
335 elif status_code == 412:
336 raise PreconditionFailedError(
337 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
338 ) from error
339 elif status_code == 304:
340 # for if_none_match with a specific etag condition.
341 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
342 else:
343 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
344 except InvalidResponse as error:
345 response_text = error.response.text
346 error_details = f"error: {error}, error_response_text: {response_text}"
347 # Check for NoSuchUpload within the response text
348 if "NoSuchUpload" in response_text:
349 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
350 else:
351 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
352 except RustRetryableError as error:
353 raise RetryableError(
354 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
355 f"error_type: {type(error).__name__}"
356 ) from error
357 except RustClientError as error:
358 message = error.args[0]
359 status_code = error.args[1]
360 if status_code == 404:
361 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
362 elif status_code == 403:
363 raise PermissionError(
364 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
365 ) from error
366 else:
367 raise RetryableError(
368 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
369 ) from error
370 except Exception as error:
371 error_details = str(error)
372 raise RuntimeError(
373 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
374 ) from error
375
376 def _put_object(
377 self,
378 path: str,
379 body: bytes,
380 if_match: Optional[str] = None,
381 if_none_match: Optional[str] = None,
382 attributes: Optional[dict[str, str]] = None,
383 ) -> int:
384 """
385 Uploads an object to Google Cloud Storage.
386
387 :param path: The path to the object to upload.
388 :param body: The content of the object to upload.
389 :param if_match: Optional ETag to match against the object.
390 :param if_none_match: Optional ETag to match against the object.
391 :param attributes: Optional attributes to attach to the object.
392 """
393 bucket, key = split_path(path)
394 self._refresh_gcs_client_if_needed()
395
396 def _invoke_api() -> int:
397 bucket_obj = self._gcs_client.bucket(bucket)
398 blob = bucket_obj.blob(key)
399
400 kwargs = {}
401
402 if if_match:
403 kwargs["if_generation_match"] = int(if_match) # 412 error code
404 if if_none_match:
405 if if_none_match == "*":
406 raise NotImplementedError("if_none_match='*' is not supported for GCS")
407 else:
408 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
409
410 validated_attributes = validate_attributes(attributes)
411 if validated_attributes:
412 blob.metadata = validated_attributes
413
414 if (
415 self._rust_client
416 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
417 and not path.endswith("/")
418 and not kwargs
419 and not validated_attributes
420 ):
421 run_async_rust_client_method(self._rust_client, "put", key, body)
422 else:
423 blob.upload_from_string(body, **kwargs)
424
425 return len(body)
426
427 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
428
429 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
430 bucket, key = split_path(path)
431 self._refresh_gcs_client_if_needed()
432
433 def _invoke_api() -> bytes:
434 bucket_obj = self._gcs_client.bucket(bucket)
435 blob = bucket_obj.blob(key)
436 if byte_range:
437 if self._rust_client:
438 return run_async_rust_client_method(
439 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
440 )
441 else:
442 return blob.download_as_bytes(
443 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True
444 )
445 else:
446 if self._rust_client:
447 return run_async_rust_client_method(self._rust_client, "get", key)
448 else:
449 return blob.download_as_bytes(single_shot_download=True)
450
451 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
452
453 def _copy_object(self, src_path: str, dest_path: str) -> int:
454 src_bucket, src_key = split_path(src_path)
455 dest_bucket, dest_key = split_path(dest_path)
456 self._refresh_gcs_client_if_needed()
457
458 src_object = self._get_object_metadata(src_path)
459
460 def _invoke_api() -> int:
461 source_bucket_obj = self._gcs_client.bucket(src_bucket)
462 source_blob = source_bucket_obj.blob(src_key)
463
464 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
465 destination_blob = destination_bucket_obj.blob(dest_key)
466
467 rewrite_tokens = [None]
468 while len(rewrite_tokens) > 0:
469 rewrite_token = rewrite_tokens.pop()
470 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
471 if next_rewrite_token is not None:
472 rewrite_tokens.append(next_rewrite_token)
473
474 return src_object.content_length
475
476 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
477
478 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
479 bucket, key = split_path(path)
480 self._refresh_gcs_client_if_needed()
481
482 def _invoke_api() -> None:
483 bucket_obj = self._gcs_client.bucket(bucket)
484 blob = bucket_obj.blob(key)
485
486 # If if_match is provided, use it as a precondition
487 if if_match:
488 generation = int(if_match)
489 blob.delete(if_generation_match=generation)
490 else:
491 # No if_match check needed, just delete
492 blob.delete()
493
494 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
495
496 def _delete_objects(self, paths: list[str]) -> None:
497 if not paths:
498 return
499
500 by_bucket: dict[str, list[str]] = {}
501 for p in paths:
502 bucket, key = split_path(p)
503 by_bucket.setdefault(bucket, []).append(key)
504 self._refresh_gcs_client_if_needed()
505
506 GCS_BATCH_LIMIT = 100
507
508 def _invoke_api() -> None:
509 for bucket, keys in by_bucket.items():
510 bucket_obj = self._gcs_client.bucket(bucket)
511 for i in range(0, len(keys), GCS_BATCH_LIMIT):
512 chunk = keys[i : i + GCS_BATCH_LIMIT]
513 with self._gcs_client.batch():
514 for k in chunk:
515 bucket_obj.blob(k).delete()
516
517 bucket_desc = "(" + "|".join(by_bucket) + ")"
518 key_desc = "(" + "|".join(str(len(keys)) for keys in by_bucket.values()) + " keys)"
519 self._translate_errors(_invoke_api, operation="DELETE_MANY", bucket=bucket_desc, key=key_desc)
520
521 def _is_dir(self, path: str) -> bool:
522 # Ensure the path ends with '/' to mimic a directory
523 path = self._append_delimiter(path)
524
525 bucket, key = split_path(path)
526 self._refresh_gcs_client_if_needed()
527
528 def _invoke_api() -> bool:
529 bucket_obj = self._gcs_client.bucket(bucket)
530 # List objects with the given prefix
531 blobs = bucket_obj.list_blobs(
532 prefix=key,
533 delimiter="/",
534 )
535 # Check if there are any contents or common prefixes
536 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
537
538 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
539
540 def _make_symlink(self, path: str, target: str) -> None:
541 bucket_name, key = split_path(path)
542 target_bucket, target_key = split_path(target)
543 if bucket_name != target_bucket:
544 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket_name}' -> '{target_bucket}'.")
545 relative_target = ObjectMetadata.encode_symlink_target(key, target_key)
546 self._refresh_gcs_client_if_needed()
547
548 def _invoke_api() -> None:
549 bucket = self._gcs_client.bucket(bucket_name)
550 blob = bucket.blob(key)
551 blob.metadata = {"msc-symlink-target": relative_target}
552 blob.upload_from_string(b"")
553
554 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket_name, key=key)
555
556 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
557 bucket, key = split_path(path)
558 if path.endswith("/") or (bucket and not key):
559 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
560 # which metadata is not guaranteed to exist for cases such as
561 # "virtual prefix" that was never explicitly created.
562 if self._is_dir(path):
563 return ObjectMetadata(
564 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
565 )
566 else:
567 raise FileNotFoundError(f"Directory {path} does not exist.")
568 else:
569 self._refresh_gcs_client_if_needed()
570
571 def _invoke_api() -> ObjectMetadata:
572 bucket_obj = self._gcs_client.bucket(bucket)
573 blob = bucket_obj.get_blob(key)
574 if not blob:
575 raise NotFound(f"Blob {key} not found in bucket {bucket}")
576 user_metadata = dict(blob.metadata) if blob.metadata else None
577 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None
578 return ObjectMetadata(
579 key=path,
580 content_length=blob.size or 0,
581 content_type=blob.content_type,
582 last_modified=blob.updated or AWARE_DATETIME_MIN,
583 etag=str(blob.generation),
584 metadata=user_metadata,
585 symlink_target=symlink_target,
586 )
587
588 try:
589 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
590 except FileNotFoundError as error:
591 if strict:
592 # If the object does not exist on the given path, we will append a trailing slash and
593 # check if the path is a directory.
594 path = self._append_delimiter(path)
595 if self._is_dir(path):
596 return ObjectMetadata(
597 key=path,
598 type="directory",
599 content_length=0,
600 last_modified=AWARE_DATETIME_MIN,
601 )
602 raise error
603
604 def _list_objects(
605 self,
606 path: str,
607 start_after: Optional[str] = None,
608 end_at: Optional[str] = None,
609 include_directories: bool = False,
610 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW,
611 ) -> Iterator[ObjectMetadata]:
612 bucket, prefix = split_path(path)
613
614 # Get the prefix of the start_after and end_at paths relative to the bucket.
615 if start_after:
616 _, start_after = split_path(start_after)
617 if end_at:
618 _, end_at = split_path(end_at)
619
620 self._refresh_gcs_client_if_needed()
621
622 def _invoke_api() -> Iterator[ObjectMetadata]:
623 bucket_obj = self._gcs_client.bucket(bucket)
624 if include_directories:
625 blobs = bucket_obj.list_blobs(
626 prefix=prefix,
627 # This is ≥ instead of >.
628 start_offset=start_after,
629 delimiter="/",
630 )
631 else:
632 blobs = bucket_obj.list_blobs(
633 prefix=prefix,
634 # This is ≥ instead of >.
635 start_offset=start_after,
636 )
637
638 # GCS guarantees lexicographical order.
639 for blob in blobs:
640 key = blob.name
641 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
642 if key.endswith("/"):
643 if include_directories:
644 yield ObjectMetadata(
645 key=os.path.join(bucket, key.rstrip("/")),
646 type="directory",
647 content_length=0,
648 last_modified=blob.updated,
649 )
650 else:
651 user_metadata = blob.metadata
652 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None
653 yield ObjectMetadata(
654 key=os.path.join(bucket, key),
655 content_length=blob.size,
656 content_type=blob.content_type,
657 last_modified=blob.updated,
658 etag=blob.etag,
659 symlink_target=symlink_target,
660 )
661 elif start_after != key:
662 return
663
664 # The directories must be accessed last.
665 if include_directories:
666 for directory in blobs.prefixes:
667 prefix_key = directory.rstrip("/")
668 # Filter by start_after and end_at if specified
669 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
670 yield ObjectMetadata(
671 key=os.path.join(bucket, prefix_key),
672 type="directory",
673 content_length=0,
674 last_modified=AWARE_DATETIME_MIN,
675 )
676
677 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
678
679 @property
680 def supports_parallel_listing(self) -> bool:
681 return True
682
683 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
684 bucket, key = split_path(remote_path)
685 file_size: int = 0
686 self._refresh_gcs_client_if_needed()
687
688 if isinstance(f, str):
689 file_size = os.path.getsize(f)
690
691 # Upload small files
692 if file_size <= self._multipart_threshold:
693 if self._rust_client and not attributes:
694 run_async_rust_client_method(self._rust_client, "upload", f, key)
695 else:
696 with open(f, "rb") as fp:
697 self._put_object(remote_path, fp.read(), attributes=attributes)
698 return file_size
699
700 # Upload large files using transfer manager
701 def _invoke_api() -> int:
702 if self._rust_client and not attributes:
703 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
704 else:
705 bucket_obj = self._gcs_client.bucket(bucket)
706 blob = bucket_obj.blob(key)
707 # GCS will raise an error if blob.metadata is None
708 validated_attributes = validate_attributes(attributes)
709 if validated_attributes is not None:
710 blob.metadata = validated_attributes
711 transfer_manager.upload_chunks_concurrently(
712 f,
713 blob,
714 chunk_size=self._multipart_chunksize,
715 max_workers=self._max_concurrency,
716 worker_type=transfer_manager.THREAD,
717 )
718
719 return file_size
720
721 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
722 else:
723 f.seek(0, io.SEEK_END)
724 file_size = f.tell()
725 f.seek(0)
726
727 # Upload small files
728 if file_size <= self._multipart_threshold:
729 if isinstance(f, io.StringIO):
730 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
731 else:
732 self._put_object(remote_path, f.read(), attributes=attributes)
733 return file_size
734
735 # Upload large files using transfer manager
736 def _invoke_api() -> int:
737 bucket_obj = self._gcs_client.bucket(bucket)
738 blob = bucket_obj.blob(key)
739 validated_attributes = validate_attributes(attributes)
740 if validated_attributes:
741 blob.metadata = validated_attributes
742 if isinstance(f, io.StringIO):
743 mode = "w"
744 else:
745 mode = "wb"
746
747 # transfer manager does not support uploading a file object
748 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
749 temp_file_path = fp.name
750 fp.write(f.read())
751
752 transfer_manager.upload_chunks_concurrently(
753 temp_file_path,
754 blob,
755 chunk_size=self._multipart_chunksize,
756 max_workers=self._max_concurrency,
757 worker_type=transfer_manager.THREAD,
758 )
759
760 os.unlink(temp_file_path)
761
762 return file_size
763
764 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
765
766 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
767 self._refresh_gcs_client_if_needed()
768
769 if metadata is None:
770 metadata = self._get_object_metadata(remote_path)
771
772 bucket, key = split_path(remote_path)
773
774 if isinstance(f, str):
775 if os.path.dirname(f):
776 safe_makedirs(os.path.dirname(f))
777 # Download small files
778 if metadata.content_length <= self._multipart_threshold:
779 if self._rust_client:
780 run_async_rust_client_method(self._rust_client, "download", key, f)
781 else:
782 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
783 temp_file_path = fp.name
784 fp.write(self._get_object(remote_path))
785 os.rename(src=temp_file_path, dst=f)
786 return metadata.content_length
787
788 # Download large files using transfer manager
789 def _invoke_api() -> int:
790 bucket_obj = self._gcs_client.bucket(bucket)
791 blob = bucket_obj.blob(key)
792 if self._rust_client:
793 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f)
794 else:
795 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
796 temp_file_path = fp.name
797 transfer_manager.download_chunks_concurrently(
798 blob,
799 temp_file_path,
800 chunk_size=self._io_chunksize,
801 max_workers=self._max_concurrency,
802 worker_type=transfer_manager.THREAD,
803 )
804 os.rename(src=temp_file_path, dst=f)
805
806 return metadata.content_length
807
808 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
809 else:
810 # Download small files
811 if metadata.content_length <= self._multipart_threshold:
812 response = self._get_object(remote_path)
813 # Python client returns `bytes`, but Rust client returns an object that implements the buffer protocol,
814 # so we need to check whether `.decode()` is available.
815 if isinstance(f, io.StringIO):
816 if hasattr(response, "decode"):
817 f.write(response.decode("utf-8"))
818 else:
819 f.write(codecs.decode(memoryview(response), "utf-8"))
820 else:
821 f.write(response)
822 return metadata.content_length
823
824 # Download large files using transfer manager
825 def _invoke_api() -> int:
826 bucket_obj = self._gcs_client.bucket(bucket)
827 blob = bucket_obj.blob(key)
828
829 # transfer manager does not support downloading to a file object
830 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
831 temp_file_path = fp.name
832 transfer_manager.download_chunks_concurrently(
833 blob,
834 temp_file_path,
835 chunk_size=self._io_chunksize,
836 max_workers=self._max_concurrency,
837 worker_type=transfer_manager.THREAD,
838 )
839
840 if isinstance(f, io.StringIO):
841 with open(temp_file_path, "r") as fp:
842 f.write(fp.read())
843 else:
844 with open(temp_file_path, "rb") as fp:
845 f.write(fp.read())
846
847 os.unlink(temp_file_path)
848
849 return metadata.content_length
850
851 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)