1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import time
20from collections.abc import Callable, Iterator, Sequence, Sized
21from typing import IO, Any, Optional, TypeVar, Union
22
23import opentelemetry.metrics as api_metrics
24from google.api_core.exceptions import GoogleAPICallError, NotFound
25from google.auth import identity_pool
26from google.cloud import storage
27from google.cloud.storage import transfer_manager
28from google.cloud.storage.exceptions import InvalidResponse
29from google.oauth2.credentials import Credentials as OAuth2Credentials
30
31from ..telemetry import Telemetry
32from ..telemetry.attributes.base import AttributesProvider
33from ..types import (
34 AWARE_DATETIME_MIN,
35 Credentials,
36 CredentialsProvider,
37 NotModifiedError,
38 ObjectMetadata,
39 PreconditionFailedError,
40 Range,
41 RetryableError,
42)
43from ..utils import split_path, validate_attributes
44from .base import BaseStorageProvider
45
46_T = TypeVar("_T")
47
48PROVIDER = "gcs"
49
50MB = 1024 * 1024
51
52DEFAULT_MULTIPART_THRESHOLD = 512 * MB
53DEFAULT_MULTIPART_CHUNK_SIZE = 256 * MB
54DEFAULT_IO_CHUNK_SIZE = 256 * MB
55DEFAULT_MAX_CONCURRENCY = 8
56
57
[docs]
58class StringTokenSupplier(identity_pool.SubjectTokenSupplier):
59 """
60 Supply a string token to the Google Identity Pool.
61 """
62
63 def __init__(self, token: str):
64 self._token = token
65
[docs]
66 def get_subject_token(self, context, request):
67 return self._token
68
69
[docs]
70class GoogleIdentityPoolCredentialsProvider(CredentialsProvider):
71 """
72 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides Google's identity pool credentials.
73 """
74
75 def __init__(self, audience: str, token_supplier: str):
76 """
77 Initializes the :py:class:`GoogleIdentityPoolCredentials` with the audience and token supplier.
78
79 :param audience: The audience for the Google Identity Pool.
80 :param token_supplier: The token supplier for the Google Identity Pool.
81 """
82 self._audience = audience
83 self._token_supplier = token_supplier
84
[docs]
85 def get_credentials(self) -> Credentials:
86 return Credentials(
87 access_key="",
88 secret_key="",
89 token="",
90 expiration=None,
91 custom_fields={"audience": self._audience, "token": self._token_supplier},
92 )
93
[docs]
94 def refresh_credentials(self) -> None:
95 pass
96
97
[docs]
98class GoogleStorageProvider(BaseStorageProvider):
99 """
100 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
101 """
102
103 def __init__(
104 self,
105 project_id: str = os.getenv("GOOGLE_CLOUD_PROJECT_ID", ""),
106 endpoint_url: str = "",
107 base_path: str = "",
108 credentials_provider: Optional[CredentialsProvider] = None,
109 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
110 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
111 metric_attributes_providers: Sequence[AttributesProvider] = (),
112 **kwargs: Any,
113 ):
114 """
115 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
116
117 :param project_id: The Google Cloud project ID.
118 :param endpoint_url: The custom endpoint URL for the GCS service.
119 :param base_path: The root prefix path within the bucket where all operations will be scoped.
120 :param credentials_provider: The provider to retrieve GCS credentials.
121 :param metric_counters: Metric counters.
122 :param metric_gauges: Metric gauges.
123 :param metric_attributes_providers: Metric attributes providers.
124 """
125 super().__init__(
126 base_path=base_path,
127 provider_name=PROVIDER,
128 metric_counters=metric_counters,
129 metric_gauges=metric_gauges,
130 metric_attributes_providers=metric_attributes_providers,
131 )
132
133 self._project_id = project_id
134 self._endpoint_url = endpoint_url
135 self._credentials_provider = credentials_provider
136 self._gcs_client = self._create_gcs_client()
137 self._multipart_threshold = kwargs.get("multipart_threshold", DEFAULT_MULTIPART_THRESHOLD)
138 self._multipart_chunksize = kwargs.get("multipart_chunksize", DEFAULT_MULTIPART_CHUNK_SIZE)
139 self._io_chunk_size = kwargs.get("io_chunk_size", DEFAULT_IO_CHUNK_SIZE)
140 self._max_concurrency = kwargs.get("max_concurrency", DEFAULT_MAX_CONCURRENCY)
141
142 def _create_gcs_client(self) -> storage.Client:
143 client_options = {}
144 if self._endpoint_url:
145 client_options["api_endpoint"] = self._endpoint_url
146
147 if self._credentials_provider:
148 if isinstance(self._credentials_provider, GoogleIdentityPoolCredentialsProvider):
149 audience = self._credentials_provider.get_credentials().get_custom_field("audience")
150 token = self._credentials_provider.get_credentials().get_custom_field("token")
151
152 # Use Workload Identity Federation (WIF)
153 identity_pool_credentials = identity_pool.Credentials(
154 audience=audience,
155 subject_token_type="urn:ietf:params:oauth:token-type:id_token",
156 subject_token_supplier=StringTokenSupplier(token),
157 )
158 return storage.Client(
159 project=self._project_id, credentials=identity_pool_credentials, client_options=client_options
160 )
161 else:
162 # Use OAuth 2.0 token
163 token = self._credentials_provider.get_credentials().token
164 creds = OAuth2Credentials(token=token)
165 return storage.Client(project=self._project_id, credentials=creds, client_options=client_options)
166 else:
167 return storage.Client(project=self._project_id, client_options=client_options)
168
169 def _refresh_gcs_client_if_needed(self) -> None:
170 """
171 Refreshes the GCS client if the current credentials are expired.
172 """
173 if self._credentials_provider:
174 credentials = self._credentials_provider.get_credentials()
175 if credentials.is_expired():
176 self._credentials_provider.refresh_credentials()
177 self._gcs_client = self._create_gcs_client()
178
179 def _collect_metrics(
180 self,
181 func: Callable[[], _T],
182 operation: str,
183 bucket: str,
184 key: str,
185 put_object_size: Optional[int] = None,
186 get_object_size: Optional[int] = None,
187 ) -> _T:
188 """
189 Collects and records performance metrics around GCS operations such as PUT, GET, DELETE, etc.
190
191 This method wraps an GCS operation and measures the time it takes to complete, along with recording
192 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
193 proper logging of duration and object size.
194
195 :param func: The function that performs the actual GCS operation.
196 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
197 :param bucket: The name of the GCS bucket involved in the operation.
198 :param key: The key of the object within the GCS bucket.
199 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
200 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
201
202 :return: The result of the GCS operation, typically the return value of the `func` callable.
203 """
204 start_time = time.time()
205 status_code = 200
206
207 object_size = None
208 if operation == "PUT":
209 object_size = put_object_size
210 elif operation == "GET" and get_object_size:
211 object_size = get_object_size
212
213 try:
214 result = func()
215 if operation == "GET" and object_size is None and isinstance(result, Sized):
216 object_size = len(result)
217 return result
218 except GoogleAPICallError as error:
219 status_code = error.code if error.code else -1
220 error_info = f"status_code: {status_code}, message: {error.message}"
221 if status_code == 404:
222 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
223 elif status_code == 412:
224 raise PreconditionFailedError(
225 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
226 ) from error
227 elif status_code == 304:
228 # for if_none_match with a specific etag condition.
229 raise NotModifiedError(f"Object {bucket}/{key} has not been modified.") from error
230 else:
231 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
232 except InvalidResponse as error:
233 status_code = error.response.status_code
234 response_text = error.response.text
235 error_details = f"error: {error}, error_response_text: {response_text}"
236 # Check for NoSuchUpload within the response text
237 if "NoSuchUpload" in response_text:
238 raise RetryableError(f"Multipart upload failed for {bucket}/{key}, {error_details}") from error
239 else:
240 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_details}") from error
241 except Exception as error:
242 status_code = -1
243 error_details = str(error)
244 raise RuntimeError(
245 f"Failed to {operation} object(s) at {bucket}/{key}. error_type: {type(error).__name__}, {error_details}"
246 ) from error
247 finally:
248 elapsed_time = time.time() - start_time
249 self._metric_helper.record_duration(
250 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code
251 )
252 if object_size:
253 self._metric_helper.record_object_size(
254 object_size,
255 provider=self._provider_name,
256 operation=operation,
257 bucket=bucket,
258 status_code=status_code,
259 )
260
261 def _put_object(
262 self,
263 path: str,
264 body: bytes,
265 if_match: Optional[str] = None,
266 if_none_match: Optional[str] = None,
267 attributes: Optional[dict[str, str]] = None,
268 ) -> int:
269 """
270 Uploads an object to Google Cloud Storage.
271
272 :param path: The path to the object to upload.
273 :param body: The content of the object to upload.
274 :param if_match: Optional ETag to match against the object.
275 :param if_none_match: Optional ETag to match against the object.
276 :param attributes: Optional attributes to attach to the object.
277 """
278 bucket, key = split_path(path)
279 self._refresh_gcs_client_if_needed()
280
281 def _invoke_api() -> int:
282 bucket_obj = self._gcs_client.bucket(bucket)
283 blob = bucket_obj.blob(key)
284
285 kwargs = {}
286
287 if if_match:
288 kwargs["if_generation_match"] = int(if_match) # 412 error code
289 if if_none_match:
290 if if_none_match == "*":
291 raise NotImplementedError("if_none_match='*' is not supported for GCS")
292 else:
293 kwargs["if_generation_not_match"] = int(if_none_match) # 304 error code
294
295 validated_attributes = validate_attributes(attributes)
296 if validated_attributes:
297 blob.metadata = validated_attributes
298
299 blob.upload_from_string(body, **kwargs)
300
301 return len(body)
302
303 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
304
305 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
306 bucket, key = split_path(path)
307 self._refresh_gcs_client_if_needed()
308
309 def _invoke_api() -> bytes:
310 bucket_obj = self._gcs_client.bucket(bucket)
311 blob = bucket_obj.blob(key)
312 if byte_range:
313 return blob.download_as_bytes(start=byte_range.offset, end=byte_range.offset + byte_range.size - 1)
314 return blob.download_as_bytes()
315
316 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
317
318 def _copy_object(self, src_path: str, dest_path: str) -> int:
319 src_bucket, src_key = split_path(src_path)
320 dest_bucket, dest_key = split_path(dest_path)
321 self._refresh_gcs_client_if_needed()
322
323 src_object = self._get_object_metadata(src_path)
324
325 def _invoke_api() -> int:
326 source_bucket_obj = self._gcs_client.bucket(src_bucket)
327 source_blob = source_bucket_obj.blob(src_key)
328
329 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
330 destination_blob = destination_bucket_obj.blob(dest_key)
331
332 rewrite_tokens = [None]
333 while len(rewrite_tokens) > 0:
334 rewrite_token = rewrite_tokens.pop()
335 next_rewrite_token, _, _ = destination_blob.rewrite(source=source_blob, token=rewrite_token)
336 if next_rewrite_token is not None:
337 rewrite_tokens.append(next_rewrite_token)
338
339 return src_object.content_length
340
341 return self._collect_metrics(
342 _invoke_api,
343 operation="COPY",
344 bucket=src_bucket,
345 key=src_key,
346 put_object_size=src_object.content_length,
347 )
348
349 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
350 bucket, key = split_path(path)
351 self._refresh_gcs_client_if_needed()
352
353 def _invoke_api() -> None:
354 bucket_obj = self._gcs_client.bucket(bucket)
355 blob = bucket_obj.blob(key)
356
357 # If if_match is provided, use it as a precondition
358 if if_match:
359 generation = int(if_match)
360 blob.delete(if_generation_match=generation)
361 else:
362 # No if_match check needed, just delete
363 blob.delete()
364
365 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
366
367 def _is_dir(self, path: str) -> bool:
368 # Ensure the path ends with '/' to mimic a directory
369 path = self._append_delimiter(path)
370
371 bucket, key = split_path(path)
372 self._refresh_gcs_client_if_needed()
373
374 def _invoke_api() -> bool:
375 bucket_obj = self._gcs_client.bucket(bucket)
376 # List objects with the given prefix
377 blobs = bucket_obj.list_blobs(
378 prefix=key,
379 delimiter="/",
380 )
381 # Check if there are any contents or common prefixes
382 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
383
384 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key)
385
386 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
387 if path.endswith("/"):
388 # If path is a "directory", then metadata is not guaranteed to exist if
389 # it is a "virtual prefix" that was never explicitly created.
390 if self._is_dir(path):
391 return ObjectMetadata(
392 key=path, type="directory", content_length=0, last_modified=AWARE_DATETIME_MIN, etag=None
393 )
394 else:
395 raise FileNotFoundError(f"Directory {path} does not exist.")
396 else:
397 bucket, key = split_path(path)
398 self._refresh_gcs_client_if_needed()
399
400 def _invoke_api() -> ObjectMetadata:
401 bucket_obj = self._gcs_client.bucket(bucket)
402 blob = bucket_obj.get_blob(key)
403 if not blob:
404 raise NotFound(f"Blob {key} not found in bucket {bucket}")
405 return ObjectMetadata(
406 key=path,
407 content_length=blob.size or 0,
408 content_type=blob.content_type,
409 last_modified=blob.updated or AWARE_DATETIME_MIN,
410 etag=str(blob.generation),
411 metadata=dict(blob.metadata) if blob.metadata else None,
412 )
413
414 try:
415 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
416 except FileNotFoundError as error:
417 if strict:
418 # If the object does not exist on the given path, we will append a trailing slash and
419 # check if the path is a directory.
420 path = self._append_delimiter(path)
421 if self._is_dir(path):
422 return ObjectMetadata(
423 key=path,
424 type="directory",
425 content_length=0,
426 last_modified=AWARE_DATETIME_MIN,
427 )
428 raise error
429
430 def _list_objects(
431 self,
432 prefix: str,
433 start_after: Optional[str] = None,
434 end_at: Optional[str] = None,
435 include_directories: bool = False,
436 ) -> Iterator[ObjectMetadata]:
437 bucket, prefix = split_path(prefix)
438 self._refresh_gcs_client_if_needed()
439
440 def _invoke_api() -> Iterator[ObjectMetadata]:
441 bucket_obj = self._gcs_client.bucket(bucket)
442 if include_directories:
443 blobs = bucket_obj.list_blobs(
444 prefix=prefix,
445 # This is ≥ instead of >.
446 start_offset=start_after,
447 delimiter="/",
448 )
449 else:
450 blobs = bucket_obj.list_blobs(
451 prefix=prefix,
452 # This is ≥ instead of >.
453 start_offset=start_after,
454 )
455
456 # GCS guarantees lexicographical order.
457 for blob in blobs:
458 key = blob.name
459 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
460 if key.endswith("/"):
461 if include_directories:
462 yield ObjectMetadata(
463 key=os.path.join(bucket, key.rstrip("/")),
464 type="directory",
465 content_length=0,
466 last_modified=blob.updated,
467 )
468 else:
469 yield ObjectMetadata(
470 key=os.path.join(bucket, key),
471 content_length=blob.size,
472 content_type=blob.content_type,
473 last_modified=blob.updated,
474 etag=blob.etag,
475 )
476 elif start_after != key:
477 return
478
479 # The directories must be accessed last.
480 if include_directories:
481 for directory in blobs.prefixes:
482 yield ObjectMetadata(
483 key=os.path.join(bucket, directory.rstrip("/")),
484 type="directory",
485 content_length=0,
486 last_modified=AWARE_DATETIME_MIN,
487 )
488
489 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
490
491 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
492 bucket, key = split_path(remote_path)
493 file_size: int = 0
494 self._refresh_gcs_client_if_needed()
495
496 if isinstance(f, str):
497 file_size = os.path.getsize(f)
498
499 # Upload small files
500 if file_size <= self._multipart_threshold:
501 with open(f, "rb") as fp:
502 self._put_object(remote_path, fp.read(), attributes=attributes)
503 return file_size
504
505 # Upload large files using transfer manager
506 def _invoke_api() -> int:
507 bucket_obj = self._gcs_client.bucket(bucket)
508 blob = bucket_obj.blob(key)
509 blob.metadata = validate_attributes(attributes)
510 transfer_manager.upload_chunks_concurrently(
511 f,
512 blob,
513 chunk_size=self._multipart_chunksize,
514 max_workers=self._max_concurrency,
515 worker_type=transfer_manager.THREAD,
516 )
517
518 return file_size
519
520 return self._collect_metrics(
521 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
522 )
523 else:
524 f.seek(0, io.SEEK_END)
525 file_size = f.tell()
526 f.seek(0)
527
528 # Upload small files
529 if file_size <= self._multipart_threshold:
530 if isinstance(f, io.StringIO):
531 self._put_object(remote_path, f.read().encode("utf-8"), attributes=attributes)
532 else:
533 self._put_object(remote_path, f.read(), attributes=attributes)
534 return file_size
535
536 # Upload large files using transfer manager
537 def _invoke_api() -> int:
538 bucket_obj = self._gcs_client.bucket(bucket)
539 blob = bucket_obj.blob(key)
540 validated_attributes = validate_attributes(attributes)
541 if validated_attributes:
542 blob.metadata = validated_attributes
543 if isinstance(f, io.StringIO):
544 mode = "w"
545 else:
546 mode = "wb"
547
548 # transfer manager does not support uploading a file object
549 with tempfile.NamedTemporaryFile(mode=mode, delete=False, prefix=".") as fp:
550 temp_file_path = fp.name
551 fp.write(f.read())
552
553 transfer_manager.upload_chunks_concurrently(
554 temp_file_path,
555 blob,
556 chunk_size=self._multipart_chunksize,
557 max_workers=self._max_concurrency,
558 worker_type=transfer_manager.THREAD,
559 )
560
561 os.unlink(temp_file_path)
562
563 return file_size
564
565 return self._collect_metrics(
566 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
567 )
568
569 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
570 self._refresh_gcs_client_if_needed()
571
572 if metadata is None:
573 metadata = self._get_object_metadata(remote_path)
574
575 bucket, key = split_path(remote_path)
576
577 if isinstance(f, str):
578 os.makedirs(os.path.dirname(f), exist_ok=True)
579 # Download small files
580 if metadata.content_length <= self._multipart_threshold:
581 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
582 temp_file_path = fp.name
583 fp.write(self._get_object(remote_path))
584 os.rename(src=temp_file_path, dst=f)
585 return metadata.content_length
586
587 # Download large files using transfer manager
588 def _invoke_api() -> int:
589 bucket_obj = self._gcs_client.bucket(bucket)
590 blob = bucket_obj.blob(key)
591
592 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
593 temp_file_path = fp.name
594 transfer_manager.download_chunks_concurrently(
595 blob,
596 temp_file_path,
597 chunk_size=self._io_chunk_size,
598 max_workers=self._max_concurrency,
599 worker_type=transfer_manager.THREAD,
600 )
601 os.rename(src=temp_file_path, dst=f)
602
603 return metadata.content_length
604
605 return self._collect_metrics(
606 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
607 )
608 else:
609 # Download small files
610 if metadata.content_length <= self._multipart_threshold:
611 if isinstance(f, io.StringIO):
612 f.write(self._get_object(remote_path).decode("utf-8"))
613 else:
614 f.write(self._get_object(remote_path))
615 return metadata.content_length
616
617 # Download large files using transfer manager
618 def _invoke_api() -> int:
619 bucket_obj = self._gcs_client.bucket(bucket)
620 blob = bucket_obj.blob(key)
621
622 # transfer manager does not support downloading to a file object
623 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as fp:
624 temp_file_path = fp.name
625 transfer_manager.download_chunks_concurrently(
626 blob,
627 temp_file_path,
628 chunk_size=self._io_chunk_size,
629 max_workers=self._max_concurrency,
630 worker_type=transfer_manager.THREAD,
631 )
632
633 if isinstance(f, io.StringIO):
634 with open(temp_file_path, "r") as fp:
635 f.write(fp.read())
636 else:
637 with open(temp_file_path, "rb") as fp:
638 f.write(fp.read())
639
640 os.unlink(temp_file_path)
641
642 return metadata.content_length
643
644 return self._collect_metrics(
645 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
646 )