Source code for multistorageclient.providers.azure

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19from collections.abc import Callable, Iterator
 20from datetime import datetime, timedelta, timezone
 21from typing import IO, Any, Optional, TypeVar, Union
 22from urllib.parse import urlparse
 23
 24from azure.core import MatchConditions
 25from azure.core.exceptions import AzureError, HttpResponseError
 26from azure.identity import DefaultAzureCredential
 27from azure.storage.blob import BlobPrefix, BlobServiceClient, generate_blob_sas
 28from azure.storage.blob._models import BlobSasPermissions
 29
 30from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
 31from ..signers.base import URLSigner
 32from ..telemetry import Telemetry
 33from ..types import (
 34    AWARE_DATETIME_MIN,
 35    Credentials,
 36    CredentialsProvider,
 37    ObjectMetadata,
 38    PreconditionFailedError,
 39    Range,
 40    SignerType,
 41    SymlinkHandling,
 42)
 43from ..utils import safe_makedirs, split_path, validate_attributes
 44from .base import BaseStorageProvider
 45
 46_T = TypeVar("_T")
 47
 48PROVIDER = "azure"
 49AZURE_CONNECTION_STRING_KEY = "connection"
 50AZURE_CREDENTIAL_KEY = "azure_credential"
 51
 52MiB = 1024 * 1024
 53
 54MULTIPART_THRESHOLD = 64 * MiB
 55MULTIPART_CHUNKSIZE = 32 * MiB
 56IO_CHUNKSIZE = 32 * MiB
 57PYTHON_MAX_CONCURRENCY = 8
 58
 59# Azure REST API only returns ``Content-MD5`` for GET ranges up to 4 MiB,
 60# so download chunks must stay at or below this limit when ``validate_content`` is enabled.
 61AZURE_CONTENT_MD5_RANGE_LIMIT_BYTES = 4 * MiB
 62
 63DEFAULT_PRESIGN_EXPIRES_IN = 3600
 64
 65# How long before delegation key expiry we treat the cached key as stale.
 66_DELEGATION_KEY_REFRESH_BUFFER = timedelta(minutes=5)
 67
 68# Azure's maximum allowed delegation key lifetime is 7 days.
 69_DELEGATION_KEY_LIFETIME = timedelta(days=7)
 70
 71
 72def _sas_permissions_for_method(method: str) -> BlobSasPermissions:
 73    """Return the minimal :class:`BlobSasPermissions` needed for *method*."""
 74    m = method.upper()
 75    if m in ("PUT", "POST"):
 76        return BlobSasPermissions(write=True, create=True)
 77    elif m == "DELETE":
 78        return BlobSasPermissions(delete=True)
 79    else:
 80        # GET, HEAD, and any unrecognised method → read-only
 81        return BlobSasPermissions(read=True)
 82
 83
 84def _parse_account_name_from_url(account_url: str) -> str:
 85    """Extract the storage account name from an Azure Blob Storage account URL."""
 86    hostname = urlparse(account_url).hostname
 87    if hostname is None:
 88        raise ValueError(f"Invalid Azure account URL: {account_url!r}")
 89    return hostname.split(".")[0]
 90
 91
 92def _parse_connection_string(conn_str: str) -> dict[str, str]:
 93    """Parse an Azure connection string (``AccountName=foo;AccountKey=bar;...``) into a dict."""
 94    return dict(part.split("=", 1) for part in conn_str.split(";") if "=" in part)
 95
 96
[docs] 97class AzureURLSigner(URLSigner): 98 """ 99 Generates Azure Blob Storage SAS (Shared Access Signature) URLs. 100 101 Supports two signing paths depending on which credential is provided: 102 103 * **Account key** – uses a static storage account key (parsed from a connection string). 104 * **User delegation key** – uses a time-limited key obtained via Azure Identity (e.g. workload 105 identity, managed identity). Callers are responsible for refreshing the signer when the 106 delegation key approaches expiry; see :py:meth:`AzureBlobStorageProvider._generate_presigned_url`. 107 """ 108 109 def __init__( 110 self, 111 account_name: str, 112 account_url: str, 113 *, 114 account_key: Optional[str] = None, 115 user_delegation_key: Optional[Any] = None, 116 expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN, 117 ) -> None: 118 if account_key is None and user_delegation_key is None: 119 raise ValueError("Either account_key or user_delegation_key must be provided.") 120 self._account_name = account_name 121 self._account_url = account_url.rstrip("/") 122 self._account_key = account_key 123 self._user_delegation_key = user_delegation_key 124 self._expires_in = expires_in 125
[docs] 126 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str: 127 """ 128 Generate a SAS URL for the given blob path. 129 130 :param path: Path in the form ``container/blob/name``. 131 :param method: HTTP method requested by the caller. 132 :return: A fully-qualified SAS URL. 133 """ 134 container_name, blob_name = split_path(path) 135 expiry = datetime.now(timezone.utc) + timedelta(seconds=self._expires_in) 136 137 sas_kwargs: dict[str, Any] = { 138 "account_name": self._account_name, 139 "container_name": container_name, 140 "blob_name": blob_name, 141 "permission": _sas_permissions_for_method(method), 142 "expiry": expiry, 143 } 144 145 if self._account_key is not None: 146 sas_kwargs["account_key"] = self._account_key 147 else: 148 sas_kwargs["user_delegation_key"] = self._user_delegation_key 149 150 sas_token = generate_blob_sas(**sas_kwargs) 151 blob_url = f"{self._account_url}/{container_name}/{blob_name}" 152 return f"{blob_url}?{sas_token}"
153 154
[docs] 155class StaticAzureCredentialsProvider(CredentialsProvider): 156 """ 157 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials. 158 """ 159 160 _connection: str 161 162 def __init__(self, connection: str): 163 """ 164 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string. 165 166 :param connection: The connection string for Azure Blob Storage authentication. 167 """ 168 self._connection = connection 169
[docs] 170 def get_credentials(self) -> Credentials: 171 return Credentials( 172 access_key=self._connection, 173 secret_key="", 174 token=None, 175 expiration=None, 176 custom_fields={AZURE_CONNECTION_STRING_KEY: self._connection}, 177 )
178
[docs] 179 def refresh_credentials(self) -> None: 180 pass
181 182
[docs] 183class DefaultAzureCredentialsProvider(CredentialsProvider): 184 """ 185 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that uses Azure Identity's :py:class:`azure.identity.DefaultAzureCredential` to authenticate with Blob Storage. 186 187 See :py:class:`azure.identity.DefaultAzureCredential` for provider options. 188 """ 189 190 def __init__(self, **kwargs: dict[str, Any]): 191 self._credential = DefaultAzureCredential(**kwargs) 192
[docs] 193 def get_credentials(self) -> Credentials: 194 return Credentials( 195 access_key="", 196 secret_key="", 197 token=None, 198 expiration=None, 199 custom_fields={AZURE_CREDENTIAL_KEY: self._credential}, 200 )
201
[docs] 202 def refresh_credentials(self) -> None: 203 pass
204 205
[docs] 206class AzureBlobStorageProvider(BaseStorageProvider): 207 """ 208 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage. 209 """ 210 211 def __init__( 212 self, 213 endpoint_url: str, 214 base_path: str = "", 215 credentials_provider: Optional[CredentialsProvider] = None, 216 config_dict: Optional[dict[str, Any]] = None, 217 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 218 **kwargs: Any, 219 ): 220 """ 221 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider. 222 223 :param endpoint_url: The Azure storage account URL. 224 :param base_path: The root prefix path within the container where all operations will be scoped. 225 :param credentials_provider: The provider to retrieve Azure credentials. 226 :param config_dict: Resolved MSC config. 227 :param telemetry_provider: A function that provides a telemetry instance. 228 :param kwargs: Additional options including: 229 - ``multipart_threshold`` (int): File size threshold (bytes) for switching to parallel chunked transfers. Defaults to 64 MiB. 230 - ``multipart_chunksize`` (int): Block size (bytes) for chunked uploads. Defaults to 32 MiB. 231 - ``io_chunksize`` (int): Chunk size (bytes) for chunked downloads. Defaults to 32 MiB. 232 - ``max_concurrency`` (int): Number of parallel threads for chunked transfers. Defaults to 8. 233 - ``validate_content`` (bool): Opt-in client-side MD5 verification. Defaults to False. 234 """ 235 super().__init__( 236 base_path=base_path, 237 provider_name=PROVIDER, 238 config_dict=config_dict, 239 telemetry_provider=telemetry_provider, 240 ) 241 242 self._account_url = endpoint_url 243 self._credentials_provider = credentials_provider 244 # Cache static connection-string signing material used for per-request signers. 245 self._account_key_signing_material: Optional[tuple[str, str]] = None 246 # Cached delegation key and its expiry for DefaultAzureCredentialsProvider. 247 self._delegation_user_key: Optional[Any] = None 248 self._delegation_signer_expiry: Optional[datetime] = None 249 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)) 250 self._multipart_chunksize = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)) 251 self._io_chunksize = int(kwargs.get("io_chunksize", IO_CHUNKSIZE)) 252 self._max_concurrency = int(kwargs.get("max_concurrency", PYTHON_MAX_CONCURRENCY)) 253 self._validate_content = kwargs.get("validate_content", False) 254 if not isinstance(self._validate_content, bool): 255 raise ValueError("Option 'validate_content' must be a boolean.") 256 if self._validate_content and self._io_chunksize > AZURE_CONTENT_MD5_RANGE_LIMIT_BYTES: 257 raise ValueError( 258 "Option 'validate_content=True' requires 'io_chunksize' to be " 259 f"<= {AZURE_CONTENT_MD5_RANGE_LIMIT_BYTES} bytes (4 MiB) because Azure only " 260 f"returns Content-MD5 for GET ranges within that limit. Got io_chunksize={self._io_chunksize}." 261 ) 262 263 # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#optional-configuration 264 client_optional_configuration_keys = { 265 "retry_total", 266 "retry_connect", 267 "retry_read", 268 "retry_status", 269 "connection_timeout", 270 "read_timeout", 271 } 272 self._client_optional_configuration: dict[str, Any] = { 273 key: value for key, value in kwargs.items() if key in client_optional_configuration_keys 274 } 275 if "connection_timeout" not in self._client_optional_configuration: 276 self._client_optional_configuration["connection_timeout"] = DEFAULT_CONNECT_TIMEOUT 277 if "read_timeout" not in self._client_optional_configuration: 278 self._client_optional_configuration["read_timeout"] = DEFAULT_READ_TIMEOUT 279 280 self._transfer_configuration: dict[str, Any] = { 281 "max_single_put_size": self._multipart_threshold, 282 "max_block_size": self._multipart_chunksize, 283 "max_single_get_size": self._multipart_threshold, 284 "max_chunk_get_size": self._io_chunksize, 285 } 286 287 self._blob_service_client = self._create_blob_service_client() 288 289 def _create_blob_service_client(self) -> BlobServiceClient: 290 """ 291 Creates and configures the Azure BlobServiceClient using the current credentials. 292 293 :return: The configured BlobServiceClient. 294 """ 295 combined_config = {**self._client_optional_configuration, **self._transfer_configuration} 296 297 if self._credentials_provider: 298 credentials = self._credentials_provider.get_credentials() 299 300 if isinstance(self._credentials_provider, StaticAzureCredentialsProvider): 301 return BlobServiceClient.from_connection_string( 302 credentials.get_custom_field(AZURE_CONNECTION_STRING_KEY), **combined_config 303 ) 304 elif isinstance(self._credentials_provider, DefaultAzureCredentialsProvider): 305 return BlobServiceClient( 306 account_url=self._account_url, 307 credential=credentials.get_custom_field(AZURE_CREDENTIAL_KEY), 308 **combined_config, 309 ) 310 else: 311 # Fallback to connection string if no built-in credentials provider is provided 312 return BlobServiceClient.from_connection_string(credentials.access_key, **combined_config) 313 else: 314 return BlobServiceClient(account_url=self._account_url, **combined_config) 315 316 def _refresh_blob_service_client_if_needed(self) -> None: 317 """ 318 Refreshes the BlobServiceClient if the current credentials are expired. 319 """ 320 if self._credentials_provider: 321 credentials = self._credentials_provider.get_credentials() 322 if credentials.is_expired(): 323 self._credentials_provider.refresh_credentials() 324 self._blob_service_client = self._create_blob_service_client() 325 326 def _translate_errors( 327 self, 328 func: Callable[[], _T], 329 operation: str, 330 container: str, 331 blob: str, 332 ) -> _T: 333 """ 334 Translates errors like timeouts and client errors. 335 336 :param func: The function that performs the actual Azure Blob Storage operation. 337 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 338 :param container: The name of the Azure container involved in the operation. 339 :param blob: The name of the blob within the Azure container. 340 341 :return The result of the Azure Blob Storage operation, typically the return value of the `func` callable. 342 """ 343 try: 344 return func() 345 except HttpResponseError as error: 346 status_code = error.status_code if error.status_code else -1 347 error_info = f"status_code: {error.status_code}, reason: {error.reason}" 348 if status_code == 404: 349 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from 350 elif status_code == 412: 351 # raised when If-Match or If-Modified fails 352 raise PreconditionFailedError( 353 f"Failed to {operation} object(s) at {container}/{blob}. {error_info}" 354 ) from error 355 else: 356 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 357 except AzureError as error: 358 error_info = f"message: {error.message}" 359 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 360 except FileNotFoundError: 361 raise 362 except Exception as error: 363 raise RuntimeError( 364 f"Failed to {operation} object(s) at {container}/{blob}. error_type: {type(error).__name__}, error: {error}" 365 ) from error 366 367 def _put_object( 368 self, 369 path: str, 370 body: bytes, 371 if_match: Optional[str] = None, 372 if_none_match: Optional[str] = None, 373 attributes: Optional[dict[str, str]] = None, 374 ) -> int: 375 """ 376 Uploads an object to Azure Blob Storage. 377 378 :param path: The path to the object to upload. 379 :param body: The content of the object to upload. 380 :param if_match: Optional ETag to match against the object. 381 :param if_none_match: Optional ETag to match against the object. 382 :param attributes: Optional attributes to attach to the object. 383 """ 384 container_name, blob_name = split_path(path) 385 self._refresh_blob_service_client_if_needed() 386 387 def _invoke_api() -> int: 388 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 389 390 kwargs: dict[str, Any] = { 391 "data": body, 392 "overwrite": True, 393 "max_concurrency": self._max_concurrency, 394 "validate_content": self._validate_content, 395 } 396 397 validated_attributes = validate_attributes(attributes) 398 if validated_attributes: 399 kwargs["metadata"] = validated_attributes 400 401 if if_match: 402 kwargs["match_condition"] = MatchConditions.IfNotModified 403 kwargs["etag"] = if_match 404 405 if if_none_match: 406 if if_none_match == "*": 407 raise NotImplementedError("if_none_match='*' is not supported for Azure") 408 kwargs["match_condition"] = MatchConditions.IfModified 409 kwargs["etag"] = if_none_match 410 411 blob_client.upload_blob(**kwargs) 412 413 return len(body) 414 415 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 416 417 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 418 container_name, blob_name = split_path(path) 419 self._refresh_blob_service_client_if_needed() 420 421 def _invoke_api() -> bytes: 422 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 423 if byte_range: 424 stream = blob_client.download_blob( 425 offset=byte_range.offset, 426 length=byte_range.size, 427 validate_content=self._validate_content, 428 ) 429 else: 430 stream = blob_client.download_blob( 431 max_concurrency=self._max_concurrency, 432 validate_content=self._validate_content, 433 ) 434 return stream.readall() 435 436 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 437 438 def _copy_object(self, src_path: str, dest_path: str) -> int: 439 src_container, src_blob = split_path(src_path) 440 dest_container, dest_blob = split_path(dest_path) 441 self._refresh_blob_service_client_if_needed() 442 443 src_object = self._get_object_metadata(src_path) 444 445 def _invoke_api() -> int: 446 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob) 447 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob) 448 dest_blob_client.start_copy_from_url(src_blob_client.url) 449 450 return src_object.content_length 451 452 return self._translate_errors(_invoke_api, operation="COPY", container=src_container, blob=src_blob) 453 454 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 455 container_name, blob_name = split_path(path) 456 self._refresh_blob_service_client_if_needed() 457 458 def _invoke_api() -> None: 459 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 460 # If if_match is provided, use if_match for conditional deletion 461 if if_match: 462 blob_client.delete_blob(etag=if_match, match_condition=MatchConditions.IfNotModified) 463 else: 464 # No if_match provided, perform unconditional deletion 465 blob_client.delete_blob() 466 467 return self._translate_errors(_invoke_api, operation="DELETE", container=container_name, blob=blob_name) 468 469 def _delete_objects(self, paths: list[str]) -> None: 470 if not paths: 471 return 472 473 by_container: dict[str, list[str]] = {} 474 for p in paths: 475 container_name, blob_name = split_path(p) 476 by_container.setdefault(container_name, []).append(blob_name) 477 self._refresh_blob_service_client_if_needed() 478 479 AZURE_BATCH_LIMIT = 256 480 481 def _invoke_api() -> None: 482 for container_name, blob_names in by_container.items(): 483 container_client = self._blob_service_client.get_container_client(container=container_name) 484 for i in range(0, len(blob_names), AZURE_BATCH_LIMIT): 485 chunk = blob_names[i : i + AZURE_BATCH_LIMIT] 486 container_client.delete_blobs(*chunk) 487 488 container_desc = "(" + "|".join(by_container) + ")" 489 blob_desc = "(" + "|".join(str(len(blob_names)) for blob_names in by_container.values()) + " keys)" 490 self._translate_errors(_invoke_api, operation="DELETE_MANY", container=container_desc, blob=blob_desc) 491 492 def _is_dir(self, path: str) -> bool: 493 # Ensure the path ends with '/' to mimic a directory 494 path = self._append_delimiter(path) 495 496 container_name, prefix = split_path(path) 497 self._refresh_blob_service_client_if_needed() 498 499 def _invoke_api() -> bool: 500 # List objects with the given prefix 501 container_client = self._blob_service_client.get_container_client(container=container_name) 502 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 503 # Check if there are any contents or common prefixes 504 return any(True for _ in blobs) 505 506 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 507 508 def _make_symlink(self, path: str, target: str) -> None: 509 container_name, blob_name = split_path(path) 510 target_container, target_key = split_path(target) 511 if container_name != target_container: 512 raise ValueError(f"Cannot create cross-container symlink: '{container_name}' -> '{target_container}'.") 513 relative_target = ObjectMetadata.encode_symlink_target(blob_name, target_key) 514 self._refresh_blob_service_client_if_needed() 515 516 def _invoke_api() -> None: 517 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 518 blob_client.upload_blob( 519 data=b"", 520 overwrite=True, 521 metadata={"msc_symlink_target": relative_target}, 522 ) 523 524 self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 525 526 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 527 container_name, blob_name = split_path(path) 528 if path.endswith("/") or (container_name and not blob_name): 529 # If path ends with "/" or empty blob name is provided, then assume it's a "directory", 530 # which metadata is not guaranteed to exist for cases such as 531 # "virtual prefix" that was never explicitly created. 532 if self._is_dir(path): 533 return ObjectMetadata( 534 key=self._append_delimiter(path), 535 type="directory", 536 content_length=0, 537 last_modified=AWARE_DATETIME_MIN, 538 ) 539 else: 540 raise FileNotFoundError(f"Directory {path} does not exist.") 541 else: 542 self._refresh_blob_service_client_if_needed() 543 544 def _invoke_api() -> ObjectMetadata: 545 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 546 properties = blob_client.get_blob_properties() 547 user_metadata = dict(properties.metadata) if properties.metadata else None 548 symlink_target = user_metadata.get("msc_symlink_target") if user_metadata else None 549 return ObjectMetadata( 550 key=path, 551 content_length=properties.size, 552 content_type=properties.content_settings.content_type, 553 last_modified=properties.last_modified, 554 etag=properties.etag.strip('"') if properties.etag else "", 555 metadata=user_metadata, 556 symlink_target=symlink_target, 557 ) 558 559 try: 560 return self._translate_errors(_invoke_api, operation="HEAD", container=container_name, blob=blob_name) 561 except FileNotFoundError as error: 562 if strict: 563 # If the object does not exist on the given path, we will append a trailing slash and 564 # check if the path is a directory. 565 path = self._append_delimiter(path) 566 if self._is_dir(path): 567 return ObjectMetadata( 568 key=path, 569 type="directory", 570 content_length=0, 571 last_modified=AWARE_DATETIME_MIN, 572 ) 573 raise error 574 575 def _list_objects( 576 self, 577 path: str, 578 start_after: Optional[str] = None, 579 end_at: Optional[str] = None, 580 include_directories: bool = False, 581 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW, 582 ) -> Iterator[ObjectMetadata]: 583 container_name, prefix = split_path(path) 584 585 # Get the prefix of the start_after and end_at paths relative to the bucket. 586 if start_after: 587 _, start_after = split_path(start_after) 588 if end_at: 589 _, end_at = split_path(end_at) 590 591 self._refresh_blob_service_client_if_needed() 592 593 def _invoke_api() -> Iterator[ObjectMetadata]: 594 container_client = self._blob_service_client.get_container_client(container=container_name) 595 # Azure has no start key option like other object stores. 596 if include_directories: 597 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/", include=["metadata"]) 598 else: 599 blobs = container_client.list_blobs(name_starts_with=prefix, include=["metadata"]) 600 # Azure guarantees lexicographical order. 601 for blob in blobs: 602 if isinstance(blob, BlobPrefix): 603 prefix_key = blob.name.rstrip("/") 604 # Filter by start_after and end_at if specified 605 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 606 yield ObjectMetadata( 607 key=os.path.join(container_name, prefix_key), 608 type="directory", 609 content_length=0, 610 last_modified=AWARE_DATETIME_MIN, 611 ) 612 elif end_at is not None and end_at < prefix_key: 613 return 614 else: 615 key = blob.name 616 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 617 if key.endswith("/"): 618 if include_directories: 619 yield ObjectMetadata( 620 key=os.path.join(container_name, key.rstrip("/")), 621 type="directory", 622 content_length=0, 623 last_modified=blob.last_modified, 624 ) 625 else: 626 user_metadata = dict(blob.metadata) if blob.metadata else None 627 symlink_target = user_metadata.get("msc_symlink_target") if user_metadata else None 628 yield ObjectMetadata( 629 key=os.path.join(container_name, key), 630 content_length=blob.size, 631 content_type=blob.content_settings.content_type, 632 last_modified=blob.last_modified, 633 etag=blob.etag.strip('"') if blob.etag else "", 634 symlink_target=symlink_target, 635 ) 636 elif end_at is not None and end_at < key: 637 return 638 639 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 640 641 def _generate_presigned_url( 642 self, 643 path: str, 644 *, 645 method: str = "GET", 646 signer_type: Optional[SignerType] = None, 647 signer_options: Optional[dict[str, Any]] = None, 648 ) -> str: 649 """ 650 Generate a SAS URL for a blob in Azure Blob Storage. 651 652 :param path: Path in the form ``container/blob/name``. 653 :param method: HTTP method requested by the caller. 654 :param signer_type: Must be ``None`` or :py:attr:`SignerType.AZURE`. 655 :param signer_options: Optional dict; supports ``expires_in`` (int, seconds). 656 :return: A fully-qualified SAS URL. 657 :raises ValueError: If *signer_type* is not ``None`` / ``SignerType.AZURE``, or if the 658 configured credential type does not support SAS generation. 659 """ 660 if signer_type is not None and signer_type != SignerType.AZURE: 661 raise ValueError(f"Unsupported signer type for Azure provider: {signer_type!r}") 662 663 options = signer_options or {} 664 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN)) 665 666 self._refresh_blob_service_client_if_needed() 667 668 if isinstance(self._credentials_provider, StaticAzureCredentialsProvider): 669 # Account key path: cache parsed AccountName + AccountKey, then sign per request. 670 if self._account_key_signing_material is None: 671 conn_str = self._credentials_provider.get_credentials().get_custom_field(AZURE_CONNECTION_STRING_KEY) 672 parsed = _parse_connection_string(conn_str) 673 self._account_key_signing_material = (parsed["AccountName"], parsed["AccountKey"]) 674 account_name, account_key = self._account_key_signing_material 675 signer = AzureURLSigner( 676 account_name=account_name, 677 account_url=self._account_url, 678 account_key=account_key, 679 expires_in=expires_in, 680 ) 681 682 elif isinstance(self._credentials_provider, DefaultAzureCredentialsProvider): 683 # User delegation key path: refresh when the cached key is within the 684 # refresh buffer of its own expiry or has not been fetched yet. 685 now = datetime.now(timezone.utc) 686 if ( 687 self._delegation_user_key is None 688 or self._delegation_signer_expiry is None 689 or now >= self._delegation_signer_expiry - _DELEGATION_KEY_REFRESH_BUFFER 690 ): 691 key_expiry = now + _DELEGATION_KEY_LIFETIME 692 self._delegation_user_key = self._blob_service_client.get_user_delegation_key( 693 key_start_time=now, 694 key_expiry_time=key_expiry, 695 ) 696 self._delegation_signer_expiry = key_expiry 697 signer = AzureURLSigner( 698 account_name=_parse_account_name_from_url(self._account_url), 699 account_url=self._account_url, 700 user_delegation_key=self._delegation_user_key, 701 expires_in=expires_in, 702 ) 703 704 else: 705 raise ValueError( 706 "Azure presigned URLs require StaticAzureCredentialsProvider (connection string) or " 707 "DefaultAzureCredentialsProvider (Azure Identity). " 708 f"Got: {type(self._credentials_provider).__name__!r}" 709 ) 710 711 return signer.generate_presigned_url(path, method=method) 712 713 @property 714 def supports_parallel_listing(self) -> bool: 715 return True 716 717 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 718 container_name, blob_name = split_path(remote_path) 719 file_size: int = 0 720 self._refresh_blob_service_client_if_needed() 721 722 validated_attributes = validate_attributes(attributes) 723 if isinstance(f, str): 724 file_size = os.path.getsize(f) 725 726 if file_size <= self._multipart_threshold: 727 728 def _invoke_api() -> int: 729 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 730 with open(f, "rb") as data: 731 blob_client.upload_blob( 732 data, 733 overwrite=True, 734 metadata=validated_attributes or {}, 735 validate_content=self._validate_content, 736 ) 737 return file_size 738 739 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 740 741 def _invoke_api() -> int: 742 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 743 with open(f, "rb") as data: 744 blob_client.upload_blob( 745 data, 746 overwrite=True, 747 metadata=validated_attributes or {}, 748 max_concurrency=self._max_concurrency, 749 validate_content=self._validate_content, 750 ) 751 return file_size 752 753 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 754 else: 755 if isinstance(f, io.StringIO): 756 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore 757 else: 758 fp = f 759 760 fp.seek(0, io.SEEK_END) 761 file_size = fp.tell() 762 fp.seek(0) 763 764 if file_size <= self._multipart_threshold: 765 766 def _invoke_api() -> int: 767 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 768 blob_client.upload_blob( 769 fp, 770 overwrite=True, 771 metadata=validated_attributes or {}, 772 validate_content=self._validate_content, 773 ) 774 return file_size 775 776 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 777 778 def _invoke_api() -> int: 779 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 780 blob_client.upload_blob( 781 fp, 782 overwrite=True, 783 metadata=validated_attributes or {}, 784 max_concurrency=self._max_concurrency, 785 validate_content=self._validate_content, 786 ) 787 return file_size 788 789 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 790 791 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 792 if metadata is None: 793 metadata = self._get_object_metadata(remote_path) 794 795 container_name, blob_name = split_path(remote_path) 796 self._refresh_blob_service_client_if_needed() 797 798 if isinstance(f, str): 799 if os.path.dirname(f): 800 safe_makedirs(os.path.dirname(f)) 801 802 if metadata.content_length <= self._multipart_threshold: 803 804 def _invoke_api() -> int: 805 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 806 temp_file_path: str | None = None 807 try: 808 with tempfile.NamedTemporaryFile( 809 mode="wb", delete=False, dir=os.path.dirname(f), prefix="." 810 ) as fp: 811 temp_file_path = fp.name 812 stream = blob_client.download_blob(validate_content=self._validate_content) 813 fp.write(stream.readall()) 814 os.rename(src=temp_file_path, dst=f) 815 except BaseException: 816 if temp_file_path and os.path.exists(temp_file_path): 817 os.unlink(temp_file_path) 818 raise 819 return metadata.content_length 820 821 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 822 823 def _invoke_api() -> int: 824 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 825 temp_file_path: str | None = None 826 try: 827 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 828 temp_file_path = fp.name 829 stream = blob_client.download_blob( 830 max_concurrency=self._max_concurrency, 831 validate_content=self._validate_content, 832 ) 833 stream.readinto(fp) 834 os.rename(src=temp_file_path, dst=f) 835 except BaseException: 836 if temp_file_path and os.path.exists(temp_file_path): 837 os.unlink(temp_file_path) 838 raise 839 return metadata.content_length 840 841 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 842 else: 843 if metadata.content_length <= self._multipart_threshold: 844 845 def _invoke_api() -> int: 846 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 847 stream = blob_client.download_blob(validate_content=self._validate_content) 848 if isinstance(f, io.StringIO): 849 f.write(stream.readall().decode("utf-8")) 850 else: 851 f.write(stream.readall()) 852 return metadata.content_length 853 854 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 855 856 def _invoke_api() -> int: 857 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 858 stream = blob_client.download_blob( 859 max_concurrency=self._max_concurrency, 860 validate_content=self._validate_content, 861 ) 862 if isinstance(f, io.StringIO): 863 temp_file_path: str | None = None 864 try: 865 with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=".") as tmp: 866 temp_file_path = tmp.name 867 stream.readinto(tmp) 868 with open(temp_file_path, "r") as tmp_read: 869 f.write(tmp_read.read()) 870 finally: 871 if temp_file_path and os.path.exists(temp_file_path): 872 os.unlink(temp_file_path) 873 else: 874 stream.readinto(f) 875 return metadata.content_length 876 877 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)