Source code for multistorageclient.providers.azure

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19from collections.abc import Callable, Iterator
 20from typing import IO, Any, Optional, TypeVar, Union
 21
 22from azure.core import MatchConditions
 23from azure.core.exceptions import AzureError, HttpResponseError
 24from azure.storage.blob import BlobPrefix, BlobServiceClient
 25
 26from ..telemetry import Telemetry
 27from ..types import (
 28    AWARE_DATETIME_MIN,
 29    Credentials,
 30    CredentialsProvider,
 31    ObjectMetadata,
 32    PreconditionFailedError,
 33    Range,
 34)
 35from ..utils import split_path, validate_attributes
 36from .base import BaseStorageProvider
 37
 38_T = TypeVar("_T")
 39
 40PROVIDER = "azure"
 41
 42
[docs] 43class StaticAzureCredentialsProvider(CredentialsProvider): 44 """ 45 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials. 46 """ 47 48 _connection: str 49 50 def __init__(self, connection: str): 51 """ 52 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string. 53 54 :param connection: The connection string for Azure Blob Storage authentication. 55 """ 56 self._connection = connection 57
[docs] 58 def get_credentials(self) -> Credentials: 59 return Credentials( 60 access_key=self._connection, 61 secret_key="", 62 token=None, 63 expiration=None, 64 )
65
[docs] 66 def refresh_credentials(self) -> None: 67 pass
68 69
[docs] 70class AzureBlobStorageProvider(BaseStorageProvider): 71 """ 72 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage. 73 """ 74 75 def __init__( 76 self, 77 endpoint_url: str, 78 base_path: str = "", 79 credentials_provider: Optional[CredentialsProvider] = None, 80 config_dict: Optional[dict[str, Any]] = None, 81 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 82 **kwargs: dict[str, Any], 83 ): 84 """ 85 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider. 86 87 :param endpoint_url: The Azure storage account URL. 88 :param base_path: The root prefix path within the container where all operations will be scoped. 89 :param credentials_provider: The provider to retrieve Azure credentials. 90 :param config_dict: Resolved MSC config. 91 :param telemetry_provider: A function that provides a telemetry instance. 92 """ 93 super().__init__( 94 base_path=base_path, 95 provider_name=PROVIDER, 96 config_dict=config_dict, 97 telemetry_provider=telemetry_provider, 98 ) 99 100 self._account_url = endpoint_url 101 self._credentials_provider = credentials_provider 102 # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#optional-configuration 103 client_optional_configuration_keys = { 104 "retry_total", 105 "retry_connect", 106 "retry_read", 107 "retry_status", 108 "connection_timeout", 109 "read_timeout", 110 } 111 self._client_optional_configuration = { 112 key: value for key, value in kwargs.items() if key in client_optional_configuration_keys 113 } 114 self._blob_service_client = self._create_blob_service_client() 115 116 def _create_blob_service_client(self) -> BlobServiceClient: 117 """ 118 Creates and configures the Azure BlobServiceClient using the current credentials. 119 120 :return: The configured BlobServiceClient. 121 """ 122 if self._credentials_provider: 123 credentials = self._credentials_provider.get_credentials() 124 return BlobServiceClient.from_connection_string( 125 credentials.access_key, **self._client_optional_configuration 126 ) 127 else: 128 return BlobServiceClient(account_url=self._account_url, **self._client_optional_configuration) 129 130 def _refresh_blob_service_client_if_needed(self) -> None: 131 """ 132 Refreshes the BlobServiceClient if the current credentials are expired. 133 """ 134 if self._credentials_provider: 135 credentials = self._credentials_provider.get_credentials() 136 if credentials.is_expired(): 137 self._credentials_provider.refresh_credentials() 138 self._blob_service_client = self._create_blob_service_client() 139 140 def _translate_errors( 141 self, 142 func: Callable[[], _T], 143 operation: str, 144 container: str, 145 blob: str, 146 ) -> _T: 147 """ 148 Translates errors like timeouts and client errors. 149 150 :param func: The function that performs the actual Azure Blob Storage operation. 151 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 152 :param container: The name of the Azure container involved in the operation. 153 :param blob: The name of the blob within the Azure container. 154 155 :return The result of the Azure Blob Storage operation, typically the return value of the `func` callable. 156 """ 157 try: 158 return func() 159 except HttpResponseError as error: 160 status_code = error.status_code if error.status_code else -1 161 error_info = f"status_code: {error.status_code}, reason: {error.reason}" 162 if status_code == 404: 163 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from 164 elif status_code == 412: 165 # raised when If-Match or If-Modified fails 166 raise PreconditionFailedError( 167 f"Failed to {operation} object(s) at {container}/{blob}. {error_info}" 168 ) from error 169 else: 170 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 171 except AzureError as error: 172 error_info = f"message: {error.message}" 173 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 174 except Exception as error: 175 raise RuntimeError( 176 f"Failed to {operation} object(s) at {container}/{blob}. error_type: {type(error).__name__}, error: {error}" 177 ) from error 178 179 def _put_object( 180 self, 181 path: str, 182 body: bytes, 183 if_match: Optional[str] = None, 184 if_none_match: Optional[str] = None, 185 attributes: Optional[dict[str, str]] = None, 186 ) -> int: 187 """ 188 Uploads an object to Azure Blob Storage. 189 190 :param path: The path to the object to upload. 191 :param body: The content of the object to upload. 192 :param if_match: Optional ETag to match against the object. 193 :param if_none_match: Optional ETag to match against the object. 194 :param attributes: Optional attributes to attach to the object. 195 """ 196 container_name, blob_name = split_path(path) 197 self._refresh_blob_service_client_if_needed() 198 199 def _invoke_api() -> int: 200 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 201 202 kwargs = { 203 "data": body, 204 "overwrite": True, 205 } 206 207 validated_attributes = validate_attributes(attributes) 208 if validated_attributes: 209 kwargs["metadata"] = validated_attributes 210 211 if if_match: 212 kwargs["match_condition"] = MatchConditions.IfNotModified 213 kwargs["etag"] = if_match 214 215 if if_none_match: 216 if if_none_match == "*": 217 raise NotImplementedError("if_none_match='*' is not supported for Azure") 218 kwargs["match_condition"] = MatchConditions.IfModified 219 kwargs["etag"] = if_none_match 220 221 blob_client.upload_blob(**kwargs) 222 223 return len(body) 224 225 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 226 227 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 228 container_name, blob_name = split_path(path) 229 self._refresh_blob_service_client_if_needed() 230 231 def _invoke_api() -> bytes: 232 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 233 if byte_range: 234 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size) 235 else: 236 stream = blob_client.download_blob() 237 return stream.readall() 238 239 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 240 241 def _copy_object(self, src_path: str, dest_path: str) -> int: 242 src_container, src_blob = split_path(src_path) 243 dest_container, dest_blob = split_path(dest_path) 244 self._refresh_blob_service_client_if_needed() 245 246 src_object = self._get_object_metadata(src_path) 247 248 def _invoke_api() -> int: 249 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob) 250 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob) 251 dest_blob_client.start_copy_from_url(src_blob_client.url) 252 253 return src_object.content_length 254 255 return self._translate_errors(_invoke_api, operation="COPY", container=src_container, blob=src_blob) 256 257 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 258 container_name, blob_name = split_path(path) 259 self._refresh_blob_service_client_if_needed() 260 261 def _invoke_api() -> None: 262 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 263 # If if_match is provided, use if_match for conditional deletion 264 if if_match: 265 blob_client.delete_blob(etag=if_match, match_condition=MatchConditions.IfNotModified) 266 else: 267 # No if_match provided, perform unconditional deletion 268 blob_client.delete_blob() 269 270 return self._translate_errors(_invoke_api, operation="DELETE", container=container_name, blob=blob_name) 271 272 def _is_dir(self, path: str) -> bool: 273 # Ensure the path ends with '/' to mimic a directory 274 path = self._append_delimiter(path) 275 276 container_name, prefix = split_path(path) 277 self._refresh_blob_service_client_if_needed() 278 279 def _invoke_api() -> bool: 280 # List objects with the given prefix 281 container_client = self._blob_service_client.get_container_client(container=container_name) 282 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 283 # Check if there are any contents or common prefixes 284 return any(True for _ in blobs) 285 286 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 287 288 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 289 container_name, blob_name = split_path(path) 290 if path.endswith("/") or (container_name and not blob_name): 291 # If path ends with "/" or empty blob name is provided, then assume it's a "directory", 292 # which metadata is not guaranteed to exist for cases such as 293 # "virtual prefix" that was never explicitly created. 294 if self._is_dir(path): 295 return ObjectMetadata( 296 key=self._append_delimiter(path), 297 type="directory", 298 content_length=0, 299 last_modified=AWARE_DATETIME_MIN, 300 ) 301 else: 302 raise FileNotFoundError(f"Directory {path} does not exist.") 303 else: 304 self._refresh_blob_service_client_if_needed() 305 306 def _invoke_api() -> ObjectMetadata: 307 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 308 properties = blob_client.get_blob_properties() 309 return ObjectMetadata( 310 key=path, 311 content_length=properties.size, 312 content_type=properties.content_settings.content_type, 313 last_modified=properties.last_modified, 314 etag=properties.etag.strip('"') if properties.etag else "", 315 metadata=dict(properties.metadata) if properties.metadata else None, 316 ) 317 318 try: 319 return self._translate_errors(_invoke_api, operation="HEAD", container=container_name, blob=blob_name) 320 except FileNotFoundError as error: 321 if strict: 322 # If the object does not exist on the given path, we will append a trailing slash and 323 # check if the path is a directory. 324 path = self._append_delimiter(path) 325 if self._is_dir(path): 326 return ObjectMetadata( 327 key=path, 328 type="directory", 329 content_length=0, 330 last_modified=AWARE_DATETIME_MIN, 331 ) 332 raise error 333 334 def _list_objects( 335 self, 336 path: str, 337 start_after: Optional[str] = None, 338 end_at: Optional[str] = None, 339 include_directories: bool = False, 340 ) -> Iterator[ObjectMetadata]: 341 container_name, prefix = split_path(path) 342 self._refresh_blob_service_client_if_needed() 343 344 def _invoke_api() -> Iterator[ObjectMetadata]: 345 container_client = self._blob_service_client.get_container_client(container=container_name) 346 # Azure has no start key option like other object stores. 347 if include_directories: 348 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 349 else: 350 blobs = container_client.list_blobs(name_starts_with=prefix) 351 # Azure guarantees lexicographical order. 352 for blob in blobs: 353 if isinstance(blob, BlobPrefix): 354 yield ObjectMetadata( 355 key=os.path.join(container_name, blob.name.rstrip("/")), 356 type="directory", 357 content_length=0, 358 last_modified=AWARE_DATETIME_MIN, 359 ) 360 else: 361 key = blob.name 362 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 363 if key.endswith("/"): 364 if include_directories: 365 yield ObjectMetadata( 366 key=os.path.join(container_name, key.rstrip("/")), 367 type="directory", 368 content_length=0, 369 last_modified=blob.last_modified, 370 ) 371 else: 372 yield ObjectMetadata( 373 key=os.path.join(container_name, key), 374 content_length=blob.size, 375 content_type=blob.content_settings.content_type, 376 last_modified=blob.last_modified, 377 etag=blob.etag.strip('"') if blob.etag else "", 378 ) 379 elif end_at is not None and end_at < key: 380 return 381 382 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 383 384 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 385 container_name, blob_name = split_path(remote_path) 386 file_size: int = 0 387 self._refresh_blob_service_client_if_needed() 388 389 validated_attributes = validate_attributes(attributes) 390 if isinstance(f, str): 391 file_size = os.path.getsize(f) 392 393 def _invoke_api() -> int: 394 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 395 with open(f, "rb") as data: 396 blob_client.upload_blob(data, overwrite=True, metadata=validated_attributes or {}) 397 398 return file_size 399 400 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 401 else: 402 # Convert StringIO to BytesIO before upload 403 if isinstance(f, io.StringIO): 404 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore 405 else: 406 fp = f 407 408 fp.seek(0, io.SEEK_END) 409 file_size = fp.tell() 410 fp.seek(0) 411 412 def _invoke_api() -> int: 413 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 414 blob_client.upload_blob(fp, overwrite=True, metadata=validated_attributes or {}) 415 416 return file_size 417 418 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 419 420 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 421 if metadata is None: 422 metadata = self._get_object_metadata(remote_path) 423 424 container_name, blob_name = split_path(remote_path) 425 self._refresh_blob_service_client_if_needed() 426 427 if isinstance(f, str): 428 if os.path.dirname(f): 429 os.makedirs(os.path.dirname(f), exist_ok=True) 430 431 def _invoke_api() -> int: 432 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 433 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 434 temp_file_path = fp.name 435 stream = blob_client.download_blob() 436 fp.write(stream.readall()) 437 os.rename(src=temp_file_path, dst=f) 438 439 return metadata.content_length 440 441 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 442 else: 443 444 def _invoke_api() -> int: 445 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 446 stream = blob_client.download_blob() 447 if isinstance(f, io.StringIO): 448 f.write(stream.readall().decode("utf-8")) 449 else: 450 f.write(stream.readall()) 451 452 return metadata.content_length 453 454 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)