1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import codecs
17import copy
18import io
19import json
20import logging
21import os
22import tempfile
23from collections.abc import Callable, Iterator
24from typing import IO, Any, Optional, TypeVar, Union
25
26from google.api_core.exceptions import GoogleAPICallError, NotFound
27from google.auth import credentials as auth_credentials
28from google.auth import identity_pool
29from google.cloud import storage
30from google.cloud.storage import transfer_manager
31from google.cloud.storage.exceptions import InvalidResponse
32from google.oauth2 import service_account
33from google.oauth2.credentials import Credentials as OAuth2Credentials
34
35from multistorageclient_rust import RustClient, RustClientError, RustRetryableError
36
37from ..rust_utils import run_async_rust_client_method
38from ..telemetry import Telemetry
39from ..types import (
40 AWARE_DATETIME_MIN,
41 Credentials,
42 CredentialsProvider,
43 NotModifiedError,
44 ObjectMetadata,
45 PreconditionFailedError,
46 Range,
47 RetryableError,
48)
49from ..utils import (
50 split_path,
51 validate_attributes,
52)
53from .base import BaseStorageProvider
54
55_T = TypeVar("_T")
56
57PROVIDER = "gcs"
58
59MiB = 1024 * 1024
60
61DEFAULT_MULTIPART_THRESHOLD = 64 * MiB
62DEFAULT_MULTIPART_CHUNKSIZE = 32 * MiB
63DEFAULT_IO_CHUNKSIZE = 32 * MiB
64PYTHON_MAX_CONCURRENCY = 8
65
66logger = logging.getLogger(__name__)
67
68
[docs]
69class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
70 """
71 Supply a string token to the Google Identity Pool.
72 """
73
74 def __init__(self, token: str):
75 self._token = token
76
[docs]
77 def get_subject_token(self, context, request):
78 return self._token
79
80
[docs]
81class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
82 """
83 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
84 """
85
86 def __init__(self, audience: str, token_supplier: str):
87 """
88 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
89
90 :param audience: The audience for the Google Identity Pool.
91 :param token_supplier: The token supplier for the Google Identity Pool.
92 """
93 self._audience = audience
94 self._token_supplier = token_supplier
95
[docs]
96 def get_credentials(self) -> Credentials:
97 return Credentials(
98 access_key="",
99 secret_key="",
100 token="",
101 expiration=None,
102 custom_fields={"audience": self._audience, "token": self._token_supplier},
103 )
104
[docs]
105 def refresh_credentials(self) -> None:
106 pass
107
108
[docs]
109class GoogleServiceAccountCredentialsProvider(CredentialsProvider):
110 """
111 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's service account credentials.
112 """
113
114 #: Google service account private key file contents.
115 _info: dict[str, Any]
116
117 def __init__(self, file: Optional[str] = None, info: Optional[dict[str, Any]] = None):
118 """
119 Initializes the :py:class:`GoogleServiceAccountCredentialsProvider` with either a path to a
120 `Google service account private key <https://docs.cloud.google.com/iam/docs/keys-create-delete#creating>`_ file
121 or the file contents.
122
123 :param file: Path to a Google service account private key file.
124 :param info: Google service account private key file contents.
125 """
126 if all(_ is None for _ in (file, info)) or all(_ is not None for _ in (file, info)):
127 raise ValueError("Must specify exactly one of file or info")
128
129 if file is not None:
130 with open(file, "r") as f:
131 self._info = json.load(f)
132 elif info is not None:
133 self._info = copy.deepcopy(info)
134
[docs]
135 def get_credentials(self) -> Credentials:
136 return Credentials(
137 access_key="",
138 secret_key="",
139 token=None,
140 expiration=None,
141 custom_fields={"info": copy.deepcopy(self._info)},
142 )
143
[docs]
144 def refresh_credentials(self) -> None:
145 pass
146
147
[docs]
148class GoogleStorageProvider(BaseStorageProvider):
149 """
150 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
151 """
152
153 def __init__(
154 self,
155 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
156 endpoint_url: str = "",
157 base_path: str = "",
158 credentials_provider: Optional[CredentialsProvider] = None,
159 config_dict: Optional[dict[str, Any]] = None,
160 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
161 **kwargs: Any,
162 ):
163 """
164 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
165
166 :param project_id: The Google Cloud project ID.
167 :param endpoint_url: The custom endpoint URL for the GCS service.
168 :param base_path: The root prefix path within the bucket where all operations will be scoped.
169 :param credentials_provider: The provider to retrieve GCS credentials.
170 :param config_dict: Resolved MSC config.
171 :param telemetry_provider: A function that provides a telemetry instance.
172 """
173 super().__init__(
174 base_path=base_path,
175 provider_name=PROVIDER,
176 config_dict=config_dict,
177 telemetry_provider=telemetry_provider,
178 )
179
180 self._project_id = project_id
181 self._endpoint_url = endpoint_url
182 self._credentials_provider = credentials_provider
183 self._skip_signature = kwargs.get("skip_signature", False)
184 self._gcs_client = self._create_gcs_client()
185 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
186 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNKSIZE)
187 self._io_chunksize = kwargs.get("io_chunksize", DEFAULT_IO_CHUNKSIZE)
188 self._max_concurrency = kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)
189 self._rust_client = None
190 if "rust_client" in kwargs:
191 # Inherit the rust client options from the kwargs
192 rust_client_options = copy.deepcopy(kwargs["rust_client"])
193 if "max_concurrency" in kwargs:
194 rust_client_options["max_concurrency"] = kwargs["max_concurrency"]
195 if "multipart_chunksize" in kwargs:
196 rust_client_options["multipart_chunksize"] = kwargs["multipart_chunksize"]
197 if "read_timeout" in kwargs:
198 rust_client_options["read_timeout"] = kwargs["read_timeout"]
199 if "connect_timeout" in kwargs:
200 rust_client_options["connect_timeout"] = kwargs["connect_timeout"]
201 if "service_account_key" not in rust_client_options and isinstance(
202 self._credentials_provider, GoogleServiceAccountCredentialsProvider
203 ):
204 rust_client_options["service_account_key"] = json.dumps(
205 self._credentials_provider.get_credentials().get_custom_field("info")
206 )
207 self._rust_client = self._create_rust_client(rust_client_options)
208
209 def _create_gcs_client(self) -> storage.Client:
210 client_options = {}
211 if self._endpoint_url:
212 client_options["api_endpoint"] = self._endpoint_url
213
214 # Use anonymous credentials for public buckets when skip_signature is enabled
215 if self._skip_signature:
216 return storage.Client(
217 project=self._project_id,
218 credentials=auth_credentials.AnonymousCredentials(),
219 client_options=client_options,
220 )
221
222 if self._credentials_provider:
223 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
224 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
225 token = self._credentials_provider.get_credentials().get_custom_field("token")
226
227 # Use Workload Identity Federation (WIF)
228 identity_pool_credentials = identity_pool.Credentials(
229 audience=audience,
230 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
231 subject_token_supplier=StringTokenSupplier(token),
232 )
233 return storage.Client(
234 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
235 )
236 elif isinstance(self._credentials_provider, GoogleServiceAccountCredentialsProvider):
237 # Use service account key.
238 service_account_credentials = service_account.Credentials.from_service_account_info(
239 info=self._credentials_provider.get_credentials().get_custom_field("info")
240 )
241 return storage.Client(
242 project=self._project_id, credentials=service_account_credentials, client_options=client_options
243 )
244 else:
245 # Use OAuth 2.0 token
246 token = self._credentials_provider.get_credentials().token
247 creds = OAuth2Credentials(token=token)
248 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
249 else:
250 return storage.Client(project=self._project_id, client_options=client_options)
251
252 def _create_rust_client(self, rust_client_options: Optional[dict[str, Any]] = None):
253 if self._credentials_provider or self._endpoint_url:
254 # Workload Identity Federation (WIF) is not supported by the rust client:
255 # https://github.com/apache/arrow-rs-object-store/issues/258
256 logger.warning(
257 "Rust client for GCS only supports Application Default Credentials or Service Account Key, skipping rust client"
258 )
259 return None
260 configs = dict(rust_client_options) if rust_client_options else {}
261
262 if "application_credentials" not in configs and os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
263 configs["application_credentials"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
264 if "service_account_key" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY"):
265 configs["service_account_key"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY")
266 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT"):
267 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT")
268 if "service_account_path" not in configs and os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH"):
269 configs["service_account_path"] = os.getenv("GOOGLE_SERVICE_ACCOUNT_PATH")
270
271 if self._skip_signature and "skip_signature" not in configs:
272 configs["skip_signature"] = True
273
274 if "bucket" not in configs:
275 bucket, _ = split_path(self._base_path)
276 configs["bucket"] = bucket
277
278 return RustClient(
279 provider=PROVIDER,
280 configs=configs,
281 )
282
283 def _refresh_gcs_client_if_needed(self) -> None:
284 """
285 Refreshes the GCS client if the current credentials are expired.
286 """
287 if self._credentials_provider:
288 credentials = self._credentials_provider.get_credentials()
289 if credentials.is_expired():
290 self._credentials_provider.refresh_credentials()
291 self._gcs_client = self._create_gcs_client()
292
293 def _translate_errors(
294 self,
295 func: Callable[[], _T],
296 operation: str,
297 bucket: str,
298 key: str,
299 ) -> _T:
300 """
301 Translates errors like timeouts and client errors.
302
303 :param func: The function that performs the actual GCS operation.
304 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
305 :param bucket: The name of the GCS bucket involved in the operation.
306 :param key: The key of the object within the GCS bucket.
307
308 :return: The result of the GCS operation, typically the return value of the `func` callable.
309 """
310 try:
311 return func()
312 except GoogleAPICallError as error:
313 status_code = error.code if error.code else -1
314 error_info = f"status_code: {status_code}, message: {error.message}"
315 if status_code == 404:
316 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
317 elif status_code == 412:
318 raise PreconditionFailedError(
319 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
320 ) from error
321 elif status_code == 304:
322 # for if_none_match with a specific etag condition.
323 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
324 else:
325 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
326 except InvalidResponse as error:
327 response_text = error.response.text
328 error_details = f"error: {error}, error_response_text: {response_text}"
329 # Check for NoSuchUpload within the response text
330 if "NoSuchUpload" in response_text:
331 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
332 else:
333 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
334 except RustRetryableError as error:
335 raise RetryableError(
336 f"Failed to {operation} object(s) at {bucket}/{key} due to retryable error from Rust. "
337 f"error_type: {type(error).__name__}"
338 ) from error
339 except RustClientError as error:
340 message = error.args[0]
341 status_code = error.args[1]
342 if status_code == 404:
343 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {message}") from error
344 elif status_code == 403:
345 raise PermissionError(
346 f"Permission denied to {operation} object(s) at {bucket}/{key}. {message}"
347 ) from error
348 else:
349 raise RetryableError(
350 f"Failed to {operation} object(s) at {bucket}/{key}. {message}. status_code: {status_code}"
351 ) from error
352 except Exception as error:
353 error_details = str(error)
354 raise RuntimeError(
355 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
356 ) from error
357
358 def _put_object(
359 self,
360 path: str,
361 body: bytes,
362 if_match: Optional[str] = None,
363 if_none_match: Optional[str] = None,
364 attributes: Optional[dict[str, str]] = None,
365 ) -> int:
366 """
367 Uploads an object to Google Cloud Storage.
368
369 :param path: The path to the object to upload.
370 :param body: The content of the object to upload.
371 :param if_match: Optional ETag to match against the object.
372 :param if_none_match: Optional ETag to match against the object.
373 :param attributes: Optional attributes to attach to the object.
374 """
375 bucket, key = split_path(path)
376 self._refresh_gcs_client_if_needed()
377
378 def _invoke_api() -> int:
379 bucket_obj = self._gcs_client.bucket(bucket)
380 blob = bucket_obj.blob(key)
381
382 kwargs = {}
383
384 if if_match:
385 kwargs["if_generation_match"] = int(if_match) # 412 error code
386 if if_none_match:
387 if if_none_match == "*":
388 raise NotImplementedError("if_none_match='*' is not supported for GCS")
389 else:
390 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
391
392 validated_attributes = validate_attributes(attributes)
393 if validated_attributes:
394 blob.metadata = validated_attributes
395
396 if (
397 self._rust_client
398 # Rust client doesn't support creating objects with trailing /, see https://github.com/apache/arrow-rs/issues/7026
399 and not path.endswith("/")
400 and not kwargs
401 and not validated_attributes
402 ):
403 run_async_rust_client_method(self._rust_client, "put", key, body)
404 else:
405 blob.upload_from_string(body, **kwargs)
406
407 return len(body)
408
409 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
410
411 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
412 bucket, key = split_path(path)
413 self._refresh_gcs_client_if_needed()
414
415 def _invoke_api() -> bytes:
416 bucket_obj = self._gcs_client.bucket(bucket)
417 blob = bucket_obj.blob(key)
418 if byte_range:
419 if self._rust_client:
420 return run_async_rust_client_method(
421 self._rust_client, "get", key, byte_range.offset, byte_range.offset + byte_range.size - 1
422 )
423 else:
424 return blob.download_as_bytes(
425 start=byte_range.offset, end=byte_range.offset + byte_range.size - 1, single_shot_download=True
426 )
427 else:
428 if self._rust_client:
429 return run_async_rust_client_method(self._rust_client, "get", key)
430 else:
431 return blob.download_as_bytes(single_shot_download=True)
432
433 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
434
435 def _copy_object(self, src_path: str, dest_path: str) -> int:
436 src_bucket, src_key = split_path(src_path)
437 dest_bucket, dest_key = split_path(dest_path)
438 self._refresh_gcs_client_if_needed()
439
440 src_object = self._get_object_metadata(src_path)
441
442 def _invoke_api() -> int:
443 source_bucket_obj = self._gcs_client.bucket(src_bucket)
444 source_blob = source_bucket_obj.blob(src_key)
445
446 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
447 destination_blob = destination_bucket_obj.blob(dest_key)
448
449 rewrite_tokens = [None]
450 while len(rewrite_tokens) > 0:
451 rewrite_token = rewrite_tokens.pop()
452 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
453 if next_rewrite_token is not None:
454 rewrite_tokens.append(next_rewrite_token)
455
456 return src_object.content_length
457
458 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
459
460 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
461 bucket, key = split_path(path)
462 self._refresh_gcs_client_if_needed()
463
464 def _invoke_api() -> None:
465 bucket_obj = self._gcs_client.bucket(bucket)
466 blob = bucket_obj.blob(key)
467
468 # If if_match is provided, use it as a precondition
469 if if_match:
470 generation = int(if_match)
471 blob.delete(if_generation_match=generation)
472 else:
473 # No if_match check needed, just delete
474 blob.delete()
475
476 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
477
478 def _is_dir(self, path: str) -> bool:
479 # Ensure the path ends with '/' to mimic a directory
480 path = self._append_delimiter(path)
481
482 bucket, key = split_path(path)
483 self._refresh_gcs_client_if_needed()
484
485 def _invoke_api() -> bool:
486 bucket_obj = self._gcs_client.bucket(bucket)
487 # List objects with the given prefix
488 blobs = bucket_obj.list_blobs(
489 prefix=key,
490 delimiter="/",
491 )
492 # Check if there are any contents or common prefixes
493 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
494
495 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
496
497 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
498 bucket, key = split_path(path)
499 if path.endswith("/") or (bucket and not key):
500 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
501 # which metadata is not guaranteed to exist for cases such as
502 # "virtual prefix" that was never explicitly created.
503 if self._is_dir(path):
504 return ObjectMetadata(
505 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
506 )
507 else:
508 raise FileNotFoundError(f"Directory {path} does not exist.")
509 else:
510 self._refresh_gcs_client_if_needed()
511
512 def _invoke_api() -> ObjectMetadata:
513 bucket_obj = self._gcs_client.bucket(bucket)
514 blob = bucket_obj.get_blob(key)
515 if not blob:
516 raise NotFound(f"Blob {key} not found in bucket {bucket}")
517 return ObjectMetadata(
518 key=path,
519 content_length=blob.size or 0,
520 content_type=blob.content_type,
521 last_modified=blob.updated or AWARE_DATETIME_MIN,
522 etag=str(blob.generation),
523 metadata=dict(blob.metadata) if blob.metadata else None,
524 )
525
526 try:
527 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
528 except FileNotFoundError as error:
529 if strict:
530 # If the object does not exist on the given path, we will append a trailing slash and
531 # check if the path is a directory.
532 path = self._append_delimiter(path)
533 if self._is_dir(path):
534 return ObjectMetadata(
535 key=path,
536 type="directory",
537 content_length=0,
538 last_modified=AWARE_DATETIME_MIN,
539 )
540 raise error
541
542 def _list_objects(
543 self,
544 path: str,
545 start_after: Optional[str] = None,
546 end_at: Optional[str] = None,
547 include_directories: bool = False,
548 follow_symlinks: bool = True,
549 ) -> Iterator[ObjectMetadata]:
550 bucket, prefix = split_path(path)
551
552 # Get the prefix of the start_after and end_at paths relative to the bucket.
553 if start_after:
554 _, start_after = split_path(start_after)
555 if end_at:
556 _, end_at = split_path(end_at)
557
558 self._refresh_gcs_client_if_needed()
559
560 def _invoke_api() -> Iterator[ObjectMetadata]:
561 bucket_obj = self._gcs_client.bucket(bucket)
562 if include_directories:
563 blobs = bucket_obj.list_blobs(
564 prefix=prefix,
565 # This is ≥ instead of >.
566 start_offset=start_after,
567 delimiter="/",
568 )
569 else:
570 blobs = bucket_obj.list_blobs(
571 prefix=prefix,
572 # This is ≥ instead of >.
573 start_offset=start_after,
574 )
575
576 # GCS guarantees lexicographical order.
577 for blob in blobs:
578 key = blob.name
579 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
580 if key.endswith("/"):
581 if include_directories:
582 yield ObjectMetadata(
583 key=os.path.join(bucket, key.rstrip("/")),
584 type="directory",
585 content_length=0,
586 last_modified=blob.updated,
587 )
588 else:
589 yield ObjectMetadata(
590 key=os.path.join(bucket, key),
591 content_length=blob.size,
592 content_type=blob.content_type,
593 last_modified=blob.updated,
594 etag=blob.etag,
595 )
596 elif start_after != key:
597 return
598
599 # The directories must be accessed last.
600 if include_directories:
601 for directory in blobs.prefixes:
602 yield ObjectMetadata(
603 key=os.path.join(bucket, directory.rstrip("/")),
604 type="directory",
605 content_length=0,
606 last_modified=AWARE_DATETIME_MIN,
607 )
608
609 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
610
611 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
612 bucket, key = split_path(remote_path)
613 file_size: int = 0
614 self._refresh_gcs_client_if_needed()
615
616 if isinstance(f, str):
617 file_size = os.path.getsize(f)
618
619 # Upload small files
620 if file_size <= self._multipart_threshold:
621 if self._rust_client and not attributes:
622 run_async_rust_client_method(self._rust_client, "upload", f, key)
623 else:
624 with open(f, "rb") as fp:
625 self._put_object(remote_path, fp.read(), attributes=attributes)
626 return file_size
627
628 # Upload large files using transfer manager
629 def _invoke_api() -> int:
630 if self._rust_client and not attributes:
631 run_async_rust_client_method(self._rust_client, "upload_multipart_from_file", f, key)
632 else:
633 bucket_obj = self._gcs_client.bucket(bucket)
634 blob = bucket_obj.blob(key)
635 # GCS will raise an error if blob.metadata is None
636 validated_attributes = validate_attributes(attributes)
637 if validated_attributes is not None:
638 blob.metadata = validated_attributes
639 transfer_manager.upload_chunks_concurrently(
640 f,
641 blob,
642 chunk_size=self._multipart_chunksize,
643 max_workers=self._max_concurrency,
644 worker_type=transfer_manager.THREAD,
645 )
646
647 return file_size
648
649 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
650 else:
651 f.seek(0, io.SEEK_END)
652 file_size = f.tell()
653 f.seek(0)
654
655 # Upload small files
656 if file_size <= self._multipart_threshold:
657 if isinstance(f, io.StringIO):
658 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
659 else:
660 self._put_object(remote_path, f.read(), attributes=attributes)
661 return file_size
662
663 # Upload large files using transfer manager
664 def _invoke_api() -> int:
665 bucket_obj = self._gcs_client.bucket(bucket)
666 blob = bucket_obj.blob(key)
667 validated_attributes = validate_attributes(attributes)
668 if validated_attributes:
669 blob.metadata = validated_attributes
670 if isinstance(f, io.StringIO):
671 mode = "w"
672 else:
673 mode = "wb"
674
675 # transfer manager does not support uploading a file object
676 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
677 temp_file_path = fp.name
678 fp.write(f.read())
679
680 transfer_manager.upload_chunks_concurrently(
681 temp_file_path,
682 blob,
683 chunk_size=self._multipart_chunksize,
684 max_workers=self._max_concurrency,
685 worker_type=transfer_manager.THREAD,
686 )
687
688 os.unlink(temp_file_path)
689
690 return file_size
691
692 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
693
694 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
695 self._refresh_gcs_client_if_needed()
696
697 if metadata is None:
698 metadata = self._get_object_metadata(remote_path)
699
700 bucket, key = split_path(remote_path)
701
702 if isinstance(f, str):
703 if os.path.dirname(f):
704 os.makedirs(os.path.dirname(f), exist_ok=True)
705 # Download small files
706 if metadata.content_length <= self._multipart_threshold:
707 if self._rust_client:
708 run_async_rust_client_method(self._rust_client, "download", key, f)
709 else:
710 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
711 temp_file_path = fp.name
712 fp.write(self._get_object(remote_path))
713 os.rename(src=temp_file_path, dst=f)
714 return metadata.content_length
715
716 # Download large files using transfer manager
717 def _invoke_api() -> int:
718 bucket_obj = self._gcs_client.bucket(bucket)
719 blob = bucket_obj.blob(key)
720 if self._rust_client:
721 run_async_rust_client_method(self._rust_client, "download_multipart_to_file", key, f)
722 else:
723 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
724 temp_file_path = fp.name
725 transfer_manager.download_chunks_concurrently(
726 blob,
727 temp_file_path,
728 chunk_size=self._io_chunksize,
729 max_workers=self._max_concurrency,
730 worker_type=transfer_manager.THREAD,
731 )
732 os.rename(src=temp_file_path, dst=f)
733
734 return metadata.content_length
735
736 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
737 else:
738 # Download small files
739 if metadata.content_length <= self._multipart_threshold:
740 response = self._get_object(remote_path)
741 # Python client returns `bytes`, but Rust client returns a object implements buffer protocol,
742 # so we need to check whether `.decode()` is available.
743 if isinstance(f, io.StringIO):
744 if hasattr(response, "decode"):
745 f.write(response.decode("utf-8"))
746 else:
747 f.write(codecs.decode(memoryview(response), "utf-8"))
748 else:
749 f.write(response)
750 return metadata.content_length
751
752 # Download large files using transfer manager
753 def _invoke_api() -> int:
754 bucket_obj = self._gcs_client.bucket(bucket)
755 blob = bucket_obj.blob(key)
756
757 # transfer manager does not support downloading to a file object
758 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
759 temp_file_path = fp.name
760 transfer_manager.download_chunks_concurrently(
761 blob,
762 temp_file_path,
763 chunk_size=self._io_chunksize,
764 max_workers=self._max_concurrency,
765 worker_type=transfer_manager.THREAD,
766 )
767
768 if isinstance(f, io.StringIO):
769 with open(temp_file_path, "r") as fp:
770 f.write(fp.read())
771 else:
772 with open(temp_file_path, "rb") as fp:
773 f.write(fp.read())
774
775 os.unlink(temp_file_path)
776
777 return metadata.content_length
778
779 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)