1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19from collections.abc import Callable, Iterator
20from datetime import datetime, timedelta, timezone
21from typing import IO, Any, Optional, TypeVar, Union
22from urllib.parse import urlparse
23
24from azure.core import MatchConditions
25from azure.core.exceptions import AzureError, HttpResponseError
26from azure.identity import DefaultAzureCredential
27from azure.storage.blob import BlobPrefix, BlobServiceClient, generate_blob_sas
28from azure.storage.blob._models import BlobSasPermissions
29
30from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
31from ..signers.base import URLSigner
32from ..telemetry import Telemetry
33from ..types import (
34 AWARE_DATETIME_MIN,
35 Credentials,
36 CredentialsProvider,
37 ObjectMetadata,
38 PreconditionFailedError,
39 Range,
40 SignerType,
41)
42from ..utils import safe_makedirs, split_path, validate_attributes
43from .base import BaseStorageProvider
44
45_T = TypeVar("_T")
46
47PROVIDER = "azure"
48AZURE_CONNECTION_STRING_KEY = "connection"
49AZURE_CREDENTIAL_KEY = "azure_credential"
50
51DEFAULT_PRESIGN_EXPIRES_IN = 3600
52
53# How long before delegation key expiry we treat the cached key as stale.
54_DELEGATION_KEY_REFRESH_BUFFER = timedelta(minutes=5)
55
56# Azure's maximum allowed delegation key lifetime is 7 days.
57_DELEGATION_KEY_LIFETIME = timedelta(days=7)
58
59
60def _sas_permissions_for_method(method: str) -> BlobSasPermissions:
61 """Return the minimal :class:`BlobSasPermissions` needed for *method*."""
62 m = method.upper()
63 if m in ("PUT", "POST"):
64 return BlobSasPermissions(write=True, create=True)
65 elif m == "DELETE":
66 return BlobSasPermissions(delete=True)
67 else:
68 # GET, HEAD, and any unrecognised method → read-only
69 return BlobSasPermissions(read=True)
70
71
72def _parse_account_name_from_url(account_url: str) -> str:
73 """Extract the storage account name from an Azure Blob Storage account URL."""
74 hostname = urlparse(account_url).hostname
75 if hostname is None:
76 raise ValueError(f"Invalid Azure account URL: {account_url!r}")
77 return hostname.split(".")[0]
78
79
80def _parse_connection_string(conn_str: str) -> dict[str, str]:
81 """Parse an Azure connection string (``AccountName=foo;AccountKey=bar;...``) into a dict."""
82 return dict(part.split("=", 1) for part in conn_str.split(";") if "=" in part)
83
84
[docs]
85class AzureURLSigner(URLSigner):
86 """
87 Generates Azure Blob Storage SAS (Shared Access Signature) URLs.
88
89 Supports two signing paths depending on which credential is provided:
90
91 * **Account key** – uses a static storage account key (parsed from a connection string).
92 * **User delegation key** – uses a time-limited key obtained via Azure Identity (e.g. workload
93 identity, managed identity). Callers are responsible for refreshing the signer when the
94 delegation key approaches expiry; see :py:meth:`AzureBlobStorageProvider._generate_presigned_url`.
95 """
96
97 def __init__(
98 self,
99 account_name: str,
100 account_url: str,
101 *,
102 account_key: Optional[str] = None,
103 user_delegation_key: Optional[Any] = None,
104 expires_in: int = DEFAULT_PRESIGN_EXPIRES_IN,
105 ) -> None:
106 if account_key is None and user_delegation_key is None:
107 raise ValueError("Either account_key or user_delegation_key must be provided.")
108 self._account_name = account_name
109 self._account_url = account_url.rstrip("/")
110 self._account_key = account_key
111 self._user_delegation_key = user_delegation_key
112 self._expires_in = expires_in
113
[docs]
114 def generate_presigned_url(self, path: str, *, method: str = "GET") -> str:
115 """
116 Generate a SAS URL for the given blob path.
117
118 :param path: Path in the form ``container/blob/name``.
119 :param method: HTTP method requested by the caller.
120 :return: A fully-qualified SAS URL.
121 """
122 container_name, blob_name = split_path(path)
123 expiry = datetime.now(timezone.utc) + timedelta(seconds=self._expires_in)
124
125 sas_kwargs: dict[str, Any] = {
126 "account_name": self._account_name,
127 "container_name": container_name,
128 "blob_name": blob_name,
129 "permission": _sas_permissions_for_method(method),
130 "expiry": expiry,
131 }
132
133 if self._account_key is not None:
134 sas_kwargs["account_key"] = self._account_key
135 else:
136 sas_kwargs["user_delegation_key"] = self._user_delegation_key
137
138 sas_token = generate_blob_sas(**sas_kwargs)
139 blob_url = f"{self._account_url}/{container_name}/{blob_name}"
140 return f"{blob_url}?{sas_token}"
141
142
[docs]
143class StaticAzureCredentialsProvider(CredentialsProvider):
144 """
145 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials.
146 """
147
148 _connection: str
149
150 def __init__(self, connection: str):
151 """
152 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string.
153
154 :param connection: The connection string for Azure Blob Storage authentication.
155 """
156 self._connection = connection
157
[docs]
158 def get_credentials(self) -> Credentials:
159 return Credentials(
160 access_key=self._connection,
161 secret_key="",
162 token=None,
163 expiration=None,
164 custom_fields={AZURE_CONNECTION_STRING_KEY: self._connection},
165 )
166
[docs]
167 def refresh_credentials(self) -> None:
168 pass
169
170
[docs]
171class DefaultAzureCredentialsProvider(CredentialsProvider):
172 """
173 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that uses Azure Identity's :py:class:`azure.identity.DefaultAzureCredential` to authenticate with Blob Storage.
174
175 See :py:class:`azure.identity.DefaultAzureCredential` for provider options.
176 """
177
178 def __init__(self, **kwargs: dict[str, Any]):
179 self._credential = DefaultAzureCredential(**kwargs)
180
[docs]
181 def get_credentials(self) -> Credentials:
182 return Credentials(
183 access_key="",
184 secret_key="",
185 token=None,
186 expiration=None,
187 custom_fields={AZURE_CREDENTIAL_KEY: self._credential},
188 )
189
[docs]
190 def refresh_credentials(self) -> None:
191 pass
192
193
[docs]
194class AzureBlobStorageProvider(BaseStorageProvider):
195 """
196 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage.
197 """
198
199 def __init__(
200 self,
201 endpoint_url: str,
202 base_path: str = "",
203 credentials_provider: Optional[CredentialsProvider] = None,
204 config_dict: Optional[dict[str, Any]] = None,
205 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
206 **kwargs: dict[str, Any],
207 ):
208 """
209 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider.
210
211 :param endpoint_url: The Azure storage account URL.
212 :param base_path: The root prefix path within the container where all operations will be scoped.
213 :param credentials_provider: The provider to retrieve Azure credentials.
214 :param config_dict: Resolved MSC config.
215 :param telemetry_provider: A function that provides a telemetry instance.
216 """
217 super().__init__(
218 base_path=base_path,
219 provider_name=PROVIDER,
220 config_dict=config_dict,
221 telemetry_provider=telemetry_provider,
222 )
223
224 self._account_url = endpoint_url
225 self._credentials_provider = credentials_provider
226 # Cache static connection-string signing material used for per-request signers.
227 self._account_key_signing_material: Optional[tuple[str, str]] = None
228 # Cached delegation key and its expiry for DefaultAzureCredentialsProvider.
229 self._delegation_user_key: Optional[Any] = None
230 self._delegation_signer_expiry: Optional[datetime] = None
231 # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#optional-configuration
232 client_optional_configuration_keys = {
233 "retry_total",
234 "retry_connect",
235 "retry_read",
236 "retry_status",
237 "connection_timeout",
238 "read_timeout",
239 }
240 self._client_optional_configuration: dict[str, Any] = {
241 key: value for key, value in kwargs.items() if key in client_optional_configuration_keys
242 }
243 if "connection_timeout" not in self._client_optional_configuration:
244 self._client_optional_configuration["connection_timeout"] = DEFAULT_CONNECT_TIMEOUT
245 if "read_timeout" not in self._client_optional_configuration:
246 self._client_optional_configuration["read_timeout"] = DEFAULT_READ_TIMEOUT
247 self._blob_service_client = self._create_blob_service_client()
248
249 def _create_blob_service_client(self) -> BlobServiceClient:
250 """
251 Creates and configures the Azure BlobServiceClient using the current credentials.
252
253 :return: The configured BlobServiceClient.
254 """
255 if self._credentials_provider:
256 credentials = self._credentials_provider.get_credentials()
257
258 if isinstance(self._credentials_provider, StaticAzureCredentialsProvider):
259 return BlobServiceClient.from_connection_string(
260 credentials.get_custom_field(AZURE_CONNECTION_STRING_KEY), **self._client_optional_configuration
261 )
262 elif isinstance(self._credentials_provider, DefaultAzureCredentialsProvider):
263 return BlobServiceClient(
264 account_url=self._account_url,
265 credential=credentials.get_custom_field(AZURE_CREDENTIAL_KEY),
266 **self._client_optional_configuration,
267 )
268 else:
269 # Fallback to connection string if no built-in credentials provider is provided
270 return BlobServiceClient.from_connection_string(
271 credentials.access_key, **self._client_optional_configuration
272 )
273 else:
274 return BlobServiceClient(account_url=self._account_url, **self._client_optional_configuration)
275
276 def _refresh_blob_service_client_if_needed(self) -> None:
277 """
278 Refreshes the BlobServiceClient if the current credentials are expired.
279 """
280 if self._credentials_provider:
281 credentials = self._credentials_provider.get_credentials()
282 if credentials.is_expired():
283 self._credentials_provider.refresh_credentials()
284 self._blob_service_client = self._create_blob_service_client()
285
286 def _translate_errors(
287 self,
288 func: Callable[[], _T],
289 operation: str,
290 container: str,
291 blob: str,
292 ) -> _T:
293 """
294 Translates errors like timeouts and client errors.
295
296 :param func: The function that performs the actual Azure Blob Storage operation.
297 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
298 :param container: The name of the Azure container involved in the operation.
299 :param blob: The name of the blob within the Azure container.
300
301 :return The result of the Azure Blob Storage operation, typically the return value of the `func` callable.
302 """
303 try:
304 return func()
305 except HttpResponseError as error:
306 status_code = error.status_code if error.status_code else -1
307 error_info = f"status_code: {error.status_code}, reason: {error.reason}"
308 if status_code == 404:
309 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from
310 elif status_code == 412:
311 # raised when If-Match or If-Modified fails
312 raise PreconditionFailedError(
313 f"Failed to {operation} object(s) at {container}/{blob}. {error_info}"
314 ) from error
315 else:
316 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error
317 except AzureError as error:
318 error_info = f"message: {error.message}"
319 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error
320 except FileNotFoundError:
321 raise
322 except Exception as error:
323 raise RuntimeError(
324 f"Failed to {operation} object(s) at {container}/{blob}. error_type: {type(error).__name__}, error: {error}"
325 ) from error
326
327 def _put_object(
328 self,
329 path: str,
330 body: bytes,
331 if_match: Optional[str] = None,
332 if_none_match: Optional[str] = None,
333 attributes: Optional[dict[str, str]] = None,
334 ) -> int:
335 """
336 Uploads an object to Azure Blob Storage.
337
338 :param path: The path to the object to upload.
339 :param body: The content of the object to upload.
340 :param if_match: Optional ETag to match against the object.
341 :param if_none_match: Optional ETag to match against the object.
342 :param attributes: Optional attributes to attach to the object.
343 """
344 container_name, blob_name = split_path(path)
345 self._refresh_blob_service_client_if_needed()
346
347 def _invoke_api() -> int:
348 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
349
350 kwargs = {
351 "data": body,
352 "overwrite": True,
353 }
354
355 validated_attributes = validate_attributes(attributes)
356 if validated_attributes:
357 kwargs["metadata"] = validated_attributes
358
359 if if_match:
360 kwargs["match_condition"] = MatchConditions.IfNotModified
361 kwargs["etag"] = if_match
362
363 if if_none_match:
364 if if_none_match == "*":
365 raise NotImplementedError("if_none_match='*' is not supported for Azure")
366 kwargs["match_condition"] = MatchConditions.IfModified
367 kwargs["etag"] = if_none_match
368
369 blob_client.upload_blob(**kwargs)
370
371 return len(body)
372
373 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name)
374
375 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
376 container_name, blob_name = split_path(path)
377 self._refresh_blob_service_client_if_needed()
378
379 def _invoke_api() -> bytes:
380 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
381 if byte_range:
382 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size)
383 else:
384 stream = blob_client.download_blob()
385 return stream.readall()
386
387 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)
388
389 def _copy_object(self, src_path: str, dest_path: str) -> int:
390 src_container, src_blob = split_path(src_path)
391 dest_container, dest_blob = split_path(dest_path)
392 self._refresh_blob_service_client_if_needed()
393
394 src_object = self._get_object_metadata(src_path)
395
396 def _invoke_api() -> int:
397 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob)
398 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob)
399 dest_blob_client.start_copy_from_url(src_blob_client.url)
400
401 return src_object.content_length
402
403 return self._translate_errors(_invoke_api, operation="COPY", container=src_container, blob=src_blob)
404
405 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
406 container_name, blob_name = split_path(path)
407 self._refresh_blob_service_client_if_needed()
408
409 def _invoke_api() -> None:
410 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
411 # If if_match is provided, use if_match for conditional deletion
412 if if_match:
413 blob_client.delete_blob(etag=if_match, match_condition=MatchConditions.IfNotModified)
414 else:
415 # No if_match provided, perform unconditional deletion
416 blob_client.delete_blob()
417
418 return self._translate_errors(_invoke_api, operation="DELETE", container=container_name, blob=blob_name)
419
420 def _delete_objects(self, paths: list[str]) -> None:
421 if not paths:
422 return
423
424 by_container: dict[str, list[str]] = {}
425 for p in paths:
426 container_name, blob_name = split_path(p)
427 by_container.setdefault(container_name, []).append(blob_name)
428 self._refresh_blob_service_client_if_needed()
429
430 AZURE_BATCH_LIMIT = 256
431
432 def _invoke_api() -> None:
433 for container_name, blob_names in by_container.items():
434 container_client = self._blob_service_client.get_container_client(container=container_name)
435 for i in range(0, len(blob_names), AZURE_BATCH_LIMIT):
436 chunk = blob_names[i : i + AZURE_BATCH_LIMIT]
437 container_client.delete_blobs(*chunk)
438
439 container_desc = "(" + "|".join(by_container) + ")"
440 blob_desc = "(" + "|".join(str(len(blob_names)) for blob_names in by_container.values()) + " keys)"
441 self._translate_errors(_invoke_api, operation="DELETE_MANY", container=container_desc, blob=blob_desc)
442
443 def _is_dir(self, path: str) -> bool:
444 # Ensure the path ends with '/' to mimic a directory
445 path = self._append_delimiter(path)
446
447 container_name, prefix = split_path(path)
448 self._refresh_blob_service_client_if_needed()
449
450 def _invoke_api() -> bool:
451 # List objects with the given prefix
452 container_client = self._blob_service_client.get_container_client(container=container_name)
453 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/")
454 # Check if there are any contents or common prefixes
455 return any(True for _ in blobs)
456
457 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix)
458
459 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
460 container_name, blob_name = split_path(path)
461 if path.endswith("/") or (container_name and not blob_name):
462 # If path ends with "/" or empty blob name is provided, then assume it's a "directory",
463 # which metadata is not guaranteed to exist for cases such as
464 # "virtual prefix" that was never explicitly created.
465 if self._is_dir(path):
466 return ObjectMetadata(
467 key=self._append_delimiter(path),
468 type="directory",
469 content_length=0,
470 last_modified=AWARE_DATETIME_MIN,
471 )
472 else:
473 raise FileNotFoundError(f"Directory {path} does not exist.")
474 else:
475 self._refresh_blob_service_client_if_needed()
476
477 def _invoke_api() -> ObjectMetadata:
478 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
479 properties = blob_client.get_blob_properties()
480 return ObjectMetadata(
481 key=path,
482 content_length=properties.size,
483 content_type=properties.content_settings.content_type,
484 last_modified=properties.last_modified,
485 etag=properties.etag.strip('"') if properties.etag else "",
486 metadata=dict(properties.metadata) if properties.metadata else None,
487 )
488
489 try:
490 return self._translate_errors(_invoke_api, operation="HEAD", container=container_name, blob=blob_name)
491 except FileNotFoundError as error:
492 if strict:
493 # If the object does not exist on the given path, we will append a trailing slash and
494 # check if the path is a directory.
495 path = self._append_delimiter(path)
496 if self._is_dir(path):
497 return ObjectMetadata(
498 key=path,
499 type="directory",
500 content_length=0,
501 last_modified=AWARE_DATETIME_MIN,
502 )
503 raise error
504
505 def _list_objects(
506 self,
507 path: str,
508 start_after: Optional[str] = None,
509 end_at: Optional[str] = None,
510 include_directories: bool = False,
511 follow_symlinks: bool = True,
512 ) -> Iterator[ObjectMetadata]:
513 container_name, prefix = split_path(path)
514
515 # Get the prefix of the start_after and end_at paths relative to the bucket.
516 if start_after:
517 _, start_after = split_path(start_after)
518 if end_at:
519 _, end_at = split_path(end_at)
520
521 self._refresh_blob_service_client_if_needed()
522
523 def _invoke_api() -> Iterator[ObjectMetadata]:
524 container_client = self._blob_service_client.get_container_client(container=container_name)
525 # Azure has no start key option like other object stores.
526 if include_directories:
527 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/")
528 else:
529 blobs = container_client.list_blobs(name_starts_with=prefix)
530 # Azure guarantees lexicographical order.
531 for blob in blobs:
532 if isinstance(blob, BlobPrefix):
533 prefix_key = blob.name.rstrip("/")
534 # Filter by start_after and end_at if specified
535 if (start_after is None or start_after < prefix_key) and (end_at is None or prefix_key <= end_at):
536 yield ObjectMetadata(
537 key=os.path.join(container_name, prefix_key),
538 type="directory",
539 content_length=0,
540 last_modified=AWARE_DATETIME_MIN,
541 )
542 elif end_at is not None and end_at < prefix_key:
543 return
544 else:
545 key = blob.name
546 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
547 if key.endswith("/"):
548 if include_directories:
549 yield ObjectMetadata(
550 key=os.path.join(container_name, key.rstrip("/")),
551 type="directory",
552 content_length=0,
553 last_modified=blob.last_modified,
554 )
555 else:
556 yield ObjectMetadata(
557 key=os.path.join(container_name, key),
558 content_length=blob.size,
559 content_type=blob.content_settings.content_type,
560 last_modified=blob.last_modified,
561 etag=blob.etag.strip('"') if blob.etag else "",
562 )
563 elif end_at is not None and end_at < key:
564 return
565
566 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix)
567
568 def _generate_presigned_url(
569 self,
570 path: str,
571 *,
572 method: str = "GET",
573 signer_type: Optional[SignerType] = None,
574 signer_options: Optional[dict[str, Any]] = None,
575 ) -> str:
576 """
577 Generate a SAS URL for a blob in Azure Blob Storage.
578
579 :param path: Path in the form ``container/blob/name``.
580 :param method: HTTP method requested by the caller.
581 :param signer_type: Must be ``None`` or :py:attr:`SignerType.AZURE`.
582 :param signer_options: Optional dict; supports ``expires_in`` (int, seconds).
583 :return: A fully-qualified SAS URL.
584 :raises ValueError: If *signer_type* is not ``None`` / ``SignerType.AZURE``, or if the
585 configured credential type does not support SAS generation.
586 """
587 if signer_type is not None and signer_type != SignerType.AZURE:
588 raise ValueError(f"Unsupported signer type for Azure provider: {signer_type!r}")
589
590 options = signer_options or {}
591 expires_in = int(options.get("expires_in", DEFAULT_PRESIGN_EXPIRES_IN))
592
593 self._refresh_blob_service_client_if_needed()
594
595 if isinstance(self._credentials_provider, StaticAzureCredentialsProvider):
596 # Account key path: cache parsed AccountName + AccountKey, then sign per request.
597 if self._account_key_signing_material is None:
598 conn_str = self._credentials_provider.get_credentials().get_custom_field(AZURE_CONNECTION_STRING_KEY)
599 parsed = _parse_connection_string(conn_str)
600 self._account_key_signing_material = (parsed["AccountName"], parsed["AccountKey"])
601 account_name, account_key = self._account_key_signing_material
602 signer = AzureURLSigner(
603 account_name=account_name,
604 account_url=self._account_url,
605 account_key=account_key,
606 expires_in=expires_in,
607 )
608
609 elif isinstance(self._credentials_provider, DefaultAzureCredentialsProvider):
610 # User delegation key path: refresh when the cached key is within the
611 # refresh buffer of its own expiry or has not been fetched yet.
612 now = datetime.now(timezone.utc)
613 if (
614 self._delegation_user_key is None
615 or self._delegation_signer_expiry is None
616 or now >= self._delegation_signer_expiry - _DELEGATION_KEY_REFRESH_BUFFER
617 ):
618 key_expiry = now + _DELEGATION_KEY_LIFETIME
619 self._delegation_user_key = self._blob_service_client.get_user_delegation_key(
620 key_start_time=now,
621 key_expiry_time=key_expiry,
622 )
623 self._delegation_signer_expiry = key_expiry
624 signer = AzureURLSigner(
625 account_name=_parse_account_name_from_url(self._account_url),
626 account_url=self._account_url,
627 user_delegation_key=self._delegation_user_key,
628 expires_in=expires_in,
629 )
630
631 else:
632 raise ValueError(
633 "Azure presigned URLs require StaticAzureCredentialsProvider (connection string) or "
634 "DefaultAzureCredentialsProvider (Azure Identity). "
635 f"Got: {type(self._credentials_provider).__name__!r}"
636 )
637
638 return signer.generate_presigned_url(path, method=method)
639
640 @property
641 def supports_parallel_listing(self) -> bool:
642 return True
643
644 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
645 container_name, blob_name = split_path(remote_path)
646 file_size: int = 0
647 self._refresh_blob_service_client_if_needed()
648
649 validated_attributes = validate_attributes(attributes)
650 if isinstance(f, str):
651 file_size = os.path.getsize(f)
652
653 def _invoke_api() -> int:
654 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
655 with open(f, "rb") as data:
656 blob_client.upload_blob(data, overwrite=True, metadata=validated_attributes or {})
657
658 return file_size
659
660 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name)
661 else:
662 # Convert StringIO to BytesIO before upload
663 if isinstance(f, io.StringIO):
664 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore
665 else:
666 fp = f
667
668 fp.seek(0, io.SEEK_END)
669 file_size = fp.tell()
670 fp.seek(0)
671
672 def _invoke_api() -> int:
673 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
674 blob_client.upload_blob(fp, overwrite=True, metadata=validated_attributes or {})
675
676 return file_size
677
678 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name)
679
680 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
681 if metadata is None:
682 metadata = self._get_object_metadata(remote_path)
683
684 container_name, blob_name = split_path(remote_path)
685 self._refresh_blob_service_client_if_needed()
686
687 if isinstance(f, str):
688 if os.path.dirname(f):
689 safe_makedirs(os.path.dirname(f))
690
691 def _invoke_api() -> int:
692 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
693 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
694 temp_file_path = fp.name
695 stream = blob_client.download_blob()
696 fp.write(stream.readall())
697 os.rename(src=temp_file_path, dst=f)
698
699 return metadata.content_length
700
701 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)
702 else:
703
704 def _invoke_api() -> int:
705 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name)
706 stream = blob_client.download_blob()
707 if isinstance(f, io.StringIO):
708 f.write(stream.readall().decode("utf-8"))
709 else:
710 f.write(stream.readall())
711
712 return metadata.content_length
713
714 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)