1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import copy
18import io
19import json
20import logging
21import os
22import tempfile
23from collections.abc import Callable, Iterator
24from typing import IO, Any, Optional, TypeVar, Union
25
26from google.api_core.exceptions import GoogleAPICallError, NotFound
27from google.auth import credentials as auth_credentials
28from google.auth import identity_pool
29from google.cloud import storage
30from google.cloud.storage import transfer_manager
31from google.cloud.storage.exceptions import InvalidResponse
32from google.oauth2 import service_account
33from google.oauth2.credentials import Credentials as OAuth2Credentials
34
35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
36
37from ..rust_utils import parse_retry_config, run_async_rust_client_method
38from ..telemetry import Telemetry
39from ..types import (
40 AWARE_DATETIME_MIN,
41 Credentials,
42 CredentialsProvider,
43 NotModifiedError,
44 ObjectMetadata,
45 PreconditionFailedError,
46 Range,
47 RetryableError,
48)
49from ..utils import (
50 safe_makedirs,
51 split_path,
52 validate_attributes,
53)
54from .base import BaseStorageProvider
55
56_T = TypeVar("_T")
57
58PROVIDER = "gcs"
59
60MiB = 1024 * 1024
61
62DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
63DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
64DEFAULT_IO_CHUNKSIZE = 32 * MiB
65PYTHON_MAX_CONCURRENCY = 8
66
67logger = logging.getLogger(__name__)
68
69
[docs]
70class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
71 """
72 Supply a string token to the Google Identity Pool.
73 """
74
75 def __init__(self, token: str):
76 self._token = token
77
[docs]
78 def get_subject_token(self, context, request):
79 return self._token
80
81
[docs]
82class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
83 """
84 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
85 """
86
87 def __init__(self, audience: str, token_supplier: str):
88 """
89 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
90
91 :param audience: The audience for the Google Identity Pool.
92 :param token_supplier: The token supplier for the Google Identity Pool.
93 """
94 self._audience = audience
95 self._token_supplier = token_supplier
96
[docs]
97 def get_credentials(self) -> Credentials:
98 return Credentials(
99 access_key="",
100 secret_key="",
101 token="",
102 expiration=None,
103 custom_fields={"audience": self._audience, "token": self._token_supplier},
104 )
105
[docs]
106 def refresh_credentials(self) -> None:
107 pass
108
109
[docs]
110class GoogleServiceAccountCredentialsProvider(CredentialsProvider):
111 """
112 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials.
113 """
114
115 #: Google service account private key file contents.
116 _info: dict[str, Any]
117
118 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None):
119 """
120 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a
121 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file
122 or the file contents.
123
124 :param file: Path to a Google service account private key file.
125 :param info: Google service account private key file contents.
126 """
127 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)):
128 raise ValueError("Must specify exactly one of file or info")
129
130 if file is not None:
131 with open(file, "r") as f:
132 self._info = json.load(f)
133 elif info is not None:
134 self._info = copy.deepcopy(info)
135
[docs]
136 def get_credentials(self) -> Credentials:
137 return Credentials(
138 access_key="",
139 secret_key="",
140 token=None,
141 expiration=None,
142 custom_fields={"info": copy.deepcopy(self._info)},
143 )
144
[docs]
145 def refresh_credentials(self) -> None:
146 pass
147
148
[docs]
149class GoogleStorageProvider(BaseStorageProvider):
150 """
151 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
152 """
153
154 def __init__(
155 self,
156 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
157 endpoint_url: str = "",
158 base_path: str = "",
159 credentials_provider: Optional[CredentialsProvider] = None,
160 config_dict: Optional[dict[str, Any]] = None,
161 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
162 **kwargs: Any,
163 ):
164 """
165 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
166
167 :param project_id: The Google Cloud project ID.
168 :param endpoint_url: The custom endpoint URL for the GCS service.
169 :param base_path: The root prefix path within the bucket where all operations will be scoped.
170 :param credentials_provider: The provider to retrieve GCS credentials.
171 :param config_dict: Resolved MSC config.
172 :param telemetry_provider: A function that provides a telemetry instance.
173 """
174 super().__init__(
175 base_path=base_path,
176 provider_name=PROVIDER,
177 config_dict=config_dict,
178 telemetry_provider=telemetry_provider,
179 )
180
181 self._project_id = project_id
182 self._endpoint_url = endpoint_url
183 self._credentials_provider = credentials_provider
184 self._skip_signature = kwargs.get("skip_signature", False)
185 self._gcs_client = self._create_gcs_client()
186 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
187 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
188 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE)
189 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)
190 self._rust_client = None
191 if "rust_client" in kwargs:
192 # Inherit the rust client options from the kwargs
193 rust_client_options = copy.deepcopy(kwargs["rust_client"])
194 if "max_concurrency" in kwargs:
195 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
196 if "multipart_chunksize" in kwargs:
197 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
198 if "read_timeout" in kwargs:
199 rust_client_options["read_timeout"] = kwargs["read_timeout"]
200 if "connect_timeout" in kwargs:
201 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
202 self._rust_client = self._create_rust_client(rust_client_options)
203
204 def _create_gcs_client(self) -> storage.Client:
205 client_options = {}
206 if self._endpoint_url:
207 client_options["api_endpoint"] = self._endpoint_url
208
209 # Use anonymous credentials for public buckets when skip_signature is enabled
210 if self._skip_signature:
211 return storage.Client(
212 project=self._project_id,
213 credentials=auth_credentials.AnonymousCredentials(),
214 client_options=client_options,
215 )
216
217 if self._credentials_provider:
218 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
219 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
220 token = self._credentials_provider.get_credentials().get_custom_field("token")
221
222 # Use Workload Identity Federation (WIF)
223 identity_pool_credentials = identity_pool.Credentials(
224 audience=audience,
225 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
226 subject_token_supplier=StringTokenSupplier(token),
227 )
228 return storage.Client(
229 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
230 )
231 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
232 # Use service account key.
233 service_account_credentials = service_account.Credentials.from_service_account_info(
234 info=self._credentials_provider.get_credentials().get_custom_field("info")
235 )
236 return storage.Client(
237 project=self._project_id, credentials=service_account_credentials, client_options=client_options
238 )
239 else:
240 # Use OAuth 2.0 token
241 token = self._credentials_provider.get_credentials().token
242 creds = OAuth2Credentials(token=token)
243 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
244 else:
245 return storage.Client(project=self._project_id, client_options=client_options)
246
247 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
248 if self._endpoint_url:
249 logger.warning("Rust client for GCS does not support customized endpoint URL, skipping rust client")
250 return None
251
252 configs = dict(rust_client_options) if rust_client_options else {}
253
254 # Extract and parse retry configuration
255 retry_config = parse_retry_config(configs)
256
257 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
258 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
259 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"):
260 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY")
261 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"):
262 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT")
263 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"):
264 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH")
265
266 if self._skip_signature and "skip_signature" not in configs:
267 configs["skip_signature"] = True
268
269 if "bucket" not in configs:
270 bucket, _ = split_path(self._base_path)
271 configs["bucket"] = bucket
272
273 if self._credentials_provider:
274 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
275 # Workload Identity Federation (WIF) is not supported by the rust client:
276 # https://github.com/apache/arrow-rs-object-store/issues/258
277 logger.warning("Rust client for GCS doesn't support Workload Identity Federation, skipping rust client")
278 return None
279 if isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
280 # Use service account key.
281 configs["service_account_key"] = json.dumps(
282 self._credentials_provider.get_credentials().get_custom_field("info")
283 )
284 return RustClient(
285 provider=PROVIDER,
286 configs=configs,
287 retry=retry_config,
288 )
289 try:
290 return RustClient(
291 provider=PROVIDER,
292 configs=configs,
293 credentials_provider=self._credentials_provider,
294 retry=retry_config,
295 )
296 except Exception as e:
297 logger.warning(f"Failed to create rust client for GCS: {e}, falling back to Python client")
298 return None
299
300 def _refresh_gcs_client_if_needed(self) -> None:
301 """
302 Refreshes the GCS client if the current credentials are expired.
303 """
304 if self._credentials_provider:
305 credentials = self._credentials_provider.get_credentials()
306 if credentials.is_expired():
307 self._credentials_provider.refresh_credentials()
308 self._gcs_client = self._create_gcs_client()
309
310 def _translate_errors(
311 self,
312 func: Callable[[], _T],
313 operation: str,
314 bucket: str,
315 key: str,
316 ) -> _T:
317 """
318 Translates errors like timeouts and client errors.
319
320 :param func: The function that performs the actual GCS operation.
321 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
322 :param bucket: The name of the GCS bucket involved in the operation.
323 :param key: The key of the object within the GCS bucket.
324
325 :return: The result of the GCS operation, typically the return value of the `func` callable.
326 """
327 try:
328 return func()
329 except GoogleAPICallError as error:
330 status_code = error.code if error.code else -1
331 error_info = f"status_code: {status_code}, message: {error.message}"
332 if status_code == 404:
333 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
334 elif status_code == 412:
335 raise PreconditionFailedError(
336 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
337 ) from error
338 elif status_code == 304:
339 # for if_none_match with a specific etag condition.
340 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
341 else:
342 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
343 except InvalidResponse as error:
344 response_text = error.response.text
345 error_details = f"error: {error}, error_response_text: {response_text}"
346 # Check for NoSuchUpload within the response text
347 if "NoSuchUpload" in response_text:
348 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
349 else:
350 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
351 except RustRetryableError as error:
352 raise RetryableError(
353 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
354 f"error_type: {type(error).__name__}"
355 ) from error
356 except RustClientError as error:
357 message = error.args[0]
358 status_code = error.args[1]
359 if status_code == 404:
360 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
361 elif status_code == 403:
362 raise PermissionError(
363 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
364 ) from error
365 else:
366 raise RetryableError(
367 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
368 ) from error
369 except Exception as error:
370 error_details = str(error)
371 raise RuntimeError(
372 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
373 ) from error
374
375 def _put_object(
376 self,
377 path: str,
378 body: bytes,
379 if_match: Optional[str] = None,
380 if_none_match: Optional[str] = None,
381 attributes: Optional[dict[str, str]] = None,
382 ) -> int:
383 """
384 Uploads an object to Google Cloud Storage.
385
386 :param path: The path to the object to upload.
387 :param body: The content of the object to upload.
388 :param if_match: Optional ETag to match against the object.
389 :param if_none_match: Optional ETag to match against the object.
390 :param attributes: Optional attributes to attach to the object.
391 """
392 bucket, key = split_path(path)
393 self._refresh_gcs_client_if_needed()
394
395 def _invoke_api() -> int:
396 bucket_obj = self._gcs_client.bucket(bucket)
397 blob = bucket_obj.blob(key)
398
399 kwargs = {}
400
401 if if_match:
402 kwargs["if_generation_match"] = int(if_match) # 412 error code
403 if if_none_match:
404 if if_none_match == "*":
405 raise NotImplementedError("if_none_match='*' is not supported for GCS")
406 else:
407 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
408
409 validated_attributes = validate_attributes(attributes)
410 if validated_attributes:
411 blob.metadata = validated_attributes
412
413 if (
414 self._rust_client
415 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
416 and not path.endswith("/")
417 and not kwargs
418 and not validated_attributes
419 ):
420 run_async_rust_client_method(self._rust_client, "put", key, body)
421 else:
422 blob.upload_from_string(body, **kwargs)
423
424 return len(body)
425
426 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
427
428 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
429 bucket, key = split_path(path)
430 self._refresh_gcs_client_if_needed()
431
432 def _invoke_api() -> bytes:
433 bucket_obj = self._gcs_client.bucket(bucket)
434 blob = bucket_obj.blob(key)
435 if byte_range:
436 if self._rust_client:
437 return run_async_rust_client_method(
438 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
439 )
440 else:
441 return blob.download_as_bytes(
442 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True
443 )
444 else:
445 if self._rust_client:
446 return run_async_rust_client_method(self._rust_client, "get", key)
447 else:
448 return blob.download_as_bytes(single_shot_download=True)
449
450 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
451
452 def _copy_object(self, src_path: str, dest_path: str) -> int:
453 src_bucket, src_key = split_path(src_path)
454 dest_bucket, dest_key = split_path(dest_path)
455 self._refresh_gcs_client_if_needed()
456
457 src_object = self._get_object_metadata(src_path)
458
459 def _invoke_api() -> int:
460 source_bucket_obj = self._gcs_client.bucket(src_bucket)
461 source_blob = source_bucket_obj.blob(src_key)
462
463 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
464 destination_blob = destination_bucket_obj.blob(dest_key)
465
466 rewrite_tokens = [None]
467 while len(rewrite_tokens) > 0:
468 rewrite_token = rewrite_tokens.pop()
469 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
470 if next_rewrite_token is not None:
471 rewrite_tokens.append(next_rewrite_token)
472
473 return src_object.content_length
474
475 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
476
477 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
478 bucket, key = split_path(path)
479 self._refresh_gcs_client_if_needed()
480
481 def _invoke_api() -> None:
482 bucket_obj = self._gcs_client.bucket(bucket)
483 blob = bucket_obj.blob(key)
484
485 # If if_match is provided, use it as a precondition
486 if if_match:
487 generation = int(if_match)
488 blob.delete(if_generation_match=generation)
489 else:
490 # No if_match check needed, just delete
491 blob.delete()
492
493 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
494
495 def _is_dir(self, path: str) -> bool:
496 # Ensure the path ends with '/' to mimic a directory
497 path = self._append_delimiter(path)
498
499 bucket, key = split_path(path)
500 self._refresh_gcs_client_if_needed()
501
502 def _invoke_api() -> bool:
503 bucket_obj = self._gcs_client.bucket(bucket)
504 # List objects with the given prefix
505 blobs = bucket_obj.list_blobs(
506 prefix=key,
507 delimiter="/",
508 )
509 # Check if there are any contents or common prefixes
510 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
511
512 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
513
514 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
515 bucket, key = split_path(path)
516 if path.endswith("/") or (bucket and not key):
517 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
518 # which metadata is not guaranteed to exist for cases such as
519 # "virtual prefix" that was never explicitly created.
520 if self._is_dir(path):
521 return ObjectMetadata(
522 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
523 )
524 else:
525 raise FileNotFoundError(f"Directory {path} does not exist.")
526 else:
527 self._refresh_gcs_client_if_needed()
528
529 def _invoke_api() -> ObjectMetadata:
530 bucket_obj = self._gcs_client.bucket(bucket)
531 blob = bucket_obj.get_blob(key)
532 if not blob:
533 raise NotFound(f"Blob {key} not found in bucket {bucket}")
534 return ObjectMetadata(
535 key=path,
536 content_length=blob.size or 0,
537 content_type=blob.content_type,
538 last_modified=blob.updated or AWARE_DATETIME_MIN,
539 etag=str(blob.generation),
540 metadata=dict(blob.metadata) if blob.metadata else None,
541 )
542
543 try:
544 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
545 except FileNotFoundError as error:
546 if strict:
547 # If the object does not exist on the given path, we will append a trailing slash and
548 # check if the path is a directory.
549 path = self._append_delimiter(path)
550 if self._is_dir(path):
551 return ObjectMetadata(
552 key=path,
553 type="directory",
554 content_length=0,
555 last_modified=AWARE_DATETIME_MIN,
556 )
557 raise error
558
559 def _list_objects(
560 self,
561 path: str,
562 start_after: Optional[str] = None,
563 end_at: Optional[str] = None,
564 include_directories: bool = False,
565 follow_symlinks: bool = True,
566 ) -> Iterator[ObjectMetadata]:
567 bucket, prefix = split_path(path)
568
569 # Get the prefix of the start_after and end_at paths relative to the bucket.
570 if start_after:
571 _, start_after = split_path(start_after)
572 if end_at:
573 _, end_at = split_path(end_at)
574
575 self._refresh_gcs_client_if_needed()
576
577 def _invoke_api() -> Iterator[ObjectMetadata]:
578 bucket_obj = self._gcs_client.bucket(bucket)
579 if include_directories:
580 blobs = bucket_obj.list_blobs(
581 prefix=prefix,
582 # This is ≥ instead of >.
583 start_offset=start_after,
584 delimiter="/",
585 )
586 else:
587 blobs = bucket_obj.list_blobs(
588 prefix=prefix,
589 # This is ≥ instead of >.
590 start_offset=start_after,
591 )
592
593 # GCS guarantees lexicographical order.
594 for blob in blobs:
595 key = blob.name
596 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
597 if key.endswith("/"):
598 if include_directories:
599 yield ObjectMetadata(
600 key=os.path.join(bucket, key.rstrip("/")),
601 type="directory",
602 content_length=0,
603 last_modified=blob.updated,
604 )
605 else:
606 yield ObjectMetadata(
607 key=os.path.join(bucket, key),
608 content_length=blob.size,
609 content_type=blob.content_type,
610 last_modified=blob.updated,
611 etag=blob.etag,
612 )
613 elif start_after != key:
614 return
615
616 # The directories must be accessed last.
617 if include_directories:
618 for directory in blobs.prefixes:
619 prefix_key = directory.rstrip("/")
620 # Filter by start_after and end_at if specified
621 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
622 yield ObjectMetadata(
623 key=os.path.join(bucket, prefix_key),
624 type="directory",
625 content_length=0,
626 last_modified=AWARE_DATETIME_MIN,
627 )
628
629 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
630
631 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
632 bucket, key = split_path(remote_path)
633 file_size: int = 0
634 self._refresh_gcs_client_if_needed()
635
636 if isinstance(f, str):
637 file_size = os.path.getsize(f)
638
639 # Upload small files
640 if file_size <= self._multipart_threshold:
641 if self._rust_client and not attributes:
642 run_async_rust_client_method(self._rust_client, "upload", f, key)
643 else:
644 with open(f, "rb") as fp:
645 self._put_object(remote_path, fp.read(), attributes=attributes)
646 return file_size
647
648 # Upload large files using transfer manager
649 def _invoke_api() -> int:
650 if self._rust_client and not attributes:
651 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
652 else:
653 bucket_obj = self._gcs_client.bucket(bucket)
654 blob = bucket_obj.blob(key)
655 # GCS will raise an error if blob.metadata is None
656 validated_attributes = validate_attributes(attributes)
657 if validated_attributes is not None:
658 blob.metadata = validated_attributes
659 transfer_manager.upload_chunks_concurrently(
660 f,
661 blob,
662 chunk_size=self._multipart_chunksize,
663 max_workers=self._max_concurrency,
664 worker_type=transfer_manager.THREAD,
665 )
666
667 return file_size
668
669 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
670 else:
671 f.seek(0, io.SEEK_END)
672 file_size = f.tell()
673 f.seek(0)
674
675 # Upload small files
676 if file_size <= self._multipart_threshold:
677 if isinstance(f, io.StringIO):
678 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
679 else:
680 self._put_object(remote_path, f.read(), attributes=attributes)
681 return file_size
682
683 # Upload large files using transfer manager
684 def _invoke_api() -> int:
685 bucket_obj = self._gcs_client.bucket(bucket)
686 blob = bucket_obj.blob(key)
687 validated_attributes = validate_attributes(attributes)
688 if validated_attributes:
689 blob.metadata = validated_attributes
690 if isinstance(f, io.StringIO):
691 mode = "w"
692 else:
693 mode = "wb"
694
695 # transfer manager does not support uploading a file object
696 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
697 temp_file_path = fp.name
698 fp.write(f.read())
699
700 transfer_manager.upload_chunks_concurrently(
701 temp_file_path,
702 blob,
703 chunk_size=self._multipart_chunksize,
704 max_workers=self._max_concurrency,
705 worker_type=transfer_manager.THREAD,
706 )
707
708 os.unlink(temp_file_path)
709
710 return file_size
711
712 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
713
714 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
715 self._refresh_gcs_client_if_needed()
716
717 if metadata is None:
718 metadata = self._get_object_metadata(remote_path)
719
720 bucket, key = split_path(remote_path)
721
722 if isinstance(f, str):
723 if os.path.dirname(f):
724 safe_makedirs(os.path.dirname(f))
725 # Download small files
726 if metadata.content_length <= self._multipart_threshold:
727 if self._rust_client:
728 run_async_rust_client_method(self._rust_client, "download", key, f)
729 else:
730 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
731 temp_file_path = fp.name
732 fp.write(self._get_object(remote_path))
733 os.rename(src=temp_file_path, dst=f)
734 return metadata.content_length
735
736 # Download large files using transfer manager
737 def _invoke_api() -> int:
738 bucket_obj = self._gcs_client.bucket(bucket)
739 blob = bucket_obj.blob(key)
740 if self._rust_client:
741 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f)
742 else:
743 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
744 temp_file_path = fp.name
745 transfer_manager.download_chunks_concurrently(
746 blob,
747 temp_file_path,
748 chunk_size=self._io_chunksize,
749 max_workers=self._max_concurrency,
750 worker_type=transfer_manager.THREAD,
751 )
752 os.rename(src=temp_file_path, dst=f)
753
754 return metadata.content_length
755
756 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
757 else:
758 # Download small files
759 if metadata.content_length <= self._multipart_threshold:
760 response = self._get_object(remote_path)
761 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
762 # so we need to check whether `.decode()` is available.
763 if isinstance(f, io.StringIO):
764 if hasattr(response, "decode"):
765 f.write(response.decode("utf-8"))
766 else:
767 f.write(codecs.decode(memoryview(response), "utf-8"))
768 else:
769 f.write(response)
770 return metadata.content_length
771
772 # Download large files using transfer manager
773 def _invoke_api() -> int:
774 bucket_obj = self._gcs_client.bucket(bucket)
775 blob = bucket_obj.blob(key)
776
777 # transfer manager does not support downloading to a file object
778 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
779 temp_file_path = fp.name
780 transfer_manager.download_chunks_concurrently(
781 blob,
782 temp_file_path,
783 chunk_size=self._io_chunksize,
784 max_workers=self._max_concurrency,
785 worker_type=transfer_manager.THREAD,
786 )
787
788 if isinstance(f, io.StringIO):
789 with open(temp_file_path, "r") as fp:
790 f.write(fp.read())
791 else:
792 with open(temp_file_path, "rb") as fp:
793 f.write(fp.read())
794
795 os.unlink(temp_file_path)
796
797 return metadata.content_length
798
799 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)