1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import time
20from datetime import datetime
21from typing import IO, Any, Callable, Iterator, Optional, Union
22
23from azure.core.exceptions import ResourceNotFoundError
24from azure.storage.blob import BlobPrefix, BlobServiceClient
25
26from ..types import (
27 Credentials,
28 CredentialsProvider,
29 ObjectMetadata,
30 Range,
31)
32from ..utils import split_path
33from .base import BaseStorageProvider
34
35PROVIDER = "azure"
36
37
[docs]
38class StaticAzureCredentialsProvider(CredentialsProvider):
39 """
40 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials.
41 """
42
43 _connection: str
44
45 def __init__(self, connection: str):
46 """
47 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string.
48
49 :param connection: The connection string for Azure Blob Storage authentication.
50 """
51 self._connection = connection
52
[docs]
53 def get_credentials(self) -> Credentials:
54 return Credentials(
55 access_key=self._connection,
56 secret_key="",
57 token=None,
58 expiration=None,
59 )
60
[docs]
61 def refresh_credentials(self) -> None:
62 pass
63
64
[docs]
65class AzureBlobStorageProvider(BaseStorageProvider):
66 """
67 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage.
68 """
69
70 def __init__(
71 self, endpoint_url: str, base_path: str = "", credentials_provider: Optional[CredentialsProvider] = None
72 ):
73 """
74 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider.
75
76 :param endpoint_url: The Azure storage account URL.
77 :param base_path: The root prefix path within the container where all operations will be scoped.
78 :param credentials_provider: The provider to retrieve Azure credentials.
79 """
80 super().__init__(base_path=base_path, provider_name=PROVIDER)
81
82 self._account_url = endpoint_url
83 self._credentials_provider = credentials_provider
84 self._blob_service_client = self._create_blob_service_client()
85
86 def _create_blob_service_client(self) -> BlobServiceClient:
87 """
88 Creates and configures the Azure BlobServiceClient using the current credentials.
89
90 :return: The configured BlobServiceClient.
91 """
92 if self._credentials_provider:
93 credentials = self._credentials_provider.get_credentials()
94 return BlobServiceClient.from_connection_string(credentials.access_key)
95 else:
96 return BlobServiceClient(account_url=self._account_url)
97
98 def _refresh_blob_service_client_if_needed(self) -> None:
99 """
100 Refreshes the BlobServiceClient if the current credentials are expired.
101 """
102 if self._credentials_provider:
103 credentials = self._credentials_provider.get_credentials()
104 if credentials.is_expired():
105 self._credentials_provider.refresh_credentials()
106 self._blob_service_client = self._create_blob_service_client()
107
108 def _collect_metrics(
109 self,
110 func: Callable,
111 operation: str,
112 container: str,
113 blob: str,
114 put_object_size: Optional[int] = None,
115 get_object_size: Optional[int] = None,
116 ) -> Any:
117 """
118 Collects and records performance metrics around Azure operations such as PUT, GET, DELETE, etc.
119
120 This method wraps an Azure operation and measures the time it takes to complete, along with recording
121 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
122 proper logging of duration and object size.
123
124 :param func: The function that performs the actual GCS operation.
125 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
126 :param container: The name of the Azure container involved in the operation.
127 :param blob: The name of the blob within the Azure container.
128 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
129 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
130
131 :return The result of the GCS operation, typically the return value of the `func` callable.
132 """
133 start_time = time.time()
134 status_code = 200
135
136 object_size = None
137 if operation == "PUT":
138 object_size = put_object_size
139 elif operation == "GET" and get_object_size:
140 object_size = get_object_size
141
142 try:
143 result = func()
144 if operation == "GET" and object_size is None:
145 object_size = len(result)
146 return result
147 except ResourceNotFoundError:
148 status_code = 404
149 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from
150 except Exception as error:
151 status_code = -1
152 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}") from error
153 finally:
154 elapsed_time = time.time() - start_time
155 self._metric_helper.record_duration(
156 elapsed_time, provider=PROVIDER, operation=operation, bucket=container, status_code=status_code
157 )
158 if object_size:
159 self._metric_helper.record_object_size(
160 object_size, provider=PROVIDER, operation=operation, bucket=container, status_code=status_code
161 )
162
163 def _put_object(self, path: str, body: bytes) -> None:
164 container_name, blob_name = split_path(path)
165 self._refresh_blob_service_client_if_needed()
166
167 def _invoke_api() -> None:
168 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
169 blob_client.upload_blob(body, overwrite=True)
170
171 return self._collect_metrics(_invoke_api, operation="PUT", container=container_name, blob=blob_name)
172
173 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
174 container_name, blob_name = split_path(path)
175 self._refresh_blob_service_client_if_needed()
176
177 def _invoke_api() -> bytes:
178 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
179 if byte_range:
180 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size)
181 else:
182 stream = blob_client.download_blob()
183 return stream.readall()
184
185 return self._collect_metrics(_invoke_api, operation="GET", container=container_name, blob=blob_name)
186
187 def _copy_object(self, src_path: str, dest_path: str) -> None:
188 src_container, src_blob = split_path(src_path)
189 dest_container, dest_blob = split_path(dest_path)
190 self._refresh_blob_service_client_if_needed()
191
192 def _invoke_api() -> None:
193 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob)
194 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob)
195 dest_blob_client.start_copy_from_url(src_blob_client.url)
196
197 src_object = self._get_object_metadata(src_path)
198
199 return self._collect_metrics(
200 _invoke_api,
201 operation="COPY",
202 container=src_container,
203 blob=src_blob,
204 put_object_size=src_object.content_length,
205 )
206
207 def _delete_object(self, path: str) -> None:
208 container_name, blob_name = split_path(path)
209 self._refresh_blob_service_client_if_needed()
210
211 def _invoke_api() -> None:
212 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
213 blob_client.delete_blob()
214
215 return self._collect_metrics(_invoke_api, operation="DELETE", container=container_name, blob=blob_name)
216
217 def _is_dir(self, path: str) -> bool:
218 # Ensure the path ends with '/' to mimic a directory
219 path = self._append_delimiter(path)
220
221 container_name, prefix = split_path(path)
222 self._refresh_blob_service_client_if_needed()
223
224 def _invoke_api() -> bool:
225 # List objects with the given prefix
226 container_client = self._blob_service_client.get_container_client(container=container_name)
227 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/")
228 # Check if there are any contents or common prefixes
229 return any(True for _ in blobs)
230
231 return self._collect_metrics(_invoke_api, operation="LIST", container=container_name, blob=prefix)
232
233 def _get_object_metadata(self, path: str) -> ObjectMetadata:
234 if path.endswith("/"):
235 # If path is a "directory", then metadata is not guaranteed to exist if
236 # it is a "virtual prefix" that was never explicitly created.
237 if self._is_dir(path):
238 return ObjectMetadata(
239 key=self._append_delimiter(path),
240 type="directory",
241 content_length=0,
242 last_modified=datetime.min,
243 )
244 else:
245 raise FileNotFoundError(f"Directory {path} does not exist.")
246 else:
247 container_name, blob_name = split_path(path)
248 self._refresh_blob_service_client_if_needed()
249
250 def _invoke_api() -> ObjectMetadata:
251 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
252 properties = blob_client.get_blob_properties()
253 return ObjectMetadata(
254 key=path,
255 content_length=properties.size,
256 content_type=properties.content_settings.content_type,
257 last_modified=properties.last_modified,
258 etag=properties.etag.strip('"') if properties.etag else "",
259 )
260
261 try:
262 return self._collect_metrics(_invoke_api, operation="HEAD", container=container_name, blob=blob_name)
263 except FileNotFoundError as error:
264 # If the object does not exist on the given path, we will append a trailing slash and
265 # check if the path is a directory.
266 path = self._append_delimiter(path)
267 if self._is_dir(path):
268 return ObjectMetadata(
269 key=path,
270 type="directory",
271 content_length=0,
272 last_modified=datetime.min,
273 )
274 else:
275 raise error
276
277 def _list_objects(
278 self,
279 prefix: str,
280 start_after: Optional[str] = None,
281 end_at: Optional[str] = None,
282 include_directories: bool = False,
283 ) -> Iterator[ObjectMetadata]:
284 container_name, prefix = split_path(prefix)
285 self._refresh_blob_service_client_if_needed()
286
287 def _invoke_api() -> Iterator[ObjectMetadata]:
288 container_client = self._blob_service_client.get_container_client(container=container_name)
289 # Azure has no start key option like other object stores.
290 if include_directories:
291 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/")
292 else:
293 blobs = container_client.list_blobs(name_starts_with=prefix)
294 # Azure guarantees lexicographical order.
295 for blob in blobs:
296 if isinstance(blob, BlobPrefix):
297 yield ObjectMetadata(
298 key=blob.name.rstrip("/"),
299 type="directory",
300 content_length=0,
301 last_modified=datetime.min,
302 )
303 else:
304 key = blob.name
305 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
306 yield ObjectMetadata(
307 key=key,
308 content_length=blob.size,
309 content_type=blob.content_settings.content_type,
310 last_modified=blob.last_modified,
311 etag=blob.etag.strip('"') if blob.etag else "",
312 )
313 elif end_at is not None and end_at < key:
314 return
315
316 return self._collect_metrics(_invoke_api, operation="LIST", container=container_name, blob=prefix)
317
318 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None:
319 container_name, blob_name = split_path(remote_path)
320 self._refresh_blob_service_client_if_needed()
321
322 if isinstance(f, str):
323 file_size = os.path.getsize(f)
324
325 def _invoke_api() -> None:
326 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
327 with open(f, "rb") as data:
328 blob_client.upload_blob(data, overwrite=True)
329
330 return self._collect_metrics(
331 _invoke_api, operation="PUT", container=container_name, blob=blob_name, put_object_size=file_size
332 )
333 else:
334 # Convert StringIO to BytesIO before upload
335 if isinstance(f, io.StringIO):
336 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore
337 else:
338 fp = f
339
340 fp.seek(0, io.SEEK_END)
341 file_size = fp.tell()
342 fp.seek(0)
343
344 def _invoke_api() -> None:
345 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
346 blob_client.upload_blob(fp, overwrite=True)
347
348 return self._collect_metrics(
349 _invoke_api, operation="PUT", container=container_name, blob=blob_name, put_object_size=file_size
350 )
351
352 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
353 if not metadata:
354 metadata = self._get_object_metadata(remote_path)
355
356 container_name, blob_name = split_path(remote_path)
357 self._refresh_blob_service_client_if_needed()
358
359 if isinstance(f, str):
360 os.makedirs(os.path.dirname(f), exist_ok=True)
361
362 def _invoke_api() -> None:
363 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
364 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
365 temp_file_path = fp.name
366 stream = blob_client.download_blob()
367 fp.write(stream.readall())
368 os.rename(src=temp_file_path, dst=f)
369
370 return self._collect_metrics(
371 _invoke_api,
372 operation="GET",
373 container=container_name,
374 blob=blob_name,
375 get_object_size=metadata.content_length,
376 )
377 else:
378
379 def _invoke_api() -> None:
380 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
381 stream = blob_client.download_blob()
382 if isinstance(f, io.StringIO):
383 f.write(stream.readall().decode("utf-8"))
384 else:
385 f.write(stream.readall())
386
387 return self._collect_metrics(
388 _invoke_api,
389 operation="GET",
390 container=container_name,
391 blob=blob_name,
392 get_object_size=metadata.content_length,
393 )