Source code for multistorageclient.providers.azure

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19from collections.abc import Callable, Iterator
 20from datetime import datetime, timedelta, timezone
 21from typing import IO, Any, Optional, TypeVar, Union
 22from urllib.parse import urlparse
 23
 24from azure.core import MatchConditions
 25from azure.core.exceptions import AzureError, HttpResponseError
 26from azure.identity import DefaultAzureCredential
 27from azure.storage.blob import BlobPrefix, BlobServiceClient, generate_blob_sas
 28from azure.storage.blob._models import BlobSasPermissions
 29
 30from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
 31from ..signers.base import URLSigner
 32from ..telemetry import Telemetry
 33from ..types import (
 34    AWARE_DATETIME_MIN,
 35    Credentials,
 36    CredentialsProvider,
 37    ObjectMetadata,
 38    PreconditionFailedError,
 39    Range,
 40    SignerType,
 41)
 42from ..utils import safe_makedirs, split_path, validate_attributes
 43from .base import BaseStorageProvider
 44
 45_T = TypeVar("_T")
 46
 47PROVIDER = "azure"
 48AZURE_CONNECTION_STRING_KEY = "connection"
 49AZURE_CREDENTIAL_KEY = "azure_credential"
 50
 51DEFAULT_PRESIGN_EXPIRES_IN = 3600
 52
 53# How long before delegation key expiry we treat the cached key as stale.
 54_DELEGATION_KEY_REFRESH_BUFFER = timedelta(minutes=5)
 55
 56# Azure's maximum allowed delegation key lifetime is 7 days.
 57_DELEGATION_KEY_LIFETIME = timedelta(days=7)
 58
 59
 60def _sas_permissions_for_method(method: str) -> BlobSasPermissions:
 61    """Return the minimal :class:`BlobSasPermissions` needed for *method*."""
 62    m = method.upper()
 63    if m in ("PUT", "POST"):
 64        return BlobSasPermissions(write=True, create=True)
 65    elif m == "DELETE":
 66        return BlobSasPermissions(delete=True)
 67    else:
 68        # GET, HEAD, and any unrecognised method → read-only
 69        return BlobSasPermissions(read=True)
 70
 71
 72def _parse_account_name_from_url(account_url: str) -> str:
 73    """Extract the storage account name from an Azure Blob Storage account URL."""
 74    hostname = urlparse(account_url).hostname
 75    if hostname is None:
 76        raise ValueError(f"Invalid Azure account URL: {account_url!r}")
 77    return hostname.split(".")[0]
 78
 79
 80def _parse_connection_string(conn_str: str) -> dict[str, str]:
 81    """Parse an Azure connection string (``AccountName=foo;AccountKey=bar;...``) into a dict."""
 82    return dict(part.split("=", 1) for part in conn_str.split(";") if "=" in part)
 83
 84
[docs] 85class AzureURLSigner(URLSigner): 86 """ 87 Generates Azure Blob Storage SAS (Shared Access Signature) URLs. 88 89 Supports two signing paths depending on which credential is provided: 90 91 * **Account key** – uses a static storage account key (parsed from a connection string). 92 * **User delegation key** – uses a time-limited key obtained via Azure Identity (e.g. workload 93 identity, managed identity). Callers are responsible for refreshing the signer when the 94 delegation key approaches expiry; see :py:meth:`AzureBlobStorageProvider._generate_presigned_url`. 95 """ 96 97 def __init__( 98 self, 99 account_name: str, 100 account_url: str, 101 *, 102 account_key: Optional[str] = None, 103 user_delegation_key: Optional[Any] = None, 104 expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN, 105 ) -> None: 106 if account_key is None and user_delegation_key is None: 107 raise ValueError("Either account_key or user_delegation_key must be provided.") 108 self._account_name = account_name 109 self._account_url = account_url.rstrip("/") 110 self._account_key = account_key 111 self._user_delegation_key = user_delegation_key 112 self._expires_in = expires_in 113
[docs] 114 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str: 115 """ 116 Generate a SAS URL for the given blob path. 117 118 :param path: Path in the form ``container/blob/name``. 119 :param method: HTTP method requested by the caller. 120 :return: A fully-qualified SAS URL. 121 """ 122 container_name, blob_name = split_path(path) 123 expiry = datetime.now(timezone.utc) + timedelta(seconds=self._expires_in) 124 125 sas_kwargs: dict[str, Any] = { 126 "account_name": self._account_name, 127 "container_name": container_name, 128 "blob_name": blob_name, 129 "permission": _sas_permissions_for_method(method), 130 "expiry": expiry, 131 } 132 133 if self._account_key is not None: 134 sas_kwargs["account_key"] = self._account_key 135 else: 136 sas_kwargs["user_delegation_key"] = self._user_delegation_key 137 138 sas_token = generate_blob_sas(**sas_kwargs) 139 blob_url = f"{self._account_url}/{container_name}/{blob_name}" 140 return f"{blob_url}?{sas_token}"
141 142
[docs] 143class StaticAzureCredentialsProvider(CredentialsProvider): 144 """ 145 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials. 146 """ 147 148 _connection: str 149 150 def __init__(self, connection: str): 151 """ 152 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string. 153 154 :param connection: The connection string for Azure Blob Storage authentication. 155 """ 156 self._connection = connection 157
[docs] 158 def get_credentials(self) -> Credentials: 159 return Credentials( 160 access_key=self._connection, 161 secret_key="", 162 token=None, 163 expiration=None, 164 custom_fields={AZURE_CONNECTION_STRING_KEY: self._connection}, 165 )
166
[docs] 167 def refresh_credentials(self) -> None: 168 pass
169 170
[docs] 171class DefaultAzureCredentialsProvider(CredentialsProvider): 172 """ 173 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that uses Azure Identity's :py:class:`azure.identity.DefaultAzureCredential` to authenticate with Blob Storage. 174 175 See :py:class:`azure.identity.DefaultAzureCredential` for provider options. 176 """ 177 178 def __init__(self, **kwargs: dict[str, Any]): 179 self._credential = DefaultAzureCredential(**kwargs) 180
[docs] 181 def get_credentials(self) -> Credentials: 182 return Credentials( 183 access_key="", 184 secret_key="", 185 token=None, 186 expiration=None, 187 custom_fields={AZURE_CREDENTIAL_KEY: self._credential}, 188 )
189
[docs] 190 def refresh_credentials(self) -> None: 191 pass
192 193
[docs] 194class AzureBlobStorageProvider(BaseStorageProvider): 195 """ 196 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage. 197 """ 198 199 def __init__( 200 self, 201 endpoint_url: str, 202 base_path: str = "", 203 credentials_provider: Optional[CredentialsProvider] = None, 204 config_dict: Optional[dict[str, Any]] = None, 205 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 206 **kwargs: dict[str, Any], 207 ): 208 """ 209 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider. 210 211 :param endpoint_url: The Azure storage account URL. 212 :param base_path: The root prefix path within the container where all operations will be scoped. 213 :param credentials_provider: The provider to retrieve Azure credentials. 214 :param config_dict: Resolved MSC config. 215 :param telemetry_provider: A function that provides a telemetry instance. 216 """ 217 super().__init__( 218 base_path=base_path, 219 provider_name=PROVIDER, 220 config_dict=config_dict, 221 telemetry_provider=telemetry_provider, 222 ) 223 224 self._account_url = endpoint_url 225 self._credentials_provider = credentials_provider 226 # Cache static connection-string signing material used for per-request signers. 227 self._account_key_signing_material: Optional[tuple[str, str]] = None 228 # Cached delegation key and its expiry for DefaultAzureCredentialsProvider. 229 self._delegation_user_key: Optional[Any] = None 230 self._delegation_signer_expiry: Optional[datetime] = None 231 # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#optional-configuration 232 client_optional_configuration_keys = { 233 "retry_total", 234 "retry_connect", 235 "retry_read", 236 "retry_status", 237 "connection_timeout", 238 "read_timeout", 239 } 240 self._client_optional_configuration: dict[str, Any] = { 241 key: value for key, value in kwargs.items() if key in client_optional_configuration_keys 242 } 243 if "connection_timeout" not in self._client_optional_configuration: 244 self._client_optional_configuration["connection_timeout"] = DEFAULT_CONNECT_TIMEOUT 245 if "read_timeout" not in self._client_optional_configuration: 246 self._client_optional_configuration["read_timeout"] = DEFAULT_READ_TIMEOUT 247 self._blob_service_client = self._create_blob_service_client() 248 249 def _create_blob_service_client(self) -> BlobServiceClient: 250 """ 251 Creates and configures the Azure BlobServiceClient using the current credentials. 252 253 :return: The configured BlobServiceClient. 254 """ 255 if self._credentials_provider: 256 credentials = self._credentials_provider.get_credentials() 257 258 if isinstance(self._credentials_provider, StaticAzureCredentialsProvider): 259 return BlobServiceClient.from_connection_string( 260 credentials.get_custom_field(AZURE_CONNECTION_STRING_KEY), **self._client_optional_configuration 261 ) 262 elif isinstance(self._credentials_provider, DefaultAzureCredentialsProvider): 263 return BlobServiceClient( 264 account_url=self._account_url, 265 credential=credentials.get_custom_field(AZURE_CREDENTIAL_KEY), 266 **self._client_optional_configuration, 267 ) 268 else: 269 # Fallback to connection string if no built-in credentials provider is provided 270 return BlobServiceClient.from_connection_string( 271 credentials.access_key, **self._client_optional_configuration 272 ) 273 else: 274 return BlobServiceClient(account_url=self._account_url, **self._client_optional_configuration) 275 276 def _refresh_blob_service_client_if_needed(self) -> None: 277 """ 278 Refreshes the BlobServiceClient if the current credentials are expired. 279 """ 280 if self._credentials_provider: 281 credentials = self._credentials_provider.get_credentials() 282 if credentials.is_expired(): 283 self._credentials_provider.refresh_credentials() 284 self._blob_service_client = self._create_blob_service_client() 285 286 def _translate_errors( 287 self, 288 func: Callable[[], _T], 289 operation: str, 290 container: str, 291 blob: str, 292 ) -> _T: 293 """ 294 Translates errors like timeouts and client errors. 295 296 :param func: The function that performs the actual Azure Blob Storage operation. 297 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 298 :param container: The name of the Azure container involved in the operation. 299 :param blob: The name of the blob within the Azure container. 300 301 :return The result of the Azure Blob Storage operation, typically the return value of the `func` callable. 302 """ 303 try: 304 return func() 305 except HttpResponseError as error: 306 status_code = error.status_code if error.status_code else -1 307 error_info = f"status_code: {error.status_code}, reason: {error.reason}" 308 if status_code == 404: 309 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from 310 elif status_code == 412: 311 # raised when If-Match or If-Modified fails 312 raise PreconditionFailedError( 313 f"Failed to {operation} object(s) at {container}/{blob}. {error_info}" 314 ) from error 315 else: 316 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 317 except AzureError as error: 318 error_info = f"message: {error.message}" 319 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 320 except FileNotFoundError: 321 raise 322 except Exception as error: 323 raise RuntimeError( 324 f"Failed to {operation} object(s) at {container}/{blob}. error_type: {type(error).__name__}, error: {error}" 325 ) from error 326 327 def _put_object( 328 self, 329 path: str, 330 body: bytes, 331 if_match: Optional[str] = None, 332 if_none_match: Optional[str] = None, 333 attributes: Optional[dict[str, str]] = None, 334 ) -> int: 335 """ 336 Uploads an object to Azure Blob Storage. 337 338 :param path: The path to the object to upload. 339 :param body: The content of the object to upload. 340 :param if_match: Optional ETag to match against the object. 341 :param if_none_match: Optional ETag to match against the object. 342 :param attributes: Optional attributes to attach to the object. 343 """ 344 container_name, blob_name = split_path(path) 345 self._refresh_blob_service_client_if_needed() 346 347 def _invoke_api() -> int: 348 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 349 350 kwargs = { 351 "data": body, 352 "overwrite": True, 353 } 354 355 validated_attributes = validate_attributes(attributes) 356 if validated_attributes: 357 kwargs["metadata"] = validated_attributes 358 359 if if_match: 360 kwargs["match_condition"] = MatchConditions.IfNotModified 361 kwargs["etag"] = if_match 362 363 if if_none_match: 364 if if_none_match == "*": 365 raise NotImplementedError("if_none_match='*' is not supported for Azure") 366 kwargs["match_condition"] = MatchConditions.IfModified 367 kwargs["etag"] = if_none_match 368 369 blob_client.upload_blob(**kwargs) 370 371 return len(body) 372 373 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 374 375 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 376 container_name, blob_name = split_path(path) 377 self._refresh_blob_service_client_if_needed() 378 379 def _invoke_api() -> bytes: 380 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 381 if byte_range: 382 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size) 383 else: 384 stream = blob_client.download_blob() 385 return stream.readall() 386 387 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 388 389 def _copy_object(self, src_path: str, dest_path: str) -> int: 390 src_container, src_blob = split_path(src_path) 391 dest_container, dest_blob = split_path(dest_path) 392 self._refresh_blob_service_client_if_needed() 393 394 src_object = self._get_object_metadata(src_path) 395 396 def _invoke_api() -> int: 397 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob) 398 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob) 399 dest_blob_client.start_copy_from_url(src_blob_client.url) 400 401 return src_object.content_length 402 403 return self._translate_errors(_invoke_api, operation="COPY", container=src_container, blob=src_blob) 404 405 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 406 container_name, blob_name = split_path(path) 407 self._refresh_blob_service_client_if_needed() 408 409 def _invoke_api() -> None: 410 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 411 # If if_match is provided, use if_match for conditional deletion 412 if if_match: 413 blob_client.delete_blob(etag=if_match, match_condition=MatchConditions.IfNotModified) 414 else: 415 # No if_match provided, perform unconditional deletion 416 blob_client.delete_blob() 417 418 return self._translate_errors(_invoke_api, operation="DELETE", container=container_name, blob=blob_name) 419 420 def _delete_objects(self, paths: list[str]) -> None: 421 if not paths: 422 return 423 424 by_container: dict[str, list[str]] = {} 425 for p in paths: 426 container_name, blob_name = split_path(p) 427 by_container.setdefault(container_name, []).append(blob_name) 428 self._refresh_blob_service_client_if_needed() 429 430 AZURE_BATCH_LIMIT = 256 431 432 def _invoke_api() -> None: 433 for container_name, blob_names in by_container.items(): 434 container_client = self._blob_service_client.get_container_client(container=container_name) 435 for i in range(0, len(blob_names), AZURE_BATCH_LIMIT): 436 chunk = blob_names[i : i + AZURE_BATCH_LIMIT] 437 container_client.delete_blobs(*chunk) 438 439 container_desc = "(" + "|".join(by_container) + ")" 440 blob_desc = "(" + "|".join(str(len(blob_names)) for blob_names in by_container.values()) + " keys)" 441 self._translate_errors(_invoke_api, operation="DELETE_MANY", container=container_desc, blob=blob_desc) 442 443 def _is_dir(self, path: str) -> bool: 444 # Ensure the path ends with '/' to mimic a directory 445 path = self._append_delimiter(path) 446 447 container_name, prefix = split_path(path) 448 self._refresh_blob_service_client_if_needed() 449 450 def _invoke_api() -> bool: 451 # List objects with the given prefix 452 container_client = self._blob_service_client.get_container_client(container=container_name) 453 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 454 # Check if there are any contents or common prefixes 455 return any(True for _ in blobs) 456 457 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 458 459 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 460 container_name, blob_name = split_path(path) 461 if path.endswith("/") or (container_name and not blob_name): 462 # If path ends with "/" or empty blob name is provided, then assume it's a "directory", 463 # which metadata is not guaranteed to exist for cases such as 464 # "virtual prefix" that was never explicitly created. 465 if self._is_dir(path): 466 return ObjectMetadata( 467 key=self._append_delimiter(path), 468 type="directory", 469 content_length=0, 470 last_modified=AWARE_DATETIME_MIN, 471 ) 472 else: 473 raise FileNotFoundError(f"Directory {path} does not exist.") 474 else: 475 self._refresh_blob_service_client_if_needed() 476 477 def _invoke_api() -> ObjectMetadata: 478 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 479 properties = blob_client.get_blob_properties() 480 return ObjectMetadata( 481 key=path, 482 content_length=properties.size, 483 content_type=properties.content_settings.content_type, 484 last_modified=properties.last_modified, 485 etag=properties.etag.strip('"') if properties.etag else "", 486 metadata=dict(properties.metadata) if properties.metadata else None, 487 ) 488 489 try: 490 return self._translate_errors(_invoke_api, operation="HEAD", container=container_name, blob=blob_name) 491 except FileNotFoundError as error: 492 if strict: 493 # If the object does not exist on the given path, we will append a trailing slash and 494 # check if the path is a directory. 495 path = self._append_delimiter(path) 496 if self._is_dir(path): 497 return ObjectMetadata( 498 key=path, 499 type="directory", 500 content_length=0, 501 last_modified=AWARE_DATETIME_MIN, 502 ) 503 raise error 504 505 def _list_objects( 506 self, 507 path: str, 508 start_after: Optional[str] = None, 509 end_at: Optional[str] = None, 510 include_directories: bool = False, 511 follow_symlinks: bool = True, 512 ) -> Iterator[ObjectMetadata]: 513 container_name, prefix = split_path(path) 514 515 # Get the prefix of the start_after and end_at paths relative to the bucket. 516 if start_after: 517 _, start_after = split_path(start_after) 518 if end_at: 519 _, end_at = split_path(end_at) 520 521 self._refresh_blob_service_client_if_needed() 522 523 def _invoke_api() -> Iterator[ObjectMetadata]: 524 container_client = self._blob_service_client.get_container_client(container=container_name) 525 # Azure has no start key option like other object stores. 526 if include_directories: 527 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 528 else: 529 blobs = container_client.list_blobs(name_starts_with=prefix) 530 # Azure guarantees lexicographical order. 531 for blob in blobs: 532 if isinstance(blob, BlobPrefix): 533 prefix_key = blob.name.rstrip("/") 534 # Filter by start_after and end_at if specified 535 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at): 536 yield ObjectMetadata( 537 key=os.path.join(container_name, prefix_key), 538 type="directory", 539 content_length=0, 540 last_modified=AWARE_DATETIME_MIN, 541 ) 542 elif end_at is not None and end_at < prefix_key: 543 return 544 else: 545 key = blob.name 546 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 547 if key.endswith("/"): 548 if include_directories: 549 yield ObjectMetadata( 550 key=os.path.join(container_name, key.rstrip("/")), 551 type="directory", 552 content_length=0, 553 last_modified=blob.last_modified, 554 ) 555 else: 556 yield ObjectMetadata( 557 key=os.path.join(container_name, key), 558 content_length=blob.size, 559 content_type=blob.content_settings.content_type, 560 last_modified=blob.last_modified, 561 etag=blob.etag.strip('"') if blob.etag else "", 562 ) 563 elif end_at is not None and end_at < key: 564 return 565 566 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 567 568 def _generate_presigned_url( 569 self, 570 path: str, 571 *, 572 method: str = "GET", 573 signer_type: Optional[SignerType] = None, 574 signer_options: Optional[dict[str, Any]] = None, 575 ) -> str: 576 """ 577 Generate a SAS URL for a blob in Azure Blob Storage. 578 579 :param path: Path in the form ``container/blob/name``. 580 :param method: HTTP method requested by the caller. 581 :param signer_type: Must be ``None`` or :py:attr:`SignerType.AZURE`. 582 :param signer_options: Optional dict; supports ``expires_in`` (int, seconds). 583 :return: A fully-qualified SAS URL. 584 :raises ValueError: If *signer_type* is not ``None`` / ``SignerType.AZURE``, or if the 585 configured credential type does not support SAS generation. 586 """ 587 if signer_type is not None and signer_type != SignerType.AZURE: 588 raise ValueError(f"Unsupported signer type for Azure provider: {signer_type!r}") 589 590 options = signer_options or {} 591 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN)) 592 593 self._refresh_blob_service_client_if_needed() 594 595 if isinstance(self._credentials_provider, StaticAzureCredentialsProvider): 596 # Account key path: cache parsed AccountName + AccountKey, then sign per request. 597 if self._account_key_signing_material is None: 598 conn_str = self._credentials_provider.get_credentials().get_custom_field(AZURE_CONNECTION_STRING_KEY) 599 parsed = _parse_connection_string(conn_str) 600 self._account_key_signing_material = (parsed["AccountName"], parsed["AccountKey"]) 601 account_name, account_key = self._account_key_signing_material 602 signer = AzureURLSigner( 603 account_name=account_name, 604 account_url=self._account_url, 605 account_key=account_key, 606 expires_in=expires_in, 607 ) 608 609 elif isinstance(self._credentials_provider, DefaultAzureCredentialsProvider): 610 # User delegation key path: refresh when the cached key is within the 611 # refresh buffer of its own expiry or has not been fetched yet. 612 now = datetime.now(timezone.utc) 613 if ( 614 self._delegation_user_key is None 615 or self._delegation_signer_expiry is None 616 or now >= self._delegation_signer_expiry - _DELEGATION_KEY_REFRESH_BUFFER 617 ): 618 key_expiry = now + _DELEGATION_KEY_LIFETIME 619 self._delegation_user_key = self._blob_service_client.get_user_delegation_key( 620 key_start_time=now, 621 key_expiry_time=key_expiry, 622 ) 623 self._delegation_signer_expiry = key_expiry 624 signer = AzureURLSigner( 625 account_name=_parse_account_name_from_url(self._account_url), 626 account_url=self._account_url, 627 user_delegation_key=self._delegation_user_key, 628 expires_in=expires_in, 629 ) 630 631 else: 632 raise ValueError( 633 "Azure presigned URLs require StaticAzureCredentialsProvider (connection string) or " 634 "DefaultAzureCredentialsProvider (Azure Identity). " 635 f"Got: {type(self._credentials_provider).__name__!r}" 636 ) 637 638 return signer.generate_presigned_url(path, method=method) 639 640 @property 641 def supports_parallel_listing(self) -> bool: 642 return True 643 644 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 645 container_name, blob_name = split_path(remote_path) 646 file_size: int = 0 647 self._refresh_blob_service_client_if_needed() 648 649 validated_attributes = validate_attributes(attributes) 650 if isinstance(f, str): 651 file_size = os.path.getsize(f) 652 653 def _invoke_api() -> int: 654 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 655 with open(f, "rb") as data: 656 blob_client.upload_blob(data, overwrite=True, metadata=validated_attributes or {}) 657 658 return file_size 659 660 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 661 else: 662 # Convert StringIO to BytesIO before upload 663 if isinstance(f, io.StringIO): 664 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore 665 else: 666 fp = f 667 668 fp.seek(0, io.SEEK_END) 669 file_size = fp.tell() 670 fp.seek(0) 671 672 def _invoke_api() -> int: 673 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 674 blob_client.upload_blob(fp, overwrite=True, metadata=validated_attributes or {}) 675 676 return file_size 677 678 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 679 680 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 681 if metadata is None: 682 metadata = self._get_object_metadata(remote_path) 683 684 container_name, blob_name = split_path(remote_path) 685 self._refresh_blob_service_client_if_needed() 686 687 if isinstance(f, str): 688 if os.path.dirname(f): 689 safe_makedirs(os.path.dirname(f)) 690 691 def _invoke_api() -> int: 692 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 693 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 694 temp_file_path = fp.name 695 stream = blob_client.download_blob() 696 fp.write(stream.readall()) 697 os.rename(src=temp_file_path, dst=f) 698 699 return metadata.content_length 700 701 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 702 else: 703 704 def _invoke_api() -> int: 705 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 706 stream = blob_client.download_blob() 707 if isinstance(f, io.StringIO): 708 f.write(stream.readall().decode("utf-8")) 709 else: 710 f.write(stream.readall()) 711 712 return metadata.content_length 713 714 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)