1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import copy
18import io
19import json
20import logging
21import os
22import tempfile
23from collections.abc import Callable, Iterator
24from typing import IO, Any, Optional, TypeVar, Union
25
26from google.api_core.exceptions import GoogleAPICallError, NotFound
27from google.auth import credentials as auth_credentials
28from google.auth import identity_pool
29from google.cloud import storage
30from google.cloud.storage import transfer_manager
31from google.cloud.storage.exceptions import InvalidResponse
32from google.oauth2 import service_account
33from google.oauth2.credentials import Credentials as OAuth2Credentials
34
35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
36
37from ..rust_utils import parse_retry_config, run_async_rust_client_method
38from ..telemetry import Telemetry
39from ..types import (
40 AWARE_DATETIME_MIN,
41 Credentials,
42 CredentialsProvider,
43 NotModifiedError,
44 ObjectMetadata,
45 PreconditionFailedError,
46 Range,
47 RetryableError,
48)
49from ..utils import (
50 split_path,
51 validate_attributes,
52)
53from .base import BaseStorageProvider
54
55_T = TypeVar("_T")
56
57PROVIDER = "gcs"
58
59MiB = 1024 * 1024
60
61DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
62DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
63DEFAULT_IO_CHUNKSIZE = 32 * MiB
64PYTHON_MAX_CONCURRENCY = 8
65
66logger = logging.getLogger(__name__)
67
68
[docs]
69class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
70 """
71 Supply a string token to the Google Identity Pool.
72 """
73
74 def __init__(self, token: str):
75 self._token = token
76
[docs]
77 def get_subject_token(self, context, request):
78 return self._token
79
80
[docs]
81class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
82 """
83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
84 """
85
86 def __init__(self, audience: str, token_supplier: str):
87 """
88 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
89
90 :param audience: The audience for the Google Identity Pool.
91 :param token_supplier: The token supplier for the Google Identity Pool.
92 """
93 self._audience = audience
94 self._token_supplier = token_supplier
95
[docs]
96 def get_credentials(self) -> Credentials:
97 return Credentials(
98 access_key="",
99 secret_key="",
100 token="",
101 expiration=None,
102 custom_fields={"audience": self._audience, "token": self._token_supplier},
103 )
104
[docs]
105 def refresh_credentials(self) -> None:
106 pass
107
108
[docs]
109class GoogleServiceAccountCredentialsProvider(CredentialsProvider):
110 """
111 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials.
112 """
113
114 #: Google service account private key file contents.
115 _info: dict[str, Any]
116
117 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None):
118 """
119 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a
120 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file
121 or the file contents.
122
123 :param file: Path to a Google service account private key file.
124 :param info: Google service account private key file contents.
125 """
126 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)):
127 raise ValueError("Must specify exactly one of file or info")
128
129 if file is not None:
130 with open(file, "r") as f:
131 self._info = json.load(f)
132 elif info is not None:
133 self._info = copy.deepcopy(info)
134
[docs]
135 def get_credentials(self) -> Credentials:
136 return Credentials(
137 access_key="",
138 secret_key="",
139 token=None,
140 expiration=None,
141 custom_fields={"info": copy.deepcopy(self._info)},
142 )
143
[docs]
144 def refresh_credentials(self) -> None:
145 pass
146
147
[docs]
148class GoogleStorageProvider(BaseStorageProvider):
149 """
150 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
151 """
152
153 def __init__(
154 self,
155 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
156 endpoint_url: str = "",
157 base_path: str = "",
158 credentials_provider: Optional[CredentialsProvider] = None,
159 config_dict: Optional[dict[str, Any]] = None,
160 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
161 **kwargs: Any,
162 ):
163 """
164 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
165
166 :param project_id: The Google Cloud project ID.
167 :param endpoint_url: The custom endpoint URL for the GCS service.
168 :param base_path: The root prefix path within the bucket where all operations will be scoped.
169 :param credentials_provider: The provider to retrieve GCS credentials.
170 :param config_dict: Resolved MSC config.
171 :param telemetry_provider: A function that provides a telemetry instance.
172 """
173 super().__init__(
174 base_path=base_path,
175 provider_name=PROVIDER,
176 config_dict=config_dict,
177 telemetry_provider=telemetry_provider,
178 )
179
180 self._project_id = project_id
181 self._endpoint_url = endpoint_url
182 self._credentials_provider = credentials_provider
183 self._skip_signature = kwargs.get("skip_signature", False)
184 self._gcs_client = self._create_gcs_client()
185 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
186 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
187 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE)
188 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)
189 self._rust_client = None
190 if "rust_client" in kwargs:
191 # Inherit the rust client options from the kwargs
192 rust_client_options = copy.deepcopy(kwargs["rust_client"])
193 if "max_concurrency" in kwargs:
194 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
195 if "multipart_chunksize" in kwargs:
196 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
197 if "read_timeout" in kwargs:
198 rust_client_options["read_timeout"] = kwargs["read_timeout"]
199 if "connect_timeout" in kwargs:
200 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
201 if "service_account_key" not in rust_client_options and isinstance(
202 self._credentials_provider, GoogleServiceAccountCredentialsProvider
203 ):
204 rust_client_options["service_account_key"] = json.dumps(
205 self._credentials_provider.get_credentials().get_custom_field("info")
206 )
207 self._rust_client = self._create_rust_client(rust_client_options)
208
209 def _create_gcs_client(self) -> storage.Client:
210 client_options = {}
211 if self._endpoint_url:
212 client_options["api_endpoint"] = self._endpoint_url
213
214 # Use anonymous credentials for public buckets when skip_signature is enabled
215 if self._skip_signature:
216 return storage.Client(
217 project=self._project_id,
218 credentials=auth_credentials.AnonymousCredentials(),
219 client_options=client_options,
220 )
221
222 if self._credentials_provider:
223 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
224 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
225 token = self._credentials_provider.get_credentials().get_custom_field("token")
226
227 # Use Workload Identity Federation (WIF)
228 identity_pool_credentials = identity_pool.Credentials(
229 audience=audience,
230 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
231 subject_token_supplier=StringTokenSupplier(token),
232 )
233 return storage.Client(
234 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
235 )
236 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
237 # Use service account key.
238 service_account_credentials = service_account.Credentials.from_service_account_info(
239 info=self._credentials_provider.get_credentials().get_custom_field("info")
240 )
241 return storage.Client(
242 project=self._project_id, credentials=service_account_credentials, client_options=client_options
243 )
244 else:
245 # Use OAuth 2.0 token
246 token = self._credentials_provider.get_credentials().token
247 creds = OAuth2Credentials(token=token)
248 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
249 else:
250 return storage.Client(project=self._project_id, client_options=client_options)
251
252 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
253 if self._credentials_provider or self._endpoint_url:
254 # Workload Identity Federation (WIF) is not supported by the rust client:
255 # https://github.com/apache/arrow-rs-object-store/issues/258
256 logger.warning(
257 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client"
258 )
259 return None
260 configs = dict(rust_client_options) if rust_client_options else {}
261
262 # Extract and parse retry configuration
263 retry_config = parse_retry_config(configs)
264
265 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
266 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
267 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"):
268 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY")
269 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"):
270 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT")
271 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"):
272 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH")
273
274 if self._skip_signature and "skip_signature" not in configs:
275 configs["skip_signature"] = True
276
277 if "bucket" not in configs:
278 bucket, _ = split_path(self._base_path)
279 configs["bucket"] = bucket
280
281 return RustClient(
282 provider=PROVIDER,
283 configs=configs,
284 retry=retry_config,
285 )
286
287 def _refresh_gcs_client_if_needed(self) -> None:
288 """
289 Refreshes the GCS client if the current credentials are expired.
290 """
291 if self._credentials_provider:
292 credentials = self._credentials_provider.get_credentials()
293 if credentials.is_expired():
294 self._credentials_provider.refresh_credentials()
295 self._gcs_client = self._create_gcs_client()
296
297 def _translate_errors(
298 self,
299 func: Callable[[], _T],
300 operation: str,
301 bucket: str,
302 key: str,
303 ) -> _T:
304 """
305 Translates errors like timeouts and client errors.
306
307 :param func: The function that performs the actual GCS operation.
308 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
309 :param bucket: The name of the GCS bucket involved in the operation.
310 :param key: The key of the object within the GCS bucket.
311
312 :return: The result of the GCS operation, typically the return value of the `func` callable.
313 """
314 try:
315 return func()
316 except GoogleAPICallError as error:
317 status_code = error.code if error.code else -1
318 error_info = f"status_code: {status_code}, message: {error.message}"
319 if status_code == 404:
320 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
321 elif status_code == 412:
322 raise PreconditionFailedError(
323 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
324 ) from error
325 elif status_code == 304:
326 # for if_none_match with a specific etag condition.
327 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
328 else:
329 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
330 except InvalidResponse as error:
331 response_text = error.response.text
332 error_details = f"error: {error}, error_response_text: {response_text}"
333 # Check for NoSuchUpload within the response text
334 if "NoSuchUpload" in response_text:
335 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
336 else:
337 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
338 except RustRetryableError as error:
339 raise RetryableError(
340 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
341 f"error_type: {type(error).__name__}"
342 ) from error
343 except RustClientError as error:
344 message = error.args[0]
345 status_code = error.args[1]
346 if status_code == 404:
347 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
348 elif status_code == 403:
349 raise PermissionError(
350 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
351 ) from error
352 else:
353 raise RetryableError(
354 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
355 ) from error
356 except Exception as error:
357 error_details = str(error)
358 raise RuntimeError(
359 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
360 ) from error
361
362 def _put_object(
363 self,
364 path: str,
365 body: bytes,
366 if_match: Optional[str] = None,
367 if_none_match: Optional[str] = None,
368 attributes: Optional[dict[str, str]] = None,
369 ) -> int:
370 """
371 Uploads an object to Google Cloud Storage.
372
373 :param path: The path to the object to upload.
374 :param body: The content of the object to upload.
375 :param if_match: Optional ETag to match against the object.
376 :param if_none_match: Optional ETag to match against the object.
377 :param attributes: Optional attributes to attach to the object.
378 """
379 bucket, key = split_path(path)
380 self._refresh_gcs_client_if_needed()
381
382 def _invoke_api() -> int:
383 bucket_obj = self._gcs_client.bucket(bucket)
384 blob = bucket_obj.blob(key)
385
386 kwargs = {}
387
388 if if_match:
389 kwargs["if_generation_match"] = int(if_match) # 412 error code
390 if if_none_match:
391 if if_none_match == "*":
392 raise NotImplementedError("if_none_match='*' is not supported for GCS")
393 else:
394 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
395
396 validated_attributes = validate_attributes(attributes)
397 if validated_attributes:
398 blob.metadata = validated_attributes
399
400 if (
401 self._rust_client
402 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
403 and not path.endswith("/")
404 and not kwargs
405 and not validated_attributes
406 ):
407 run_async_rust_client_method(self._rust_client, "put", key, body)
408 else:
409 blob.upload_from_string(body, **kwargs)
410
411 return len(body)
412
413 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
414
415 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
416 bucket, key = split_path(path)
417 self._refresh_gcs_client_if_needed()
418
419 def _invoke_api() -> bytes:
420 bucket_obj = self._gcs_client.bucket(bucket)
421 blob = bucket_obj.blob(key)
422 if byte_range:
423 if self._rust_client:
424 return run_async_rust_client_method(
425 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
426 )
427 else:
428 return blob.download_as_bytes(
429 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True
430 )
431 else:
432 if self._rust_client:
433 return run_async_rust_client_method(self._rust_client, "get", key)
434 else:
435 return blob.download_as_bytes(single_shot_download=True)
436
437 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
438
439 def _copy_object(self, src_path: str, dest_path: str) -> int:
440 src_bucket, src_key = split_path(src_path)
441 dest_bucket, dest_key = split_path(dest_path)
442 self._refresh_gcs_client_if_needed()
443
444 src_object = self._get_object_metadata(src_path)
445
446 def _invoke_api() -> int:
447 source_bucket_obj = self._gcs_client.bucket(src_bucket)
448 source_blob = source_bucket_obj.blob(src_key)
449
450 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
451 destination_blob = destination_bucket_obj.blob(dest_key)
452
453 rewrite_tokens = [None]
454 while len(rewrite_tokens) > 0:
455 rewrite_token = rewrite_tokens.pop()
456 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
457 if next_rewrite_token is not None:
458 rewrite_tokens.append(next_rewrite_token)
459
460 return src_object.content_length
461
462 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
463
464 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
465 bucket, key = split_path(path)
466 self._refresh_gcs_client_if_needed()
467
468 def _invoke_api() -> None:
469 bucket_obj = self._gcs_client.bucket(bucket)
470 blob = bucket_obj.blob(key)
471
472 # If if_match is provided, use it as a precondition
473 if if_match:
474 generation = int(if_match)
475 blob.delete(if_generation_match=generation)
476 else:
477 # No if_match check needed, just delete
478 blob.delete()
479
480 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
481
482 def _is_dir(self, path: str) -> bool:
483 # Ensure the path ends with '/' to mimic a directory
484 path = self._append_delimiter(path)
485
486 bucket, key = split_path(path)
487 self._refresh_gcs_client_if_needed()
488
489 def _invoke_api() -> bool:
490 bucket_obj = self._gcs_client.bucket(bucket)
491 # List objects with the given prefix
492 blobs = bucket_obj.list_blobs(
493 prefix=key,
494 delimiter="/",
495 )
496 # Check if there are any contents or common prefixes
497 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
498
499 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
500
501 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
502 bucket, key = split_path(path)
503 if path.endswith("/") or (bucket and not key):
504 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
505 # which metadata is not guaranteed to exist for cases such as
506 # "virtual prefix" that was never explicitly created.
507 if self._is_dir(path):
508 return ObjectMetadata(
509 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
510 )
511 else:
512 raise FileNotFoundError(f"Directory {path} does not exist.")
513 else:
514 self._refresh_gcs_client_if_needed()
515
516 def _invoke_api() -> ObjectMetadata:
517 bucket_obj = self._gcs_client.bucket(bucket)
518 blob = bucket_obj.get_blob(key)
519 if not blob:
520 raise NotFound(f"Blob {key} not found in bucket {bucket}")
521 return ObjectMetadata(
522 key=path,
523 content_length=blob.size or 0,
524 content_type=blob.content_type,
525 last_modified=blob.updated or AWARE_DATETIME_MIN,
526 etag=str(blob.generation),
527 metadata=dict(blob.metadata) if blob.metadata else None,
528 )
529
530 try:
531 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
532 except FileNotFoundError as error:
533 if strict:
534 # If the object does not exist on the given path, we will append a trailing slash and
535 # check if the path is a directory.
536 path = self._append_delimiter(path)
537 if self._is_dir(path):
538 return ObjectMetadata(
539 key=path,
540 type="directory",
541 content_length=0,
542 last_modified=AWARE_DATETIME_MIN,
543 )
544 raise error
545
546 def _list_objects(
547 self,
548 path: str,
549 start_after: Optional[str] = None,
550 end_at: Optional[str] = None,
551 include_directories: bool = False,
552 follow_symlinks: bool = True,
553 ) -> Iterator[ObjectMetadata]:
554 bucket, prefix = split_path(path)
555
556 # Get the prefix of the start_after and end_at paths relative to the bucket.
557 if start_after:
558 _, start_after = split_path(start_after)
559 if end_at:
560 _, end_at = split_path(end_at)
561
562 self._refresh_gcs_client_if_needed()
563
564 def _invoke_api() -> Iterator[ObjectMetadata]:
565 bucket_obj = self._gcs_client.bucket(bucket)
566 if include_directories:
567 blobs = bucket_obj.list_blobs(
568 prefix=prefix,
569 # This is ≥ instead of >.
570 start_offset=start_after,
571 delimiter="/",
572 )
573 else:
574 blobs = bucket_obj.list_blobs(
575 prefix=prefix,
576 # This is ≥ instead of >.
577 start_offset=start_after,
578 )
579
580 # GCS guarantees lexicographical order.
581 for blob in blobs:
582 key = blob.name
583 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
584 if key.endswith("/"):
585 if include_directories:
586 yield ObjectMetadata(
587 key=os.path.join(bucket, key.rstrip("/")),
588 type="directory",
589 content_length=0,
590 last_modified=blob.updated,
591 )
592 else:
593 yield ObjectMetadata(
594 key=os.path.join(bucket, key),
595 content_length=blob.size,
596 content_type=blob.content_type,
597 last_modified=blob.updated,
598 etag=blob.etag,
599 )
600 elif start_after != key:
601 return
602
603 # The directories must be accessed last.
604 if include_directories:
605 for directory in blobs.prefixes:
606 yield ObjectMetadata(
607 key=os.path.join(bucket, directory.rstrip("/")),
608 type="directory",
609 content_length=0,
610 last_modified=AWARE_DATETIME_MIN,
611 )
612
613 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
614
615 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
616 bucket, key = split_path(remote_path)
617 file_size: int = 0
618 self._refresh_gcs_client_if_needed()
619
620 if isinstance(f, str):
621 file_size = os.path.getsize(f)
622
623 # Upload small files
624 if file_size <= self._multipart_threshold:
625 if self._rust_client and not attributes:
626 run_async_rust_client_method(self._rust_client, "upload", f, key)
627 else:
628 with open(f, "rb") as fp:
629 self._put_object(remote_path, fp.read(), attributes=attributes)
630 return file_size
631
632 # Upload large files using transfer manager
633 def _invoke_api() -> int:
634 if self._rust_client and not attributes:
635 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
636 else:
637 bucket_obj = self._gcs_client.bucket(bucket)
638 blob = bucket_obj.blob(key)
639 # GCS will raise an error if blob.metadata is None
640 validated_attributes = validate_attributes(attributes)
641 if validated_attributes is not None:
642 blob.metadata = validated_attributes
643 transfer_manager.upload_chunks_concurrently(
644 f,
645 blob,
646 chunk_size=self._multipart_chunksize,
647 max_workers=self._max_concurrency,
648 worker_type=transfer_manager.THREAD,
649 )
650
651 return file_size
652
653 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
654 else:
655 f.seek(0, io.SEEK_END)
656 file_size = f.tell()
657 f.seek(0)
658
659 # Upload small files
660 if file_size <= self._multipart_threshold:
661 if isinstance(f, io.StringIO):
662 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
663 else:
664 self._put_object(remote_path, f.read(), attributes=attributes)
665 return file_size
666
667 # Upload large files using transfer manager
668 def _invoke_api() -> int:
669 bucket_obj = self._gcs_client.bucket(bucket)
670 blob = bucket_obj.blob(key)
671 validated_attributes = validate_attributes(attributes)
672 if validated_attributes:
673 blob.metadata = validated_attributes
674 if isinstance(f, io.StringIO):
675 mode = "w"
676 else:
677 mode = "wb"
678
679 # transfer manager does not support uploading a file object
680 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
681 temp_file_path = fp.name
682 fp.write(f.read())
683
684 transfer_manager.upload_chunks_concurrently(
685 temp_file_path,
686 blob,
687 chunk_size=self._multipart_chunksize,
688 max_workers=self._max_concurrency,
689 worker_type=transfer_manager.THREAD,
690 )
691
692 os.unlink(temp_file_path)
693
694 return file_size
695
696 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
697
698 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
699 self._refresh_gcs_client_if_needed()
700
701 if metadata is None:
702 metadata = self._get_object_metadata(remote_path)
703
704 bucket, key = split_path(remote_path)
705
706 if isinstance(f, str):
707 if os.path.dirname(f):
708 os.makedirs(os.path.dirname(f), exist_ok=True)
709 # Download small files
710 if metadata.content_length <= self._multipart_threshold:
711 if self._rust_client:
712 run_async_rust_client_method(self._rust_client, "download", key, f)
713 else:
714 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
715 temp_file_path = fp.name
716 fp.write(self._get_object(remote_path))
717 os.rename(src=temp_file_path, dst=f)
718 return metadata.content_length
719
720 # Download large files using transfer manager
721 def _invoke_api() -> int:
722 bucket_obj = self._gcs_client.bucket(bucket)
723 blob = bucket_obj.blob(key)
724 if self._rust_client:
725 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f)
726 else:
727 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
728 temp_file_path = fp.name
729 transfer_manager.download_chunks_concurrently(
730 blob,
731 temp_file_path,
732 chunk_size=self._io_chunksize,
733 max_workers=self._max_concurrency,
734 worker_type=transfer_manager.THREAD,
735 )
736 os.rename(src=temp_file_path, dst=f)
737
738 return metadata.content_length
739
740 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
741 else:
742 # Download small files
743 if metadata.content_length <= self._multipart_threshold:
744 response = self._get_object(remote_path)
745 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
746 # so we need to check whether `.decode()` is available.
747 if isinstance(f, io.StringIO):
748 if hasattr(response, "decode"):
749 f.write(response.decode("utf-8"))
750 else:
751 f.write(codecs.decode(memoryview(response), "utf-8"))
752 else:
753 f.write(response)
754 return metadata.content_length
755
756 # Download large files using transfer manager
757 def _invoke_api() -> int:
758 bucket_obj = self._gcs_client.bucket(bucket)
759 blob = bucket_obj.blob(key)
760
761 # transfer manager does not support downloading to a file object
762 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
763 temp_file_path = fp.name
764 transfer_manager.download_chunks_concurrently(
765 blob,
766 temp_file_path,
767 chunk_size=self._io_chunksize,
768 max_workers=self._max_concurrency,
769 worker_type=transfer_manager.THREAD,
770 )
771
772 if isinstance(f, io.StringIO):
773 with open(temp_file_path, "r") as fp:
774 f.write(fp.read())
775 else:
776 with open(temp_file_path, "rb") as fp:
777 f.write(fp.read())
778
779 os.unlink(temp_file_path)
780
781 return metadata.content_length
782
783 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)