Source code for multistorageclient.providers.azure

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19from collections.abc import Callable, Iterator
 20from typing import IO, Any, Optional, TypeVar, Union
 21
 22from azure.core import MatchConditions
 23from azure.core.exceptions import AzureError, HttpResponseError
 24from azure.storage.blob import BlobPrefix, BlobServiceClient
 25
 26from ..telemetry import Telemetry
 27from ..types import (
 28    AWARE_DATETIME_MIN,
 29    Credentials,
 30    CredentialsProvider,
 31    ObjectMetadata,
 32    PreconditionFailedError,
 33    Range,
 34)
 35from ..utils import split_path, validate_attributes
 36from .base import BaseStorageProvider
 37
 38_T = TypeVar("_T")
 39
 40PROVIDER = "azure"
 41
 42
[docs] 43class StaticAzureCredentialsProvider(CredentialsProvider): 44 """ 45 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials. 46 """ 47 48 _connection: str 49 50 def __init__(self, connection: str): 51 """ 52 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string. 53 54 :param connection: The connection string for Azure Blob Storage authentication. 55 """ 56 self._connection = connection 57
[docs] 58 def get_credentials(self) -> Credentials: 59 return Credentials( 60 access_key=self._connection, 61 secret_key="", 62 token=None, 63 expiration=None, 64 )
65
[docs] 66 def refresh_credentials(self) -> None: 67 pass
68 69
[docs] 70class AzureBlobStorageProvider(BaseStorageProvider): 71 """ 72 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage. 73 """ 74 75 def __init__( 76 self, 77 endpoint_url: str, 78 base_path: str = "", 79 credentials_provider: Optional[CredentialsProvider] = None, 80 config_dict: Optional[dict[str, Any]] = None, 81 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 82 **kwargs: dict[str, Any], 83 ): 84 """ 85 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider. 86 87 :param endpoint_url: The Azure storage account URL. 88 :param base_path: The root prefix path within the container where all operations will be scoped. 89 :param credentials_provider: The provider to retrieve Azure credentials. 90 :param config_dict: Resolved MSC config. 91 :param telemetry_provider: A function that provides a telemetry instance. 92 """ 93 super().__init__( 94 base_path=base_path, 95 provider_name=PROVIDER, 96 config_dict=config_dict, 97 telemetry_provider=telemetry_provider, 98 ) 99 100 self._account_url = endpoint_url 101 self._credentials_provider = credentials_provider 102 # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#optional-configuration 103 client_optional_configuration_keys = { 104 "retry_total", 105 "retry_connect", 106 "retry_read", 107 "retry_status", 108 "connection_timeout", 109 "read_timeout", 110 } 111 self._client_optional_configuration = { 112 key: value for key, value in kwargs.items() if key in client_optional_configuration_keys 113 } 114 self._blob_service_client = self._create_blob_service_client() 115 116 def _create_blob_service_client(self) -> BlobServiceClient: 117 """ 118 Creates and configures the Azure BlobServiceClient using the current credentials. 119 120 :return: The configured BlobServiceClient. 121 """ 122 if self._credentials_provider: 123 credentials = self._credentials_provider.get_credentials() 124 return BlobServiceClient.from_connection_string( 125 credentials.access_key, **self._client_optional_configuration 126 ) 127 else: 128 return BlobServiceClient(account_url=self._account_url, **self._client_optional_configuration) 129 130 def _refresh_blob_service_client_if_needed(self) -> None: 131 """ 132 Refreshes the BlobServiceClient if the current credentials are expired. 133 """ 134 if self._credentials_provider: 135 credentials = self._credentials_provider.get_credentials() 136 if credentials.is_expired(): 137 self._credentials_provider.refresh_credentials() 138 self._blob_service_client = self._create_blob_service_client() 139 140 def _translate_errors( 141 self, 142 func: Callable[[], _T], 143 operation: str, 144 container: str, 145 blob: str, 146 ) -> _T: 147 """ 148 Translates errors like timeouts and client errors. 149 150 :param func: The function that performs the actual Azure Blob Storage operation. 151 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 152 :param container: The name of the Azure container involved in the operation. 153 :param blob: The name of the blob within the Azure container. 154 155 :return The result of the Azure Blob Storage operation, typically the return value of the `func` callable. 156 """ 157 try: 158 return func() 159 except HttpResponseError as error: 160 status_code = error.status_code if error.status_code else -1 161 error_info = f"status_code: {error.status_code}, reason: {error.reason}" 162 if status_code == 404: 163 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from 164 elif status_code == 412: 165 # raised when If-Match or If-Modified fails 166 raise PreconditionFailedError( 167 f"Failed to {operation} object(s) at {container}/{blob}. {error_info}" 168 ) from error 169 else: 170 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 171 except AzureError as error: 172 error_info = f"message: {error.message}" 173 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 174 except FileNotFoundError: 175 raise 176 except Exception as error: 177 raise RuntimeError( 178 f"Failed to {operation} object(s) at {container}/{blob}. error_type: {type(error).__name__}, error: {error}" 179 ) from error 180 181 def _put_object( 182 self, 183 path: str, 184 body: bytes, 185 if_match: Optional[str] = None, 186 if_none_match: Optional[str] = None, 187 attributes: Optional[dict[str, str]] = None, 188 ) -> int: 189 """ 190 Uploads an object to Azure Blob Storage. 191 192 :param path: The path to the object to upload. 193 :param body: The content of the object to upload. 194 :param if_match: Optional ETag to match against the object. 195 :param if_none_match: Optional ETag to match against the object. 196 :param attributes: Optional attributes to attach to the object. 197 """ 198 container_name, blob_name = split_path(path) 199 self._refresh_blob_service_client_if_needed() 200 201 def _invoke_api() -> int: 202 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 203 204 kwargs = { 205 "data": body, 206 "overwrite": True, 207 } 208 209 validated_attributes = validate_attributes(attributes) 210 if validated_attributes: 211 kwargs["metadata"] = validated_attributes 212 213 if if_match: 214 kwargs["match_condition"] = MatchConditions.IfNotModified 215 kwargs["etag"] = if_match 216 217 if if_none_match: 218 if if_none_match == "*": 219 raise NotImplementedError("if_none_match='*' is not supported for Azure") 220 kwargs["match_condition"] = MatchConditions.IfModified 221 kwargs["etag"] = if_none_match 222 223 blob_client.upload_blob(**kwargs) 224 225 return len(body) 226 227 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 228 229 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 230 container_name, blob_name = split_path(path) 231 self._refresh_blob_service_client_if_needed() 232 233 def _invoke_api() -> bytes: 234 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 235 if byte_range: 236 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size) 237 else: 238 stream = blob_client.download_blob() 239 return stream.readall() 240 241 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 242 243 def _copy_object(self, src_path: str, dest_path: str) -> int: 244 src_container, src_blob = split_path(src_path) 245 dest_container, dest_blob = split_path(dest_path) 246 self._refresh_blob_service_client_if_needed() 247 248 src_object = self._get_object_metadata(src_path) 249 250 def _invoke_api() -> int: 251 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob) 252 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob) 253 dest_blob_client.start_copy_from_url(src_blob_client.url) 254 255 return src_object.content_length 256 257 return self._translate_errors(_invoke_api, operation="COPY", container=src_container, blob=src_blob) 258 259 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 260 container_name, blob_name = split_path(path) 261 self._refresh_blob_service_client_if_needed() 262 263 def _invoke_api() -> None: 264 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 265 # If if_match is provided, use if_match for conditional deletion 266 if if_match: 267 blob_client.delete_blob(etag=if_match, match_condition=MatchConditions.IfNotModified) 268 else: 269 # No if_match provided, perform unconditional deletion 270 blob_client.delete_blob() 271 272 return self._translate_errors(_invoke_api, operation="DELETE", container=container_name, blob=blob_name) 273 274 def _is_dir(self, path: str) -> bool: 275 # Ensure the path ends with '/' to mimic a directory 276 path = self._append_delimiter(path) 277 278 container_name, prefix = split_path(path) 279 self._refresh_blob_service_client_if_needed() 280 281 def _invoke_api() -> bool: 282 # List objects with the given prefix 283 container_client = self._blob_service_client.get_container_client(container=container_name) 284 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 285 # Check if there are any contents or common prefixes 286 return any(True for _ in blobs) 287 288 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 289 290 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 291 container_name, blob_name = split_path(path) 292 if path.endswith("/") or (container_name and not blob_name): 293 # If path ends with "/" or empty blob name is provided, then assume it's a "directory", 294 # which metadata is not guaranteed to exist for cases such as 295 # "virtual prefix" that was never explicitly created. 296 if self._is_dir(path): 297 return ObjectMetadata( 298 key=self._append_delimiter(path), 299 type="directory", 300 content_length=0, 301 last_modified=AWARE_DATETIME_MIN, 302 ) 303 else: 304 raise FileNotFoundError(f"Directory {path} does not exist.") 305 else: 306 self._refresh_blob_service_client_if_needed() 307 308 def _invoke_api() -> ObjectMetadata: 309 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 310 properties = blob_client.get_blob_properties() 311 return ObjectMetadata( 312 key=path, 313 content_length=properties.size, 314 content_type=properties.content_settings.content_type, 315 last_modified=properties.last_modified, 316 etag=properties.etag.strip('"') if properties.etag else "", 317 metadata=dict(properties.metadata) if properties.metadata else None, 318 ) 319 320 try: 321 return self._translate_errors(_invoke_api, operation="HEAD", container=container_name, blob=blob_name) 322 except FileNotFoundError as error: 323 if strict: 324 # If the object does not exist on the given path, we will append a trailing slash and 325 # check if the path is a directory. 326 path = self._append_delimiter(path) 327 if self._is_dir(path): 328 return ObjectMetadata( 329 key=path, 330 type="directory", 331 content_length=0, 332 last_modified=AWARE_DATETIME_MIN, 333 ) 334 raise error 335 336 def _list_objects( 337 self, 338 path: str, 339 start_after: Optional[str] = None, 340 end_at: Optional[str] = None, 341 include_directories: bool = False, 342 follow_symlinks: bool = True, 343 ) -> Iterator[ObjectMetadata]: 344 container_name, prefix = split_path(path) 345 346 # Get the prefix of the start_after and end_at paths relative to the bucket. 347 if start_after: 348 _, start_after = split_path(start_after) 349 if end_at: 350 _, end_at = split_path(end_at) 351 352 self._refresh_blob_service_client_if_needed() 353 354 def _invoke_api() -> Iterator[ObjectMetadata]: 355 container_client = self._blob_service_client.get_container_client(container=container_name) 356 # Azure has no start key option like other object stores. 357 if include_directories: 358 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 359 else: 360 blobs = container_client.list_blobs(name_starts_with=prefix) 361 # Azure guarantees lexicographical order. 362 for blob in blobs: 363 if isinstance(blob, BlobPrefix): 364 yield ObjectMetadata( 365 key=os.path.join(container_name, blob.name.rstrip("/")), 366 type="directory", 367 content_length=0, 368 last_modified=AWARE_DATETIME_MIN, 369 ) 370 else: 371 key = blob.name 372 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 373 if key.endswith("/"): 374 if include_directories: 375 yield ObjectMetadata( 376 key=os.path.join(container_name, key.rstrip("/")), 377 type="directory", 378 content_length=0, 379 last_modified=blob.last_modified, 380 ) 381 else: 382 yield ObjectMetadata( 383 key=os.path.join(container_name, key), 384 content_length=blob.size, 385 content_type=blob.content_settings.content_type, 386 last_modified=blob.last_modified, 387 etag=blob.etag.strip('"') if blob.etag else "", 388 ) 389 elif end_at is not None and end_at < key: 390 return 391 392 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 393 394 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 395 container_name, blob_name = split_path(remote_path) 396 file_size: int = 0 397 self._refresh_blob_service_client_if_needed() 398 399 validated_attributes = validate_attributes(attributes) 400 if isinstance(f, str): 401 file_size = os.path.getsize(f) 402 403 def _invoke_api() -> int: 404 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 405 with open(f, "rb") as data: 406 blob_client.upload_blob(data, overwrite=True, metadata=validated_attributes or {}) 407 408 return file_size 409 410 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 411 else: 412 # Convert StringIO to BytesIO before upload 413 if isinstance(f, io.StringIO): 414 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore 415 else: 416 fp = f 417 418 fp.seek(0, io.SEEK_END) 419 file_size = fp.tell() 420 fp.seek(0) 421 422 def _invoke_api() -> int: 423 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 424 blob_client.upload_blob(fp, overwrite=True, metadata=validated_attributes or {}) 425 426 return file_size 427 428 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 429 430 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 431 if metadata is None: 432 metadata = self._get_object_metadata(remote_path) 433 434 container_name, blob_name = split_path(remote_path) 435 self._refresh_blob_service_client_if_needed() 436 437 if isinstance(f, str): 438 if os.path.dirname(f): 439 os.makedirs(os.path.dirname(f), exist_ok=True) 440 441 def _invoke_api() -> int: 442 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 443 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 444 temp_file_path = fp.name 445 stream = blob_client.download_blob() 446 fp.write(stream.readall()) 447 os.rename(src=temp_file_path, dst=f) 448 449 return metadata.content_length 450 451 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 452 else: 453 454 def _invoke_api() -> int: 455 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 456 stream = blob_client.download_blob() 457 if isinstance(f, io.StringIO): 458 f.write(stream.readall().decode("utf-8")) 459 else: 460 f.write(stream.readall()) 461 462 return metadata.content_length 463 464 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)