Source code for multistorageclient.providers.gcs

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19import time
 20from datetime import datetime
 21from typing import IO, Any, Callable, Iterator, Optional, Union
 22
 23from google.api_core.exceptions import NotFound
 24from google.cloud import storage
 25from google.oauth2.credentials import Credentials as GoogleCredentials
 26
 27from ..types import CredentialsProvider, ObjectMetadata, Range
 28from ..utils import split_path
 29from .base import BaseStorageProvider
 30
 31PROVIDER = "gcs"
 32
 33
[docs] 34class GoogleStorageProvider(BaseStorageProvider): 35 """ 36 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage. 37 """ 38 39 def __init__( 40 self, project_id: str, base_path: str = "", credentials_provider: Optional[CredentialsProvider] = None 41 ): 42 """ 43 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider. 44 45 :param project_id: The Google Cloud project ID. 46 :param base_path: The root prefix path within the bucket where all operations will be scoped. 47 :param credentials_provider: The provider to retrieve GCS credentials. 48 """ 49 super().__init__(base_path=base_path, provider_name=PROVIDER) 50 51 self._project_id = project_id 52 self._credentials_provider = credentials_provider 53 self._gcs_client = self._create_gcs_client() 54 55 def _create_gcs_client(self) -> storage.Client: 56 if self._credentials_provider: 57 access_token = self._credentials_provider.get_credentials().token 58 creds = GoogleCredentials(token=access_token) 59 return storage.Client(project=self._project_id, credentials=creds) 60 else: 61 return storage.Client(project=self._project_id) 62 63 def _refresh_gcs_client_if_needed(self) -> None: 64 """ 65 Refreshes the GCS client if the current credentials are expired. 66 """ 67 if self._credentials_provider: 68 credentials = self._credentials_provider.get_credentials() 69 if credentials.is_expired(): 70 self._credentials_provider.refresh_credentials() 71 self._gcs_client = self._create_gcs_client() 72 73 def _collect_metrics( 74 self, 75 func: Callable, 76 operation: str, 77 bucket: str, 78 key: str, 79 put_object_size: Optional[int] = None, 80 get_object_size: Optional[int] = None, 81 ) -> Any: 82 """ 83 Collects and records performance metrics around GCS operations such as PUT, GET, DELETE, etc. 84 85 This method wraps an GCS operation and measures the time it takes to complete, along with recording 86 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 87 proper logging of duration and object size. 88 89 :param func: The function that performs the actual GCS operation. 90 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 91 :param bucket: The name of the GCS bucket involved in the operation. 92 :param key: The key of the object within the GCS bucket. 93 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 94 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 95 96 :return: The result of the GCS operation, typically the return value of the `func` callable. 97 """ 98 start_time = time.time() 99 status_code = 200 100 101 object_size = None 102 if operation == "PUT": 103 object_size = put_object_size 104 elif operation == "GET" and get_object_size: 105 object_size = get_object_size 106 107 try: 108 result = func() 109 if operation == "GET" and object_size is None: 110 object_size = len(result) 111 return result 112 except NotFound: 113 status_code = 404 114 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from 115 except Exception as error: 116 status_code = -1 117 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error 118 finally: 119 elapsed_time = time.time() - start_time 120 self._metric_helper.record_duration( 121 elapsed_time, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code 122 ) 123 if object_size: 124 self._metric_helper.record_object_size( 125 object_size, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code 126 ) 127 128 def _put_object(self, path: str, body: bytes) -> None: 129 bucket, key = split_path(path) 130 self._refresh_gcs_client_if_needed() 131 132 def _invoke_api() -> None: 133 bucket_obj = self._gcs_client.bucket(bucket) 134 blob = bucket_obj.blob(key) 135 blob.upload_from_string(body) 136 137 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 138 139 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 140 bucket, key = split_path(path) 141 self._refresh_gcs_client_if_needed() 142 143 def _invoke_api() -> bytes: 144 bucket_obj = self._gcs_client.bucket(bucket) 145 blob = bucket_obj.blob(key) 146 if byte_range: 147 return blob.download_as_bytes(start=byte_range.offset, end=byte_range.offset + byte_range.size - 1) 148 return blob.download_as_bytes() 149 150 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 151 152 def _copy_object(self, src_path: str, dest_path: str) -> None: 153 src_bucket, src_key = split_path(src_path) 154 dest_bucket, dest_key = split_path(dest_path) 155 self._refresh_gcs_client_if_needed() 156 157 def _invoke_api() -> None: 158 source_bucket_obj = self._gcs_client.bucket(src_bucket) 159 source_blob = source_bucket_obj.blob(src_key) 160 161 destination_bucket_obj = self._gcs_client.bucket(dest_bucket) 162 source_bucket_obj.copy_blob(source_blob, destination_bucket_obj, dest_key) 163 164 src_object = self._get_object_metadata(src_path) 165 166 return self._collect_metrics( 167 _invoke_api, 168 operation="COPY", 169 bucket=src_bucket, 170 key=src_key, 171 put_object_size=src_object.content_length, 172 ) 173 174 def _delete_object(self, path: str) -> None: 175 bucket, key = split_path(path) 176 self._refresh_gcs_client_if_needed() 177 178 def _invoke_api() -> None: 179 bucket_obj = self._gcs_client.bucket(bucket) 180 blob = bucket_obj.blob(key) 181 blob.delete() 182 183 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 184 185 def _is_dir(self, path: str) -> bool: 186 # Ensure the path ends with '/' to mimic a directory 187 path = self._append_delimiter(path) 188 189 bucket, key = split_path(path) 190 self._refresh_gcs_client_if_needed() 191 192 def _invoke_api() -> bool: 193 bucket_obj = self._gcs_client.bucket(bucket) 194 # List objects with the given prefix 195 blobs = bucket_obj.list_blobs( 196 prefix=key, 197 delimiter="/", 198 ) 199 # Check if there are any contents or common prefixes 200 return any(True for _ in blobs) or any(True for _ in blobs.prefixes) 201 202 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) 203 204 def _get_object_metadata(self, path: str) -> ObjectMetadata: 205 if path.endswith("/"): 206 # If path is a "directory", then metadata is not guaranteed to exist if 207 # it is a "virtual prefix" that was never explicitly created. 208 if self._is_dir(path): 209 return ObjectMetadata( 210 key=path, 211 type="directory", 212 content_length=0, 213 last_modified=datetime.min, 214 ) 215 else: 216 raise FileNotFoundError(f"Directory {path} does not exist.") 217 else: 218 bucket, key = split_path(path) 219 self._refresh_gcs_client_if_needed() 220 221 def _invoke_api() -> ObjectMetadata: 222 bucket_obj = self._gcs_client.bucket(bucket) 223 blob = bucket_obj.get_blob(key) 224 if not blob: 225 raise NotFound(f"Blob {key} not found in bucket {bucket}") 226 return ObjectMetadata( 227 key=path, 228 content_length=blob.size or 0, 229 content_type=blob.content_type, 230 last_modified=blob.updated or datetime.min, 231 etag=blob.etag, 232 ) 233 234 try: 235 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 236 except FileNotFoundError as error: 237 # If the object does not exist on the given path, we will append a trailing slash and 238 # check if the path is a directory. 239 path = self._append_delimiter(path) 240 if self._is_dir(path): 241 return ObjectMetadata( 242 key=path, 243 type="directory", 244 content_length=0, 245 last_modified=datetime.min, 246 ) 247 else: 248 raise error 249 250 def _list_objects( 251 self, 252 prefix: str, 253 start_after: Optional[str] = None, 254 end_at: Optional[str] = None, 255 include_directories: bool = False, 256 ) -> Iterator[ObjectMetadata]: 257 bucket, prefix = split_path(prefix) 258 self._refresh_gcs_client_if_needed() 259 260 def _invoke_api() -> Iterator[ObjectMetadata]: 261 bucket_obj = self._gcs_client.bucket(bucket) 262 if include_directories: 263 blobs = bucket_obj.list_blobs( 264 prefix=prefix, 265 # This is ≥ instead of >. 266 start_offset=start_after, 267 delimiter="/", 268 ) 269 else: 270 blobs = bucket_obj.list_blobs( 271 prefix=prefix, 272 # This is ≥ instead of >. 273 start_offset=start_after, 274 ) 275 276 # GCS guarantees lexicographical order. 277 for blob in blobs: 278 key = blob.name 279 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 280 yield ObjectMetadata( 281 key=key, 282 content_length=blob.size, 283 content_type=blob.content_type, 284 last_modified=blob.updated, 285 etag=blob.etag, 286 ) 287 elif start_after != key: 288 return 289 290 # The directories must be accessed last. 291 if include_directories: 292 for directory in blobs.prefixes: 293 yield ObjectMetadata( 294 key=directory.rstrip("/"), 295 type="directory", 296 content_length=0, 297 last_modified=datetime.min, 298 ) 299 300 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 301 302 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None: 303 bucket, key = split_path(remote_path) 304 self._refresh_gcs_client_if_needed() 305 306 if isinstance(f, str): 307 filesize = os.path.getsize(f) 308 309 def _invoke_api() -> None: 310 bucket_obj = self._gcs_client.bucket(bucket) 311 blob = bucket_obj.blob(key) 312 blob.upload_from_filename(f) 313 314 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize) 315 else: 316 f.seek(0, io.SEEK_END) 317 filesize = f.tell() 318 f.seek(0) 319 320 def _invoke_api() -> None: 321 bucket_obj = self._gcs_client.bucket(bucket) 322 blob = bucket_obj.blob(key) 323 blob.upload_from_string(f.read()) 324 325 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize) 326 327 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None: 328 self._refresh_gcs_client_if_needed() 329 330 if not metadata: 331 metadata = self._get_object_metadata(remote_path) 332 333 bucket, key = split_path(remote_path) 334 335 if isinstance(f, str): 336 os.makedirs(os.path.dirname(f), exist_ok=True) 337 338 def _invoke_api() -> None: 339 bucket_obj = self._gcs_client.bucket(bucket) 340 blob = bucket_obj.blob(key) 341 342 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 343 temp_file_path = fp.name 344 blob.download_to_filename(temp_file_path) 345 os.rename(src=temp_file_path, dst=f) 346 347 return self._collect_metrics( 348 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 349 ) 350 else: 351 352 def _invoke_api() -> None: 353 bucket_obj = self._gcs_client.bucket(bucket) 354 blob = bucket_obj.blob(key) 355 if isinstance(f, io.TextIOBase): 356 content = blob.download_as_text() 357 f.write(content) 358 else: 359 blob.download_to_file(f) 360 361 return self._collect_metrics( 362 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 363 )