Source code for multistorageclient.providers.azure

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19from collections.abc import Callable, Iterator
 20from typing import IO, Any, Optional, TypeVar, Union
 21
 22from azure.core import MatchConditions
 23from azure.core.exceptions import AzureError, HttpResponseError
 24from azure.identity import DefaultAzureCredential
 25from azure.storage.blob import BlobPrefix, BlobServiceClient
 26
 27from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
 28from ..telemetry import Telemetry
 29from ..types import (
 30    AWARE_DATETIME_MIN,
 31    Credentials,
 32    CredentialsProvider,
 33    ObjectMetadata,
 34    PreconditionFailedError,
 35    Range,
 36)
 37from ..utils import safe_makedirs, split_path, validate_attributes
 38from .base import BaseStorageProvider
 39
 40_T = TypeVar("_T")
 41
 42PROVIDER = "azure"
 43AZURE_CONNECTION_STRING_KEY = "connection"
 44AZURE_CREDENTIAL_KEY = "azure_credential"
 45
 46
[docs] 47class StaticAzureCredentialsProvider(CredentialsProvider): 48 """ 49 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials. 50 """ 51 52 _connection: str 53 54 def __init__(self, connection: str): 55 """ 56 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string. 57 58 :param connection: The connection string for Azure Blob Storage authentication. 59 """ 60 self._connection = connection 61
[docs] 62 def get_credentials(self) -> Credentials: 63 return Credentials( 64 access_key=self._connection, 65 secret_key="", 66 token=None, 67 expiration=None, 68 custom_fields={AZURE_CONNECTION_STRING_KEY: self._connection}, 69 )
70
[docs] 71 def refresh_credentials(self) -> None: 72 pass
73 74
[docs] 75class DefaultAzureCredentialsProvider(CredentialsProvider): 76 """ 77 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that uses Azure Identity's :py:class:`azure.identity.DefaultAzureCredential` to authenticate with Blob Storage. 78 79 See :py:class:`azure.identity.DefaultAzureCredential` for provider options. 80 """ 81 82 def __init__(self, **kwargs: dict[str, Any]): 83 self._credential = DefaultAzureCredential(**kwargs) 84
[docs] 85 def get_credentials(self) -> Credentials: 86 return Credentials( 87 access_key="", 88 secret_key="", 89 token=None, 90 expiration=None, 91 custom_fields={AZURE_CREDENTIAL_KEY: self._credential}, 92 )
93
[docs] 94 def refresh_credentials(self) -> None: 95 pass
96 97
[docs] 98class AzureBlobStorageProvider(BaseStorageProvider): 99 """ 100 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage. 101 """ 102 103 def __init__( 104 self, 105 endpoint_url: str, 106 base_path: str = "", 107 credentials_provider: Optional[CredentialsProvider] = None, 108 config_dict: Optional[dict[str, Any]] = None, 109 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 110 **kwargs: dict[str, Any], 111 ): 112 """ 113 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider. 114 115 :param endpoint_url: The Azure storage account URL. 116 :param base_path: The root prefix path within the container where all operations will be scoped. 117 :param credentials_provider: The provider to retrieve Azure credentials. 118 :param config_dict: Resolved MSC config. 119 :param telemetry_provider: A function that provides a telemetry instance. 120 """ 121 super().__init__( 122 base_path=base_path, 123 provider_name=PROVIDER, 124 config_dict=config_dict, 125 telemetry_provider=telemetry_provider, 126 ) 127 128 self._account_url = endpoint_url 129 self._credentials_provider = credentials_provider 130 # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#optional-configuration 131 client_optional_configuration_keys = { 132 "retry_total", 133 "retry_connect", 134 "retry_read", 135 "retry_status", 136 "connection_timeout", 137 "read_timeout", 138 } 139 self._client_optional_configuration: dict[str, Any] = { 140 key: value for key, value in kwargs.items() if key in client_optional_configuration_keys 141 } 142 if "connection_timeout" not in self._client_optional_configuration: 143 self._client_optional_configuration["connection_timeout"] = DEFAULT_CONNECT_TIMEOUT 144 if "read_timeout" not in self._client_optional_configuration: 145 self._client_optional_configuration["read_timeout"] = DEFAULT_READ_TIMEOUT 146 self._blob_service_client = self._create_blob_service_client() 147 148 def _create_blob_service_client(self) -> BlobServiceClient: 149 """ 150 Creates and configures the Azure BlobServiceClient using the current credentials. 151 152 :return: The configured BlobServiceClient. 153 """ 154 if self._credentials_provider: 155 credentials = self._credentials_provider.get_credentials() 156 157 if isinstance(self._credentials_provider, StaticAzureCredentialsProvider): 158 return BlobServiceClient.from_connection_string( 159 credentials.get_custom_field(AZURE_CONNECTION_STRING_KEY), **self._client_optional_configuration 160 ) 161 elif isinstance(self._credentials_provider, DefaultAzureCredentialsProvider): 162 return BlobServiceClient( 163 account_url=self._account_url, 164 credential=credentials.get_custom_field(AZURE_CREDENTIAL_KEY), 165 **self._client_optional_configuration, 166 ) 167 else: 168 # Fallback to connection string if no built-in credentials provider is provided 169 return BlobServiceClient.from_connection_string( 170 credentials.access_key, **self._client_optional_configuration 171 ) 172 else: 173 return BlobServiceClient(account_url=self._account_url, **self._client_optional_configuration) 174 175 def _refresh_blob_service_client_if_needed(self) -> None: 176 """ 177 Refreshes the BlobServiceClient if the current credentials are expired. 178 """ 179 if self._credentials_provider: 180 credentials = self._credentials_provider.get_credentials() 181 if credentials.is_expired(): 182 self._credentials_provider.refresh_credentials() 183 self._blob_service_client = self._create_blob_service_client() 184 185 def _translate_errors( 186 self, 187 func: Callable[[], _T], 188 operation: str, 189 container: str, 190 blob: str, 191 ) -> _T: 192 """ 193 Translates errors like timeouts and client errors. 194 195 :param func: The function that performs the actual Azure Blob Storage operation. 196 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 197 :param container: The name of the Azure container involved in the operation. 198 :param blob: The name of the blob within the Azure container. 199 200 :return The result of the Azure Blob Storage operation, typically the return value of the `func` callable. 201 """ 202 try: 203 return func() 204 except HttpResponseError as error: 205 status_code = error.status_code if error.status_code else -1 206 error_info = f"status_code: {error.status_code}, reason: {error.reason}" 207 if status_code == 404: 208 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from 209 elif status_code == 412: 210 # raised when If-Match or If-Modified fails 211 raise PreconditionFailedError( 212 f"Failed to {operation} object(s) at {container}/{blob}. {error_info}" 213 ) from error 214 else: 215 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 216 except AzureError as error: 217 error_info = f"message: {error.message}" 218 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 219 except FileNotFoundError: 220 raise 221 except Exception as error: 222 raise RuntimeError( 223 f"Failed to {operation} object(s) at {container}/{blob}. error_type: {type(error).__name__}, error: {error}" 224 ) from error 225 226 def _put_object( 227 self, 228 path: str, 229 body: bytes, 230 if_match: Optional[str] = None, 231 if_none_match: Optional[str] = None, 232 attributes: Optional[dict[str, str]] = None, 233 ) -> int: 234 """ 235 Uploads an object to Azure Blob Storage. 236 237 :param path: The path to the object to upload. 238 :param body: The content of the object to upload. 239 :param if_match: Optional ETag to match against the object. 240 :param if_none_match: Optional ETag to match against the object. 241 :param attributes: Optional attributes to attach to the object. 242 """ 243 container_name, blob_name = split_path(path) 244 self._refresh_blob_service_client_if_needed() 245 246 def _invoke_api() -> int: 247 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 248 249 kwargs = { 250 "data": body, 251 "overwrite": True, 252 } 253 254 validated_attributes = validate_attributes(attributes) 255 if validated_attributes: 256 kwargs["metadata"] = validated_attributes 257 258 if if_match: 259 kwargs["match_condition"] = MatchConditions.IfNotModified 260 kwargs["etag"] = if_match 261 262 if if_none_match: 263 if if_none_match == "*": 264 raise NotImplementedError("if_none_match='*' is not supported for Azure") 265 kwargs["match_condition"] = MatchConditions.IfModified 266 kwargs["etag"] = if_none_match 267 268 blob_client.upload_blob(**kwargs) 269 270 return len(body) 271 272 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 273 274 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 275 container_name, blob_name = split_path(path) 276 self._refresh_blob_service_client_if_needed() 277 278 def _invoke_api() -> bytes: 279 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 280 if byte_range: 281 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size) 282 else: 283 stream = blob_client.download_blob() 284 return stream.readall() 285 286 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 287 288 def _copy_object(self, src_path: str, dest_path: str) -> int: 289 src_container, src_blob = split_path(src_path) 290 dest_container, dest_blob = split_path(dest_path) 291 self._refresh_blob_service_client_if_needed() 292 293 src_object = self._get_object_metadata(src_path) 294 295 def _invoke_api() -> int: 296 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob) 297 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob) 298 dest_blob_client.start_copy_from_url(src_blob_client.url) 299 300 return src_object.content_length 301 302 return self._translate_errors(_invoke_api, operation="COPY", container=src_container, blob=src_blob) 303 304 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 305 container_name, blob_name = split_path(path) 306 self._refresh_blob_service_client_if_needed() 307 308 def _invoke_api() -> None: 309 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 310 # If if_match is provided, use if_match for conditional deletion 311 if if_match: 312 blob_client.delete_blob(etag=if_match, match_condition=MatchConditions.IfNotModified) 313 else: 314 # No if_match provided, perform unconditional deletion 315 blob_client.delete_blob() 316 317 return self._translate_errors(_invoke_api, operation="DELETE", container=container_name, blob=blob_name) 318 319 def _delete_objects(self, paths: list[str]) -> None: 320 if not paths: 321 return 322 323 by_container: dict[str, list[str]] = {} 324 for p in paths: 325 container_name, blob_name = split_path(p) 326 by_container.setdefault(container_name, []).append(blob_name) 327 self._refresh_blob_service_client_if_needed() 328 329 AZURE_BATCH_LIMIT = 256 330 331 def _invoke_api() -> None: 332 for container_name, blob_names in by_container.items(): 333 container_client = self._blob_service_client.get_container_client(container=container_name) 334 for i in range(0, len(blob_names), AZURE_BATCH_LIMIT): 335 chunk = blob_names[i : i + AZURE_BATCH_LIMIT] 336 container_client.delete_blobs(*chunk) 337 338 container_desc = "(" + "|".join(by_container) + ")" 339 blob_desc = "(" + "|".join(str(len(blob_names)) for blob_names in by_container.values()) + " keys)" 340 self._translate_errors(_invoke_api, operation="DELETE_MANY", container=container_desc, blob=blob_desc) 341 342 def _is_dir(self, path: str) -> bool: 343 # Ensure the path ends with '/' to mimic a directory 344 path = self._append_delimiter(path) 345 346 container_name, prefix = split_path(path) 347 self._refresh_blob_service_client_if_needed() 348 349 def _invoke_api() -> bool: 350 # List objects with the given prefix 351 container_client = self._blob_service_client.get_container_client(container=container_name) 352 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 353 # Check if there are any contents or common prefixes 354 return any(True for _ in blobs) 355 356 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 357 358 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 359 container_name, blob_name = split_path(path) 360 if path.endswith("/") or (container_name and not blob_name): 361 # If path ends with "/" or empty blob name is provided, then assume it's a "directory", 362 # which metadata is not guaranteed to exist for cases such as 363 # "virtual prefix" that was never explicitly created. 364 if self._is_dir(path): 365 return ObjectMetadata( 366 key=self._append_delimiter(path), 367 type="directory", 368 content_length=0, 369 last_modified=AWARE_DATETIME_MIN, 370 ) 371 else: 372 raise FileNotFoundError(f"Directory {path} does not exist.") 373 else: 374 self._refresh_blob_service_client_if_needed() 375 376 def _invoke_api() -> ObjectMetadata: 377 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 378 properties = blob_client.get_blob_properties() 379 return ObjectMetadata( 380 key=path, 381 content_length=properties.size, 382 content_type=properties.content_settings.content_type, 383 last_modified=properties.last_modified, 384 etag=properties.etag.strip('"') if properties.etag else "", 385 metadata=dict(properties.metadata) if properties.metadata else None, 386 ) 387 388 try: 389 return self._translate_errors(_invoke_api, operation="HEAD", container=container_name, blob=blob_name) 390 except FileNotFoundError as error: 391 if strict: 392 # If the object does not exist on the given path, we will append a trailing slash and 393 # check if the path is a directory. 394 path = self._append_delimiter(path) 395 if self._is_dir(path): 396 return ObjectMetadata( 397 key=path, 398 type="directory", 399 content_length=0, 400 last_modified=AWARE_DATETIME_MIN, 401 ) 402 raise error 403 404 def _list_objects( 405 self, 406 path: str, 407 start_after: Optional[str] = None, 408 end_at: Optional[str] = None, 409 include_directories: bool = False, 410 follow_symlinks: bool = True, 411 ) -> Iterator[ObjectMetadata]: 412 container_name, prefix = split_path(path) 413 414 # Get the prefix of the start_after and end_at paths relative to the bucket. 415 if start_after: 416 _, start_after = split_path(start_after) 417 if end_at: 418 _, end_at = split_path(end_at) 419 420 self._refresh_blob_service_client_if_needed() 421 422 def _invoke_api() -> Iterator[ObjectMetadata]: 423 container_client = self._blob_service_client.get_container_client(container=container_name) 424 # Azure has no start key option like other object stores. 425 if include_directories: 426 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 427 else: 428 blobs = container_client.list_blobs(name_starts_with=prefix) 429 # Azure guarantees lexicographical order. 430 for blob in blobs: 431 if isinstance(blob, BlobPrefix): 432 prefix_key = blob.name.rstrip("/") 433 # Filter by start_after and end_at if specified 434 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 435 yield ObjectMetadata( 436 key=os.path.join(container_name, prefix_key), 437 type="directory", 438 content_length=0, 439 last_modified=AWARE_DATETIME_MIN, 440 ) 441 elif end_at is not None and end_at < prefix_key: 442 return 443 else: 444 key = blob.name 445 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 446 if key.endswith("/"): 447 if include_directories: 448 yield ObjectMetadata( 449 key=os.path.join(container_name, key.rstrip("/")), 450 type="directory", 451 content_length=0, 452 last_modified=blob.last_modified, 453 ) 454 else: 455 yield ObjectMetadata( 456 key=os.path.join(container_name, key), 457 content_length=blob.size, 458 content_type=blob.content_settings.content_type, 459 last_modified=blob.last_modified, 460 etag=blob.etag.strip('"') if blob.etag else "", 461 ) 462 elif end_at is not None and end_at < key: 463 return 464 465 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 466 467 @property 468 def supports_parallel_listing(self) -> bool: 469 return True 470 471 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 472 container_name, blob_name = split_path(remote_path) 473 file_size: int = 0 474 self._refresh_blob_service_client_if_needed() 475 476 validated_attributes = validate_attributes(attributes) 477 if isinstance(f, str): 478 file_size = os.path.getsize(f) 479 480 def _invoke_api() -> int: 481 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 482 with open(f, "rb") as data: 483 blob_client.upload_blob(data, overwrite=True, metadata=validated_attributes or {}) 484 485 return file_size 486 487 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 488 else: 489 # Convert StringIO to BytesIO before upload 490 if isinstance(f, io.StringIO): 491 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore 492 else: 493 fp = f 494 495 fp.seek(0, io.SEEK_END) 496 file_size = fp.tell() 497 fp.seek(0) 498 499 def _invoke_api() -> int: 500 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 501 blob_client.upload_blob(fp, overwrite=True, metadata=validated_attributes or {}) 502 503 return file_size 504 505 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 506 507 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 508 if metadata is None: 509 metadata = self._get_object_metadata(remote_path) 510 511 container_name, blob_name = split_path(remote_path) 512 self._refresh_blob_service_client_if_needed() 513 514 if isinstance(f, str): 515 if os.path.dirname(f): 516 safe_makedirs(os.path.dirname(f)) 517 518 def _invoke_api() -> int: 519 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 520 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 521 temp_file_path = fp.name 522 stream = blob_client.download_blob() 523 fp.write(stream.readall()) 524 os.rename(src=temp_file_path, dst=f) 525 526 return metadata.content_length 527 528 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 529 else: 530 531 def _invoke_api() -> int: 532 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 533 stream = blob_client.download_blob() 534 if isinstance(f, io.StringIO): 535 f.write(stream.readall().decode("utf-8")) 536 else: 537 f.write(stream.readall()) 538 539 return metadata.content_length 540 541 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)