1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import time
20from datetime import datetime
21from typing import IO, Any, Callable, Iterator, Optional, Union
22
23from google.api_core.exceptions import NotFound
24from google.cloud import storage
25from google.oauth2.credentials import Credentials as GoogleCredentials
26
27from ..types import CredentialsProvider, ObjectMetadata, Range
28from ..utils import split_path
29from .base import BaseStorageProvider
30
31PROVIDER = "gcs"
32
33
[docs]
34class GoogleStorageProvider(BaseStorageProvider):
35 """
36 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Google Cloud Storage.
37 """
38
39 def __init__(
40 self, project_id: str, base_path: str = "", credentials_provider: Optional[CredentialsProvider] = None
41 ):
42 """
43 Initializes the :py:class:`GoogleStorageProvider` with the project ID and optional credentials provider.
44
45 :param project_id: The Google Cloud project ID.
46 :param base_path: The root prefix path within the bucket where all operations will be scoped.
47 :param credentials_provider: The provider to retrieve GCS credentials.
48 """
49 super().__init__(base_path=base_path, provider_name=PROVIDER)
50
51 self._project_id = project_id
52 self._credentials_provider = credentials_provider
53 self._gcs_client = self._create_gcs_client()
54
55 def _create_gcs_client(self) -> storage.Client:
56 if self._credentials_provider:
57 access_token = self._credentials_provider.get_credentials().token
58 creds = GoogleCredentials(token=access_token)
59 return storage.Client(project=self._project_id, credentials=creds)
60 else:
61 return storage.Client(project=self._project_id)
62
63 def _refresh_gcs_client_if_needed(self) -> None:
64 """
65 Refreshes the GCS client if the current credentials are expired.
66 """
67 if self._credentials_provider:
68 credentials = self._credentials_provider.get_credentials()
69 if credentials.is_expired():
70 self._credentials_provider.refresh_credentials()
71 self._gcs_client = self._create_gcs_client()
72
73 def _collect_metrics(
74 self,
75 func: Callable,
76 operation: str,
77 bucket: str,
78 key: str,
79 put_object_size: Optional[int] = None,
80 get_object_size: Optional[int] = None,
81 ) -> Any:
82 """
83 Collects and records performance metrics around GCS operations such as PUT, GET, DELETE, etc.
84
85 This method wraps an GCS operation and measures the time it takes to complete, along with recording
86 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
87 proper logging of duration and object size.
88
89 :param func: The function that performs the actual GCS operation.
90 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
91 :param bucket: The name of the GCS bucket involved in the operation.
92 :param key: The key of the object within the GCS bucket.
93 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
94 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
95
96 :return: The result of the GCS operation, typically the return value of the `func` callable.
97 """
98 start_time = time.time()
99 status_code = 200
100
101 object_size = None
102 if operation == "PUT":
103 object_size = put_object_size
104 elif operation == "GET" and get_object_size:
105 object_size = get_object_size
106
107 try:
108 result = func()
109 if operation == "GET" and object_size is None:
110 object_size = len(result)
111 return result
112 except NotFound:
113 status_code = 404
114 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
115 except Exception as error:
116 status_code = -1
117 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error
118 finally:
119 elapsed_time = time.time() - start_time
120 self._metric_helper.record_duration(
121 elapsed_time, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code
122 )
123 if object_size:
124 self._metric_helper.record_object_size(
125 object_size, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code
126 )
127
128 def _put_object(self, path: str, body: bytes) -> None:
129 bucket, key = split_path(path)
130 self._refresh_gcs_client_if_needed()
131
132 def _invoke_api() -> None:
133 bucket_obj = self._gcs_client.bucket(bucket)
134 blob = bucket_obj.blob(key)
135 blob.upload_from_string(body)
136
137 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
138
139 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
140 bucket, key = split_path(path)
141 self._refresh_gcs_client_if_needed()
142
143 def _invoke_api() -> bytes:
144 bucket_obj = self._gcs_client.bucket(bucket)
145 blob = bucket_obj.blob(key)
146 if byte_range:
147 return blob.download_as_bytes(start=byte_range.offset, end=byte_range.offset + byte_range.size - 1)
148 return blob.download_as_bytes()
149
150 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
151
152 def _copy_object(self, src_path: str, dest_path: str) -> None:
153 src_bucket, src_key = split_path(src_path)
154 dest_bucket, dest_key = split_path(dest_path)
155 self._refresh_gcs_client_if_needed()
156
157 def _invoke_api() -> None:
158 source_bucket_obj = self._gcs_client.bucket(src_bucket)
159 source_blob = source_bucket_obj.blob(src_key)
160
161 destination_bucket_obj = self._gcs_client.bucket(dest_bucket)
162 source_bucket_obj.copy_blob(source_blob, destination_bucket_obj, dest_key)
163
164 src_object = self._get_object_metadata(src_path)
165
166 return self._collect_metrics(
167 _invoke_api,
168 operation="COPY",
169 bucket=src_bucket,
170 key=src_key,
171 put_object_size=src_object.content_length,
172 )
173
174 def _delete_object(self, path: str) -> None:
175 bucket, key = split_path(path)
176 self._refresh_gcs_client_if_needed()
177
178 def _invoke_api() -> None:
179 bucket_obj = self._gcs_client.bucket(bucket)
180 blob = bucket_obj.blob(key)
181 blob.delete()
182
183 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
184
185 def _is_dir(self, path: str) -> bool:
186 # Ensure the path ends with '/' to mimic a directory
187 path = self._append_delimiter(path)
188
189 bucket, key = split_path(path)
190 self._refresh_gcs_client_if_needed()
191
192 def _invoke_api() -> bool:
193 bucket_obj = self._gcs_client.bucket(bucket)
194 # List objects with the given prefix
195 blobs = bucket_obj.list_blobs(
196 prefix=key,
197 delimiter="/",
198 )
199 # Check if there are any contents or common prefixes
200 return any(True for _ in blobs) or any(True for _ in blobs.prefixes)
201
202 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key)
203
204 def _get_object_metadata(self, path: str) -> ObjectMetadata:
205 if path.endswith("/"):
206 # If path is a "directory", then metadata is not guaranteed to exist if
207 # it is a "virtual prefix" that was never explicitly created.
208 if self._is_dir(path):
209 return ObjectMetadata(
210 key=path,
211 type="directory",
212 content_length=0,
213 last_modified=datetime.min,
214 )
215 else:
216 raise FileNotFoundError(f"Directory {path} does not exist.")
217 else:
218 bucket, key = split_path(path)
219 self._refresh_gcs_client_if_needed()
220
221 def _invoke_api() -> ObjectMetadata:
222 bucket_obj = self._gcs_client.bucket(bucket)
223 blob = bucket_obj.get_blob(key)
224 if not blob:
225 raise NotFound(f"Blob {key} not found in bucket {bucket}")
226 return ObjectMetadata(
227 key=path,
228 content_length=blob.size or 0,
229 content_type=blob.content_type,
230 last_modified=blob.updated or datetime.min,
231 etag=blob.etag,
232 )
233
234 try:
235 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
236 except FileNotFoundError as error:
237 # If the object does not exist on the given path, we will append a trailing slash and
238 # check if the path is a directory.
239 path = self._append_delimiter(path)
240 if self._is_dir(path):
241 return ObjectMetadata(
242 key=path,
243 type="directory",
244 content_length=0,
245 last_modified=datetime.min,
246 )
247 else:
248 raise error
249
250 def _list_objects(
251 self,
252 prefix: str,
253 start_after: Optional[str] = None,
254 end_at: Optional[str] = None,
255 include_directories: bool = False,
256 ) -> Iterator[ObjectMetadata]:
257 bucket, prefix = split_path(prefix)
258 self._refresh_gcs_client_if_needed()
259
260 def _invoke_api() -> Iterator[ObjectMetadata]:
261 bucket_obj = self._gcs_client.bucket(bucket)
262 if include_directories:
263 blobs = bucket_obj.list_blobs(
264 prefix=prefix,
265 # This is ≥ instead of >.
266 start_offset=start_after,
267 delimiter="/",
268 )
269 else:
270 blobs = bucket_obj.list_blobs(
271 prefix=prefix,
272 # This is ≥ instead of >.
273 start_offset=start_after,
274 )
275
276 # GCS guarantees lexicographical order.
277 for blob in blobs:
278 key = blob.name
279 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
280 yield ObjectMetadata(
281 key=key,
282 content_length=blob.size,
283 content_type=blob.content_type,
284 last_modified=blob.updated,
285 etag=blob.etag,
286 )
287 elif start_after != key:
288 return
289
290 # The directories must be accessed last.
291 if include_directories:
292 for directory in blobs.prefixes:
293 yield ObjectMetadata(
294 key=directory.rstrip("/"),
295 type="directory",
296 content_length=0,
297 last_modified=datetime.min,
298 )
299
300 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
301
302 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None:
303 bucket, key = split_path(remote_path)
304 self._refresh_gcs_client_if_needed()
305
306 if isinstance(f, str):
307 filesize = os.path.getsize(f)
308
309 def _invoke_api() -> None:
310 bucket_obj = self._gcs_client.bucket(bucket)
311 blob = bucket_obj.blob(key)
312 blob.upload_from_filename(f)
313
314 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize)
315 else:
316 f.seek(0, io.SEEK_END)
317 filesize = f.tell()
318 f.seek(0)
319
320 def _invoke_api() -> None:
321 bucket_obj = self._gcs_client.bucket(bucket)
322 blob = bucket_obj.blob(key)
323 blob.upload_from_string(f.read())
324
325 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize)
326
327 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
328 self._refresh_gcs_client_if_needed()
329
330 if not metadata:
331 metadata = self._get_object_metadata(remote_path)
332
333 bucket, key = split_path(remote_path)
334
335 if isinstance(f, str):
336 os.makedirs(os.path.dirname(f), exist_ok=True)
337
338 def _invoke_api() -> None:
339 bucket_obj = self._gcs_client.bucket(bucket)
340 blob = bucket_obj.blob(key)
341
342 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
343 temp_file_path = fp.name
344 blob.download_to_filename(temp_file_path)
345 os.rename(src=temp_file_path, dst=f)
346
347 return self._collect_metrics(
348 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
349 )
350 else:
351
352 def _invoke_api() -> None:
353 bucket_obj = self._gcs_client.bucket(bucket)
354 blob = bucket_obj.blob(key)
355 if isinstance(f, io.TextIOBase):
356 content = blob.download_as_text()
357 f.write(content)
358 else:
359 blob.download_to_file(f)
360
361 return self._collect_metrics(
362 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
363 )