1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import time
20from collections.abc import Callable, Iterator, Sequence, Sized
21from typing import IO, Any, Optional, TypeVar, Union
22
23import opentelemetry.metrics as api_metrics
24from azure.core import MatchConditions
25from azure.core.exceptions import AzureError, HttpResponseError
26from azure.storage.blob import BlobPrefix, BlobServiceClient
27
28from ..telemetry import Telemetry
29from ..telemetry.attributes.base import AttributesProvider
30from ..types import (
31 AWARE_DATETIME_MIN,
32 Credentials,
33 CredentialsProvider,
34 ObjectMetadata,
35 PreconditionFailedError,
36 Range,
37)
38from ..utils import split_path, validate_attributes
39from .base import BaseStorageProvider
40
41_T = TypeVar("_T")
42
43PROVIDER = "azure"
44
45
[docs]
46class StaticAzureCredentialsProvider(CredentialsProvider):
47 """
48 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials.
49 """
50
51 _connection: str
52
53 def __init__(self, connection: str):
54 """
55 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string.
56
57 :param connection: The connection string for Azure Blob Storage authentication.
58 """
59 self._connection = connection
60
[docs]
61 def get_credentials(self) -> Credentials:
62 return Credentials(
63 access_key=self._connection,
64 secret_key="",
65 token=None,
66 expiration=None,
67 )
68
[docs]
69 def refresh_credentials(self) -> None:
70 pass
71
72
[docs]
73class AzureBlobStorageProvider(BaseStorageProvider):
74 """
75 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage.
76 """
77
78 def __init__(
79 self,
80 endpoint_url: str,
81 base_path: str = "",
82 credentials_provider: Optional[CredentialsProvider] = None,
83 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
84 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
85 metric_attributes_providers: Sequence[AttributesProvider] = (),
86 **kwargs: dict[str, Any],
87 ):
88 """
89 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider.
90
91 :param endpoint_url: The Azure storage account URL.
92 :param base_path: The root prefix path within the container where all operations will be scoped.
93 :param credentials_provider: The provider to retrieve Azure credentials.
94 :param metric_counters: Metric counters.
95 :param metric_gauges: Metric gauges.
96 :param metric_attributes_providers: Metric attributes providers.
97 """
98 super().__init__(
99 base_path=base_path,
100 provider_name=PROVIDER,
101 metric_counters=metric_counters,
102 metric_gauges=metric_gauges,
103 metric_attributes_providers=metric_attributes_providers,
104 )
105
106 self._account_url = endpoint_url
107 self._credentials_provider = credentials_provider
108 # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#optional-configuration
109 client_optional_configuration_keys = {
110 "retry_total",
111 "retry_connect",
112 "retry_read",
113 "retry_status",
114 "connection_timeout",
115 "read_timeout",
116 }
117 self._client_optional_configuration = {
118 key: value for key, value in kwargs.items() if key in client_optional_configuration_keys
119 }
120 self._blob_service_client = self._create_blob_service_client()
121
122 def _create_blob_service_client(self) -> BlobServiceClient:
123 """
124 Creates and configures the Azure BlobServiceClient using the current credentials.
125
126 :return: The configured BlobServiceClient.
127 """
128 if self._credentials_provider:
129 credentials = self._credentials_provider.get_credentials()
130 return BlobServiceClient.from_connection_string(
131 credentials.access_key, **self._client_optional_configuration
132 )
133 else:
134 return BlobServiceClient(account_url=self._account_url, **self._client_optional_configuration)
135
136 def _refresh_blob_service_client_if_needed(self) -> None:
137 """
138 Refreshes the BlobServiceClient if the current credentials are expired.
139 """
140 if self._credentials_provider:
141 credentials = self._credentials_provider.get_credentials()
142 if credentials.is_expired():
143 self._credentials_provider.refresh_credentials()
144 self._blob_service_client = self._create_blob_service_client()
145
146 def _collect_metrics(
147 self,
148 func: Callable[[], _T],
149 operation: str,
150 container: str,
151 blob: str,
152 put_object_size: Optional[int] = None,
153 get_object_size: Optional[int] = None,
154 ) -> _T:
155 """
156 Collects and records performance metrics around Azure operations such as PUT, GET, DELETE, etc.
157
158 This method wraps an Azure operation and measures the time it takes to complete, along with recording
159 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
160 proper logging of duration and object size.
161
162 :param func: The function that performs the actual GCS operation.
163 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
164 :param container: The name of the Azure container involved in the operation.
165 :param blob: The name of the blob within the Azure container.
166 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
167 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
168
169 :return The result of the GCS operation, typically the return value of the `func` callable.
170 """
171 start_time = time.time()
172 status_code = 200
173
174 object_size = None
175 if operation == "PUT":
176 object_size = put_object_size
177 elif operation == "GET" and get_object_size:
178 object_size = get_object_size
179
180 try:
181 result = func()
182 if operation == "GET" and object_size is None and isinstance(result, Sized):
183 object_size = len(result)
184 return result
185 except HttpResponseError as error:
186 status_code = error.status_code if error.status_code else -1
187 error_info = f"status_code: {error.status_code}, reason: {error.reason}"
188 if status_code == 404:
189 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from
190 elif status_code == 412:
191 # raised when If-Match or If-Modified fails
192 raise PreconditionFailedError(
193 f"Failed to {operation} object(s) at {container}/{blob}. {error_info}"
194 ) from error
195 else:
196 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error
197 except AzureError as error:
198 status_code = -1
199 error_info = f"message: {error.message}"
200 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error
201 except Exception as error:
202 status_code = -1
203 raise RuntimeError(
204 f"Failed to {operation} object(s) at {container}/{blob}. error_type: {type(error).__name__}, error: {error}"
205 ) from error
206 finally:
207 elapsed_time = time.time() - start_time
208 self._metric_helper.record_duration(
209 elapsed_time,
210 provider=self._provider_name,
211 operation=operation,
212 bucket=container,
213 status_code=status_code,
214 )
215 if object_size:
216 self._metric_helper.record_object_size(
217 object_size,
218 provider=self._provider_name,
219 operation=operation,
220 bucket=container,
221 status_code=status_code,
222 )
223
224 def _put_object(
225 self,
226 path: str,
227 body: bytes,
228 if_match: Optional[str] = None,
229 if_none_match: Optional[str] = None,
230 attributes: Optional[dict[str, str]] = None,
231 ) -> int:
232 """
233 Uploads an object to Azure Blob Storage.
234
235 :param path: The path to the object to upload.
236 :param body: The content of the object to upload.
237 :param if_match: Optional ETag to match against the object.
238 :param if_none_match: Optional ETag to match against the object.
239 :param attributes: Optional attributes to attach to the object.
240 """
241 container_name, blob_name = split_path(path)
242 self._refresh_blob_service_client_if_needed()
243
244 def _invoke_api() -> int:
245 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
246
247 kwargs = {
248 "data": body,
249 "overwrite": True,
250 }
251
252 validated_attributes = validate_attributes(attributes)
253 if validated_attributes:
254 kwargs["metadata"] = validated_attributes
255
256 if if_match:
257 kwargs["match_condition"] = MatchConditions.IfNotModified
258 kwargs["etag"] = if_match
259
260 if if_none_match:
261 if if_none_match == "*":
262 raise NotImplementedError("if_none_match='*' is not supported for Azure")
263 kwargs["match_condition"] = MatchConditions.IfModified
264 kwargs["etag"] = if_none_match
265
266 blob_client.upload_blob(**kwargs)
267
268 return len(body)
269
270 return self._collect_metrics(
271 _invoke_api, operation="PUT", container=container_name, blob=blob_name, put_object_size=len(body)
272 )
273
274 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
275 container_name, blob_name = split_path(path)
276 self._refresh_blob_service_client_if_needed()
277
278 def _invoke_api() -> bytes:
279 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
280 if byte_range:
281 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size)
282 else:
283 stream = blob_client.download_blob()
284 return stream.readall()
285
286 return self._collect_metrics(_invoke_api, operation="GET", container=container_name, blob=blob_name)
287
288 def _copy_object(self, src_path: str, dest_path: str) -> int:
289 src_container, src_blob = split_path(src_path)
290 dest_container, dest_blob = split_path(dest_path)
291 self._refresh_blob_service_client_if_needed()
292
293 src_object = self._get_object_metadata(src_path)
294
295 def _invoke_api() -> int:
296 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob)
297 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob)
298 dest_blob_client.start_copy_from_url(src_blob_client.url)
299
300 return src_object.content_length
301
302 return self._collect_metrics(
303 _invoke_api,
304 operation="COPY",
305 container=src_container,
306 blob=src_blob,
307 put_object_size=src_object.content_length,
308 )
309
310 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
311 container_name, blob_name = split_path(path)
312 self._refresh_blob_service_client_if_needed()
313
314 def _invoke_api() -> None:
315 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
316 # If if_match is provided, use if_match for conditional deletion
317 if if_match:
318 blob_client.delete_blob(etag=if_match, match_condition=MatchConditions.IfNotModified)
319 else:
320 # No if_match provided, perform unconditional deletion
321 blob_client.delete_blob()
322
323 return self._collect_metrics(_invoke_api, operation="DELETE", container=container_name, blob=blob_name)
324
325 def _is_dir(self, path: str) -> bool:
326 # Ensure the path ends with '/' to mimic a directory
327 path = self._append_delimiter(path)
328
329 container_name, prefix = split_path(path)
330 self._refresh_blob_service_client_if_needed()
331
332 def _invoke_api() -> bool:
333 # List objects with the given prefix
334 container_client = self._blob_service_client.get_container_client(container=container_name)
335 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/")
336 # Check if there are any contents or common prefixes
337 return any(True for _ in blobs)
338
339 return self._collect_metrics(_invoke_api, operation="LIST", container=container_name, blob=prefix)
340
341 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
342 if path.endswith("/"):
343 # If path is a "directory", then metadata is not guaranteed to exist if
344 # it is a "virtual prefix" that was never explicitly created.
345 if self._is_dir(path):
346 return ObjectMetadata(
347 key=self._append_delimiter(path),
348 type="directory",
349 content_length=0,
350 last_modified=AWARE_DATETIME_MIN,
351 )
352 else:
353 raise FileNotFoundError(f"Directory {path} does not exist.")
354 else:
355 container_name, blob_name = split_path(path)
356 self._refresh_blob_service_client_if_needed()
357
358 def _invoke_api() -> ObjectMetadata:
359 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
360 properties = blob_client.get_blob_properties()
361 return ObjectMetadata(
362 key=path,
363 content_length=properties.size,
364 content_type=properties.content_settings.content_type,
365 last_modified=properties.last_modified,
366 etag=properties.etag.strip('"') if properties.etag else "",
367 metadata=dict(properties.metadata) if properties.metadata else None,
368 )
369
370 try:
371 return self._collect_metrics(_invoke_api, operation="HEAD", container=container_name, blob=blob_name)
372 except FileNotFoundError as error:
373 if strict:
374 # If the object does not exist on the given path, we will append a trailing slash and
375 # check if the path is a directory.
376 path = self._append_delimiter(path)
377 if self._is_dir(path):
378 return ObjectMetadata(
379 key=path,
380 type="directory",
381 content_length=0,
382 last_modified=AWARE_DATETIME_MIN,
383 )
384 raise error
385
386 def _list_objects(
387 self,
388 prefix: str,
389 start_after: Optional[str] = None,
390 end_at: Optional[str] = None,
391 include_directories: bool = False,
392 ) -> Iterator[ObjectMetadata]:
393 container_name, prefix = split_path(prefix)
394 self._refresh_blob_service_client_if_needed()
395
396 def _invoke_api() -> Iterator[ObjectMetadata]:
397 container_client = self._blob_service_client.get_container_client(container=container_name)
398 # Azure has no start key option like other object stores.
399 if include_directories:
400 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/")
401 else:
402 blobs = container_client.list_blobs(name_starts_with=prefix)
403 # Azure guarantees lexicographical order.
404 for blob in blobs:
405 if isinstance(blob, BlobPrefix):
406 yield ObjectMetadata(
407 key=os.path.join(container_name, blob.name.rstrip("/")),
408 type="directory",
409 content_length=0,
410 last_modified=AWARE_DATETIME_MIN,
411 )
412 else:
413 key = blob.name
414 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
415 if key.endswith("/"):
416 if include_directories:
417 yield ObjectMetadata(
418 key=os.path.join(container_name, key.rstrip("/")),
419 type="directory",
420 content_length=0,
421 last_modified=blob.last_modified,
422 )
423 else:
424 yield ObjectMetadata(
425 key=os.path.join(container_name, key),
426 content_length=blob.size,
427 content_type=blob.content_settings.content_type,
428 last_modified=blob.last_modified,
429 etag=blob.etag.strip('"') if blob.etag else "",
430 )
431 elif end_at is not None and end_at < key:
432 return
433
434 return self._collect_metrics(_invoke_api, operation="LIST", container=container_name, blob=prefix)
435
436 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
437 container_name, blob_name = split_path(remote_path)
438 file_size: int = 0
439 self._refresh_blob_service_client_if_needed()
440
441 if isinstance(f, str):
442 file_size = os.path.getsize(f)
443
444 def _invoke_api() -> int:
445 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
446 with open(f, "rb") as data:
447 blob_client.upload_blob(data, overwrite=True)
448
449 return file_size
450
451 return self._collect_metrics(
452 _invoke_api, operation="PUT", container=container_name, blob=blob_name, put_object_size=file_size
453 )
454 else:
455 # Convert StringIO to BytesIO before upload
456 if isinstance(f, io.StringIO):
457 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore
458 else:
459 fp = f
460
461 fp.seek(0, io.SEEK_END)
462 file_size = fp.tell()
463 fp.seek(0)
464
465 def _invoke_api() -> int:
466 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
467 blob_client.upload_blob(fp, overwrite=True)
468
469 return file_size
470
471 return self._collect_metrics(
472 _invoke_api, operation="PUT", container=container_name, blob=blob_name, put_object_size=file_size
473 )
474
475 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
476 if metadata is None:
477 metadata = self._get_object_metadata(remote_path)
478
479 container_name, blob_name = split_path(remote_path)
480 self._refresh_blob_service_client_if_needed()
481
482 if isinstance(f, str):
483 os.makedirs(os.path.dirname(f), exist_ok=True)
484
485 def _invoke_api() -> int:
486 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
487 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
488 temp_file_path = fp.name
489 stream = blob_client.download_blob()
490 fp.write(stream.readall())
491 os.rename(src=temp_file_path, dst=f)
492
493 return metadata.content_length
494
495 return self._collect_metrics(
496 _invoke_api,
497 operation="GET",
498 container=container_name,
499 blob=blob_name,
500 get_object_size=metadata.content_length,
501 )
502 else:
503
504 def _invoke_api() -> int:
505 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
506 stream = blob_client.download_blob()
507 if isinstance(f, io.StringIO):
508 f.write(stream.readall().decode("utf-8"))
509 else:
510 f.write(stream.readall())
511
512 return metadata.content_length
513
514 return self._collect_metrics(
515 _invoke_api,
516 operation="GET",
517 container=container_name,
518 blob=blob_name,
519 get_object_size=metadata.content_length,
520 )