1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19from collections.abc import Callable, Iterator
20from typing import IO, Any, Optional, TypeVar, Union
21
22from azure.core import MatchConditions
23from azure.core.exceptions import AzureError, HttpResponseError
24from azure.identity import DefaultAzureCredential
25from azure.storage.blob import BlobPrefix, BlobServiceClient
26
27from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
28from ..telemetry import Telemetry
29from ..types import (
30 AWARE_DATETIME_MIN,
31 Credentials,
32 CredentialsProvider,
33 ObjectMetadata,
34 PreconditionFailedError,
35 Range,
36)
37from ..utils import safe_makedirs, split_path, validate_attributes
38from .base import BaseStorageProvider
39
40_T = TypeVar("_T")
41
42PROVIDER = "azure"
43AZURE_CONNECTION_STRING_KEY = "connection"
44AZURE_CREDENTIAL_KEY = "azure_credential"
45
46
[docs]
47class StaticAzureCredentialsProvider(CredentialsProvider):
48 """
49 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials.
50 """
51
52 _connection: str
53
54 def __init__(self, connection: str):
55 """
56 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string.
57
58 :param connection: The connection string for Azure Blob Storage authentication.
59 """
60 self._connection = connection
61
[docs]
62 def get_credentials(self) -> Credentials:
63 return Credentials(
64 access_key=self._connection,
65 secret_key="",
66 token=None,
67 expiration=None,
68 custom_fields={AZURE_CONNECTION_STRING_KEY: self._connection},
69 )
70
[docs]
71 def refresh_credentials(self) -> None:
72 pass
73
74
[docs]
75class DefaultAzureCredentialsProvider(CredentialsProvider):
76 """
77 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that uses Azure Identity's :py:class:`azure.identity.DefaultAzureCredential` to authenticate with Blob Storage.
78
79 See :py:class:`azure.identity.DefaultAzureCredential` for provider options.
80 """
81
82 def __init__(self, **kwargs: dict[str, Any]):
83 self._credential = DefaultAzureCredential(**kwargs)
84
[docs]
85 def get_credentials(self) -> Credentials:
86 return Credentials(
87 access_key="",
88 secret_key="",
89 token=None,
90 expiration=None,
91 custom_fields={AZURE_CREDENTIAL_KEY: self._credential},
92 )
93
[docs]
94 def refresh_credentials(self) -> None:
95 pass
96
97
[docs]
98class AzureBlobStorageProvider(BaseStorageProvider):
99 """
100 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage.
101 """
102
103 def __init__(
104 self,
105 endpoint_url: str,
106 base_path: str = "",
107 credentials_provider: Optional[CredentialsProvider] = None,
108 config_dict: Optional[dict[str, Any]] = None,
109 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
110 **kwargs: dict[str, Any],
111 ):
112 """
113 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider.
114
115 :param endpoint_url: The Azure storage account URL.
116 :param base_path: The root prefix path within the container where all operations will be scoped.
117 :param credentials_provider: The provider to retrieve Azure credentials.
118 :param config_dict: Resolved MSC config.
119 :param telemetry_provider: A function that provides a telemetry instance.
120 """
121 super().__init__(
122 base_path=base_path,
123 provider_name=PROVIDER,
124 config_dict=config_dict,
125 telemetry_provider=telemetry_provider,
126 )
127
128 self._account_url = endpoint_url
129 self._credentials_provider = credentials_provider
130 # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#optional-configuration
131 client_optional_configuration_keys = {
132 "retry_total",
133 "retry_connect",
134 "retry_read",
135 "retry_status",
136 "connection_timeout",
137 "read_timeout",
138 }
139 self._client_optional_configuration: dict[str, Any] = {
140 key: value for key, value in kwargs.items() if key in client_optional_configuration_keys
141 }
142 if "connection_timeout" not in self._client_optional_configuration:
143 self._client_optional_configuration["connection_timeout"] = DEFAULT_CONNECT_TIMEOUT
144 if "read_timeout" not in self._client_optional_configuration:
145 self._client_optional_configuration["read_timeout"] = DEFAULT_READ_TIMEOUT
146 self._blob_service_client = self._create_blob_service_client()
147
148 def _create_blob_service_client(self) -> BlobServiceClient:
149 """
150 Creates and configures the Azure BlobServiceClient using the current credentials.
151
152 :return: The configured BlobServiceClient.
153 """
154 if self._credentials_provider:
155 credentials = self._credentials_provider.get_credentials()
156
157 if isinstance(self._credentials_provider, StaticAzureCredentialsProvider):
158 return BlobServiceClient.from_connection_string(
159 credentials.get_custom_field(AZURE_CONNECTION_STRING_KEY), **self._client_optional_configuration
160 )
161 elif isinstance(self._credentials_provider, DefaultAzureCredentialsProvider):
162 return BlobServiceClient(
163 account_url=self._account_url,
164 credential=credentials.get_custom_field(AZURE_CREDENTIAL_KEY),
165 **self._client_optional_configuration,
166 )
167 else:
168 # Fallback to connection string if no built-in credentials provider is provided
169 return BlobServiceClient.from_connection_string(
170 credentials.access_key, **self._client_optional_configuration
171 )
172 else:
173 return BlobServiceClient(account_url=self._account_url, **self._client_optional_configuration)
174
175 def _refresh_blob_service_client_if_needed(self) -> None:
176 """
177 Refreshes the BlobServiceClient if the current credentials are expired.
178 """
179 if self._credentials_provider:
180 credentials = self._credentials_provider.get_credentials()
181 if credentials.is_expired():
182 self._credentials_provider.refresh_credentials()
183 self._blob_service_client = self._create_blob_service_client()
184
185 def _translate_errors(
186 self,
187 func: Callable[[], _T],
188 operation: str,
189 container: str,
190 blob: str,
191 ) -> _T:
192 """
193 Translates errors like timeouts and client errors.
194
195 :param func: The function that performs the actual Azure Blob Storage operation.
196 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
197 :param container: The name of the Azure container involved in the operation.
198 :param blob: The name of the blob within the Azure container.
199
200 :return The result of the Azure Blob Storage operation, typically the return value of the `func` callable.
201 """
202 try:
203 return func()
204 except HttpResponseError as error:
205 status_code = error.status_code if error.status_code else -1
206 error_info = f"status_code: {error.status_code}, reason: {error.reason}"
207 if status_code == 404:
208 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from
209 elif status_code == 412:
210 # raised when If-Match or If-Modified fails
211 raise PreconditionFailedError(
212 f"Failed to {operation} object(s) at {container}/{blob}. {error_info}"
213 ) from error
214 else:
215 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error
216 except AzureError as error:
217 error_info = f"message: {error.message}"
218 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error
219 except FileNotFoundError:
220 raise
221 except Exception as error:
222 raise RuntimeError(
223 f"Failed to {operation} object(s) at {container}/{blob}. error_type: {type(error).__name__}, error: {error}"
224 ) from error
225
226 def _put_object(
227 self,
228 path: str,
229 body: bytes,
230 if_match: Optional[str] = None,
231 if_none_match: Optional[str] = None,
232 attributes: Optional[dict[str, str]] = None,
233 ) -> int:
234 """
235 Uploads an object to Azure Blob Storage.
236
237 :param path: The path to the object to upload.
238 :param body: The content of the object to upload.
239 :param if_match: Optional ETag to match against the object.
240 :param if_none_match: Optional ETag to match against the object.
241 :param attributes: Optional attributes to attach to the object.
242 """
243 container_name, blob_name = split_path(path)
244 self._refresh_blob_service_client_if_needed()
245
246 def _invoke_api() -> int:
247 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
248
249 kwargs = {
250 "data": body,
251 "overwrite": True,
252 }
253
254 validated_attributes = validate_attributes(attributes)
255 if validated_attributes:
256 kwargs["metadata"] = validated_attributes
257
258 if if_match:
259 kwargs["match_condition"] = MatchConditions.IfNotModified
260 kwargs["etag"] = if_match
261
262 if if_none_match:
263 if if_none_match == "*":
264 raise NotImplementedError("if_none_match='*' is not supported for Azure")
265 kwargs["match_condition"] = MatchConditions.IfModified
266 kwargs["etag"] = if_none_match
267
268 blob_client.upload_blob(**kwargs)
269
270 return len(body)
271
272 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name)
273
274 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
275 container_name, blob_name = split_path(path)
276 self._refresh_blob_service_client_if_needed()
277
278 def _invoke_api() -> bytes:
279 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
280 if byte_range:
281 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size)
282 else:
283 stream = blob_client.download_blob()
284 return stream.readall()
285
286 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)
287
288 def _copy_object(self, src_path: str, dest_path: str) -> int:
289 src_container, src_blob = split_path(src_path)
290 dest_container, dest_blob = split_path(dest_path)
291 self._refresh_blob_service_client_if_needed()
292
293 src_object = self._get_object_metadata(src_path)
294
295 def _invoke_api() -> int:
296 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob)
297 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob)
298 dest_blob_client.start_copy_from_url(src_blob_client.url)
299
300 return src_object.content_length
301
302 return self._translate_errors(_invoke_api, operation="COPY", container=src_container, blob=src_blob)
303
304 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
305 container_name, blob_name = split_path(path)
306 self._refresh_blob_service_client_if_needed()
307
308 def _invoke_api() -> None:
309 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
310 # If if_match is provided, use if_match for conditional deletion
311 if if_match:
312 blob_client.delete_blob(etag=if_match, match_condition=MatchConditions.IfNotModified)
313 else:
314 # No if_match provided, perform unconditional deletion
315 blob_client.delete_blob()
316
317 return self._translate_errors(_invoke_api, operation="DELETE", container=container_name, blob=blob_name)
318
319 def _delete_objects(self, paths: list[str]) -> None:
320 if not paths:
321 return
322
323 by_container: dict[str, list[str]] = {}
324 for p in paths:
325 container_name, blob_name = split_path(p)
326 by_container.setdefault(container_name, []).append(blob_name)
327 self._refresh_blob_service_client_if_needed()
328
329 AZURE_BATCH_LIMIT = 256
330
331 def _invoke_api() -> None:
332 for container_name, blob_names in by_container.items():
333 container_client = self._blob_service_client.get_container_client(container=container_name)
334 for i in range(0, len(blob_names), AZURE_BATCH_LIMIT):
335 chunk = blob_names[i : i + AZURE_BATCH_LIMIT]
336 container_client.delete_blobs(*chunk)
337
338 container_desc = "(" + "|".join(by_container) + ")"
339 blob_desc = "(" + "|".join(str(len(blob_names)) for blob_names in by_container.values()) + " keys)"
340 self._translate_errors(_invoke_api, operation="DELETE_MANY", container=container_desc, blob=blob_desc)
341
342 def _is_dir(self, path: str) -> bool:
343 # Ensure the path ends with '/' to mimic a directory
344 path = self._append_delimiter(path)
345
346 container_name, prefix = split_path(path)
347 self._refresh_blob_service_client_if_needed()
348
349 def _invoke_api() -> bool:
350 # List objects with the given prefix
351 container_client = self._blob_service_client.get_container_client(container=container_name)
352 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/")
353 # Check if there are any contents or common prefixes
354 return any(True for _ in blobs)
355
356 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix)
357
358 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
359 container_name, blob_name = split_path(path)
360 if path.endswith("/") or (container_name and not blob_name):
361 # If path ends with "/" or empty blob name is provided, then assume it's a "directory",
362 # which metadata is not guaranteed to exist for cases such as
363 # "virtual prefix" that was never explicitly created.
364 if self._is_dir(path):
365 return ObjectMetadata(
366 key=self._append_delimiter(path),
367 type="directory",
368 content_length=0,
369 last_modified=AWARE_DATETIME_MIN,
370 )
371 else:
372 raise FileNotFoundError(f"Directory {path} does not exist.")
373 else:
374 self._refresh_blob_service_client_if_needed()
375
376 def _invoke_api() -> ObjectMetadata:
377 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
378 properties = blob_client.get_blob_properties()
379 return ObjectMetadata(
380 key=path,
381 content_length=properties.size,
382 content_type=properties.content_settings.content_type,
383 last_modified=properties.last_modified,
384 etag=properties.etag.strip('"') if properties.etag else "",
385 metadata=dict(properties.metadata) if properties.metadata else None,
386 )
387
388 try:
389 return self._translate_errors(_invoke_api, operation="HEAD", container=container_name, blob=blob_name)
390 except FileNotFoundError as error:
391 if strict:
392 # If the object does not exist on the given path, we will append a trailing slash and
393 # check if the path is a directory.
394 path = self._append_delimiter(path)
395 if self._is_dir(path):
396 return ObjectMetadata(
397 key=path,
398 type="directory",
399 content_length=0,
400 last_modified=AWARE_DATETIME_MIN,
401 )
402 raise error
403
404 def _list_objects(
405 self,
406 path: str,
407 start_after: Optional[str] = None,
408 end_at: Optional[str] = None,
409 include_directories: bool = False,
410 follow_symlinks: bool = True,
411 ) -> Iterator[ObjectMetadata]:
412 container_name, prefix = split_path(path)
413
414 # Get the prefix of the start_after and end_at paths relative to the bucket.
415 if start_after:
416 _, start_after = split_path(start_after)
417 if end_at:
418 _, end_at = split_path(end_at)
419
420 self._refresh_blob_service_client_if_needed()
421
422 def _invoke_api() -> Iterator[ObjectMetadata]:
423 container_client = self._blob_service_client.get_container_client(container=container_name)
424 # Azure has no start key option like other object stores.
425 if include_directories:
426 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/")
427 else:
428 blobs = container_client.list_blobs(name_starts_with=prefix)
429 # Azure guarantees lexicographical order.
430 for blob in blobs:
431 if isinstance(blob, BlobPrefix):
432 prefix_key = blob.name.rstrip("/")
433 # Filter by start_after and end_at if specified
434 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
435 yield ObjectMetadata(
436 key=os.path.join(container_name, prefix_key),
437 type="directory",
438 content_length=0,
439 last_modified=AWARE_DATETIME_MIN,
440 )
441 elif end_at is not None and end_at < prefix_key:
442 return
443 else:
444 key = blob.name
445 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
446 if key.endswith("/"):
447 if include_directories:
448 yield ObjectMetadata(
449 key=os.path.join(container_name, key.rstrip("/")),
450 type="directory",
451 content_length=0,
452 last_modified=blob.last_modified,
453 )
454 else:
455 yield ObjectMetadata(
456 key=os.path.join(container_name, key),
457 content_length=blob.size,
458 content_type=blob.content_settings.content_type,
459 last_modified=blob.last_modified,
460 etag=blob.etag.strip('"') if blob.etag else "",
461 )
462 elif end_at is not None and end_at < key:
463 return
464
465 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix)
466
467 @property
468 def supports_parallel_listing(self) -> bool:
469 return True
470
471 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
472 container_name, blob_name = split_path(remote_path)
473 file_size: int = 0
474 self._refresh_blob_service_client_if_needed()
475
476 validated_attributes = validate_attributes(attributes)
477 if isinstance(f, str):
478 file_size = os.path.getsize(f)
479
480 def _invoke_api() -> int:
481 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
482 with open(f, "rb") as data:
483 blob_client.upload_blob(data, overwrite=True, metadata=validated_attributes or {})
484
485 return file_size
486
487 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name)
488 else:
489 # Convert StringIO to BytesIO before upload
490 if isinstance(f, io.StringIO):
491 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore
492 else:
493 fp = f
494
495 fp.seek(0, io.SEEK_END)
496 file_size = fp.tell()
497 fp.seek(0)
498
499 def _invoke_api() -> int:
500 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
501 blob_client.upload_blob(fp, overwrite=True, metadata=validated_attributes or {})
502
503 return file_size
504
505 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name)
506
507 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
508 if metadata is None:
509 metadata = self._get_object_metadata(remote_path)
510
511 container_name, blob_name = split_path(remote_path)
512 self._refresh_blob_service_client_if_needed()
513
514 if isinstance(f, str):
515 if os.path.dirname(f):
516 safe_makedirs(os.path.dirname(f))
517
518 def _invoke_api() -> int:
519 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
520 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
521 temp_file_path = fp.name
522 stream = blob_client.download_blob()
523 fp.write(stream.readall())
524 os.rename(src=temp_file_path, dst=f)
525
526 return metadata.content_length
527
528 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)
529 else:
530
531 def _invoke_api() -> int:
532 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
533 stream = blob_client.download_blob()
534 if isinstance(f, io.StringIO):
535 f.write(stream.readall().decode("utf-8"))
536 else:
537 f.write(stream.readall())
538
539 return metadata.content_length
540
541 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)