1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import time
20from collections.abc import Callable, Iterator, Sequence, Sized
21from typing import IO, Any, Optional, TypeVar, Union
22
23import opentelemetry.metrics as api_metrics
24from azure.core import MatchConditions
25from azure.core.exceptions import AzureError, HttpResponseError
26from azure.storage.blob import BlobPrefix, BlobServiceClient
27
28from ..telemetry import Telemetry
29from ..telemetry.attributes.base import AttributesProvider
30from ..types import (
31 AWARE_DATETIME_MIN,
32 Credentials,
33 CredentialsProvider,
34 ObjectMetadata,
35 PreconditionFailedError,
36 Range,
37)
38from ..utils import split_path, validate_attributes
39from .base import BaseStorageProvider
40
41_T = TypeVar("_T")
42
43PROVIDER = "azure"
44
45
[docs]
46class StaticAzureCredentialsProvider(CredentialsProvider):
47 """
48 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials.
49 """
50
51 _connection: str
52
53 def __init__(self, connection: str):
54 """
55 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string.
56
57 :param connection: The connection string for Azure Blob Storage authentication.
58 """
59 self._connection = connection
60
[docs]
61 def get_credentials(self) -> Credentials:
62 return Credentials(
63 access_key=self._connection,
64 secret_key="",
65 token=None,
66 expiration=None,
67 )
68
[docs]
69 def refresh_credentials(self) -> None:
70 pass
71
72
[docs]
73class AzureBlobStorageProvider(BaseStorageProvider):
74 """
75 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage.
76 """
77
78 def __init__(
79 self,
80 endpoint_url: str,
81 base_path: str = "",
82 credentials_provider: Optional[CredentialsProvider] = None,
83 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
84 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
85 metric_attributes_providers: Sequence[AttributesProvider] = (),
86 **kwargs: dict[str, Any],
87 ):
88 """
89 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider.
90
91 :param endpoint_url: The Azure storage account URL.
92 :param base_path: The root prefix path within the container where all operations will be scoped.
93 :param credentials_provider: The provider to retrieve Azure credentials.
94 :param metric_counters: Metric counters.
95 :param metric_gauges: Metric gauges.
96 :param metric_attributes_providers: Metric attributes providers.
97 """
98 super().__init__(
99 base_path=base_path,
100 provider_name=PROVIDER,
101 metric_counters=metric_counters,
102 metric_gauges=metric_gauges,
103 metric_attributes_providers=metric_attributes_providers,
104 )
105
106 self._account_url = endpoint_url
107 self._credentials_provider = credentials_provider
108 # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#optional-configuration
109 client_optional_configuration_keys = {
110 "retry_total",
111 "retry_connect",
112 "retry_read",
113 "retry_status",
114 "connection_timeout",
115 "read_timeout",
116 }
117 self._client_optional_configuration = {
118 key: value for key, value in kwargs.items() if key in client_optional_configuration_keys
119 }
120 self._blob_service_client = self._create_blob_service_client()
121
122 def _create_blob_service_client(self) -> BlobServiceClient:
123 """
124 Creates and configures the Azure BlobServiceClient using the current credentials.
125
126 :return: The configured BlobServiceClient.
127 """
128 if self._credentials_provider:
129 credentials = self._credentials_provider.get_credentials()
130 return BlobServiceClient.from_connection_string(
131 credentials.access_key, **self._client_optional_configuration
132 )
133 else:
134 return BlobServiceClient(account_url=self._account_url, **self._client_optional_configuration)
135
136 def _refresh_blob_service_client_if_needed(self) -> None:
137 """
138 Refreshes the BlobServiceClient if the current credentials are expired.
139 """
140 if self._credentials_provider:
141 credentials = self._credentials_provider.get_credentials()
142 if credentials.is_expired():
143 self._credentials_provider.refresh_credentials()
144 self._blob_service_client = self._create_blob_service_client()
145
146 def _collect_metrics(
147 self,
148 func: Callable[[], _T],
149 operation: str,
150 container: str,
151 blob: str,
152 put_object_size: Optional[int] = None,
153 get_object_size: Optional[int] = None,
154 ) -> _T:
155 """
156 Collects and records performance metrics around Azure operations such as PUT, GET, DELETE, etc.
157
158 This method wraps an Azure operation and measures the time it takes to complete, along with recording
159 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
160 proper logging of duration and object size.
161
162 :param func: The function that performs the actual GCS operation.
163 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
164 :param container: The name of the Azure container involved in the operation.
165 :param blob: The name of the blob within the Azure container.
166 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
167 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
168
169 :return The result of the GCS operation, typically the return value of the `func` callable.
170 """
171 start_time = time.time()
172 status_code = 200
173
174 object_size = None
175 if operation == "PUT":
176 object_size = put_object_size
177 elif operation == "GET" and get_object_size:
178 object_size = get_object_size
179
180 try:
181 result = func()
182 if operation == "GET" and object_size is None and isinstance(result, Sized):
183 object_size = len(result)
184 return result
185 except HttpResponseError as error:
186 status_code = error.status_code if error.status_code else -1
187 error_info = f"status_code: {error.status_code}, reason: {error.reason}"
188 if status_code == 404:
189 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from
190 elif status_code == 412:
191 # raised when If-Match or If-Modified fails
192 raise PreconditionFailedError(
193 f"Failed to {operation} object(s) at {container}/{blob}. {error_info}"
194 ) from error
195 else:
196 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error
197 except AzureError as error:
198 status_code = -1
199 error_info = f"message: {error.message}"
200 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error
201 except Exception as error:
202 status_code = -1
203 raise RuntimeError(
204 f"Failed to {operation} object(s) at {container}/{blob}. error_type: {type(error).__name__}, error: {error}"
205 ) from error
206 finally:
207 elapsed_time = time.time() - start_time
208 self._metric_helper.record_duration(
209 elapsed_time,
210 provider=self._provider_name,
211 operation=operation,
212 bucket=container,
213 status_code=status_code,
214 )
215 if object_size:
216 self._metric_helper.record_object_size(
217 object_size,
218 provider=self._provider_name,
219 operation=operation,
220 bucket=container,
221 status_code=status_code,
222 )
223
224 def _put_object(
225 self,
226 path: str,
227 body: bytes,
228 if_match: Optional[str] = None,
229 if_none_match: Optional[str] = None,
230 attributes: Optional[dict[str, str]] = None,
231 ) -> int:
232 """
233 Uploads an object to Azure Blob Storage.
234
235 :param path: The path to the object to upload.
236 :param body: The content of the object to upload.
237 :param if_match: Optional ETag to match against the object.
238 :param if_none_match: Optional ETag to match against the object.
239 :param attributes: Optional attributes to attach to the object.
240 """
241 container_name, blob_name = split_path(path)
242 self._refresh_blob_service_client_if_needed()
243
244 def _invoke_api() -> int:
245 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
246
247 kwargs = {
248 "data": body,
249 "overwrite": True,
250 }
251
252 validated_attributes = validate_attributes(attributes)
253 if validated_attributes:
254 kwargs["metadata"] = validated_attributes
255
256 if if_match:
257 kwargs["match_condition"] = MatchConditions.IfNotModified
258 kwargs["etag"] = if_match
259
260 if if_none_match:
261 if if_none_match == "*":
262 raise NotImplementedError("if_none_match='*' is not supported for Azure")
263 kwargs["match_condition"] = MatchConditions.IfModified
264 kwargs["etag"] = if_none_match
265
266 blob_client.upload_blob(**kwargs)
267
268 return len(body)
269
270 return self._collect_metrics(
271 _invoke_api, operation="PUT", container=container_name, blob=blob_name, put_object_size=len(body)
272 )
273
274 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
275 container_name, blob_name = split_path(path)
276 self._refresh_blob_service_client_if_needed()
277
278 def _invoke_api() -> bytes:
279 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
280 if byte_range:
281 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size)
282 else:
283 stream = blob_client.download_blob()
284 return stream.readall()
285
286 return self._collect_metrics(_invoke_api, operation="GET", container=container_name, blob=blob_name)
287
288 def _copy_object(self, src_path: str, dest_path: str) -> int:
289 src_container, src_blob = split_path(src_path)
290 dest_container, dest_blob = split_path(dest_path)
291 self._refresh_blob_service_client_if_needed()
292
293 src_object = self._get_object_metadata(src_path)
294
295 def _invoke_api() -> int:
296 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob)
297 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob)
298 dest_blob_client.start_copy_from_url(src_blob_client.url)
299
300 return src_object.content_length
301
302 return self._collect_metrics(
303 _invoke_api,
304 operation="COPY",
305 container=src_container,
306 blob=src_blob,
307 put_object_size=src_object.content_length,
308 )
309
310 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
311 container_name, blob_name = split_path(path)
312 self._refresh_blob_service_client_if_needed()
313
314 def _invoke_api() -> None:
315 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
316 # If if_match is provided, use if_match for conditional deletion
317 if if_match:
318 blob_client.delete_blob(etag=if_match, match_condition=MatchConditions.IfNotModified)
319 else:
320 # No if_match provided, perform unconditional deletion
321 blob_client.delete_blob()
322
323 return self._collect_metrics(_invoke_api, operation="DELETE", container=container_name, blob=blob_name)
324
325 def _is_dir(self, path: str) -> bool:
326 # Ensure the path ends with '/' to mimic a directory
327 path = self._append_delimiter(path)
328
329 container_name, prefix = split_path(path)
330 self._refresh_blob_service_client_if_needed()
331
332 def _invoke_api() -> bool:
333 # List objects with the given prefix
334 container_client = self._blob_service_client.get_container_client(container=container_name)
335 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/")
336 # Check if there are any contents or common prefixes
337 return any(True for _ in blobs)
338
339 return self._collect_metrics(_invoke_api, operation="LIST", container=container_name, blob=prefix)
340
341 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
342 container_name, blob_name = split_path(path)
343 if path.endswith("/") or (container_name and not blob_name):
344 # If path ends with "/" or empty blob name is provided, then assume it's a "directory",
345 # which metadata is not guaranteed to exist for cases such as
346 # "virtual prefix" that was never explicitly created.
347 if self._is_dir(path):
348 return ObjectMetadata(
349 key=self._append_delimiter(path),
350 type="directory",
351 content_length=0,
352 last_modified=AWARE_DATETIME_MIN,
353 )
354 else:
355 raise FileNotFoundError(f"Directory {path} does not exist.")
356 else:
357 self._refresh_blob_service_client_if_needed()
358
359 def _invoke_api() -> ObjectMetadata:
360 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
361 properties = blob_client.get_blob_properties()
362 return ObjectMetadata(
363 key=path,
364 content_length=properties.size,
365 content_type=properties.content_settings.content_type,
366 last_modified=properties.last_modified,
367 etag=properties.etag.strip('"') if properties.etag else "",
368 metadata=dict(properties.metadata) if properties.metadata else None,
369 )
370
371 try:
372 return self._collect_metrics(_invoke_api, operation="HEAD", container=container_name, blob=blob_name)
373 except FileNotFoundError as error:
374 if strict:
375 # If the object does not exist on the given path, we will append a trailing slash and
376 # check if the path is a directory.
377 path = self._append_delimiter(path)
378 if self._is_dir(path):
379 return ObjectMetadata(
380 key=path,
381 type="directory",
382 content_length=0,
383 last_modified=AWARE_DATETIME_MIN,
384 )
385 raise error
386
387 def _list_objects(
388 self,
389 path: str,
390 start_after: Optional[str] = None,
391 end_at: Optional[str] = None,
392 include_directories: bool = False,
393 ) -> Iterator[ObjectMetadata]:
394 container_name, prefix = split_path(path)
395 self._refresh_blob_service_client_if_needed()
396
397 def _invoke_api() -> Iterator[ObjectMetadata]:
398 container_client = self._blob_service_client.get_container_client(container=container_name)
399 # Azure has no start key option like other object stores.
400 if include_directories:
401 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/")
402 else:
403 blobs = container_client.list_blobs(name_starts_with=prefix)
404 # Azure guarantees lexicographical order.
405 for blob in blobs:
406 if isinstance(blob, BlobPrefix):
407 yield ObjectMetadata(
408 key=os.path.join(container_name, blob.name.rstrip("/")),
409 type="directory",
410 content_length=0,
411 last_modified=AWARE_DATETIME_MIN,
412 )
413 else:
414 key = blob.name
415 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
416 if key.endswith("/"):
417 if include_directories:
418 yield ObjectMetadata(
419 key=os.path.join(container_name, key.rstrip("/")),
420 type="directory",
421 content_length=0,
422 last_modified=blob.last_modified,
423 )
424 else:
425 yield ObjectMetadata(
426 key=os.path.join(container_name, key),
427 content_length=blob.size,
428 content_type=blob.content_settings.content_type,
429 last_modified=blob.last_modified,
430 etag=blob.etag.strip('"') if blob.etag else "",
431 )
432 elif end_at is not None and end_at < key:
433 return
434
435 return self._collect_metrics(_invoke_api, operation="LIST", container=container_name, blob=prefix)
436
437 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
438 container_name, blob_name = split_path(remote_path)
439 file_size: int = 0
440 self._refresh_blob_service_client_if_needed()
441
442 validated_attributes = validate_attributes(attributes)
443 if isinstance(f, str):
444 file_size = os.path.getsize(f)
445
446 def _invoke_api() -> int:
447 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
448 with open(f, "rb") as data:
449 blob_client.upload_blob(data, overwrite=True, metadata=validated_attributes or {})
450
451 return file_size
452
453 return self._collect_metrics(
454 _invoke_api, operation="PUT", container=container_name, blob=blob_name, put_object_size=file_size
455 )
456 else:
457 # Convert StringIO to BytesIO before upload
458 if isinstance(f, io.StringIO):
459 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore
460 else:
461 fp = f
462
463 fp.seek(0, io.SEEK_END)
464 file_size = fp.tell()
465 fp.seek(0)
466
467 def _invoke_api() -> int:
468 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
469 blob_client.upload_blob(fp, overwrite=True, metadata=validated_attributes or {})
470
471 return file_size
472
473 return self._collect_metrics(
474 _invoke_api, operation="PUT", container=container_name, blob=blob_name, put_object_size=file_size
475 )
476
477 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
478 if metadata is None:
479 metadata = self._get_object_metadata(remote_path)
480
481 container_name, blob_name = split_path(remote_path)
482 self._refresh_blob_service_client_if_needed()
483
484 if isinstance(f, str):
485 if os.path.dirname(f):
486 os.makedirs(os.path.dirname(f), exist_ok=True)
487
488 def _invoke_api() -> int:
489 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
490 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
491 temp_file_path = fp.name
492 stream = blob_client.download_blob()
493 fp.write(stream.readall())
494 os.rename(src=temp_file_path, dst=f)
495
496 return metadata.content_length
497
498 return self._collect_metrics(
499 _invoke_api,
500 operation="GET",
501 container=container_name,
502 blob=blob_name,
503 get_object_size=metadata.content_length,
504 )
505 else:
506
507 def _invoke_api() -> int:
508 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
509 stream = blob_client.download_blob()
510 if isinstance(f, io.StringIO):
511 f.write(stream.readall().decode("utf-8"))
512 else:
513 f.write(stream.readall())
514
515 return metadata.content_length
516
517 return self._collect_metrics(
518 _invoke_api,
519 operation="GET",
520 container=container_name,
521 blob=blob_name,
522 get_object_size=metadata.content_length,
523 )