Source code for multistorageclient.providers.azure

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19from collections.abc import Callable, Iterator
 20from typing import IO, Any, Optional, TypeVar, Union
 21
 22from azure.core import MatchConditions
 23from azure.core.exceptions import AzureError, HttpResponseError
 24from azure.storage.blob import BlobPrefix, BlobServiceClient
 25
 26from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
 27from ..telemetry import Telemetry
 28from ..types import (
 29    AWARE_DATETIME_MIN,
 30    Credentials,
 31    CredentialsProvider,
 32    ObjectMetadata,
 33    PreconditionFailedError,
 34    Range,
 35)
 36from ..utils import split_path, validate_attributes
 37from .base import BaseStorageProvider
 38
 39_T = TypeVar("_T")
 40
 41PROVIDER = "azure"
 42
 43
[docs] 44class StaticAzureCredentialsProvider(CredentialsProvider): 45 """ 46 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials. 47 """ 48 49 _connection: str 50 51 def __init__(self, connection: str): 52 """ 53 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string. 54 55 :param connection: The connection string for Azure Blob Storage authentication. 56 """ 57 self._connection = connection 58
[docs] 59 def get_credentials(self) -> Credentials: 60 return Credentials( 61 access_key=self._connection, 62 secret_key="", 63 token=None, 64 expiration=None, 65 )
66
[docs] 67 def refresh_credentials(self) -> None: 68 pass
69 70
[docs] 71class AzureBlobStorageProvider(BaseStorageProvider): 72 """ 73 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage. 74 """ 75 76 def __init__( 77 self, 78 endpoint_url: str, 79 base_path: str = "", 80 credentials_provider: Optional[CredentialsProvider] = None, 81 config_dict: Optional[dict[str, Any]] = None, 82 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 83 **kwargs: dict[str, Any], 84 ): 85 """ 86 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider. 87 88 :param endpoint_url: The Azure storage account URL. 89 :param base_path: The root prefix path within the container where all operations will be scoped. 90 :param credentials_provider: The provider to retrieve Azure credentials. 91 :param config_dict: Resolved MSC config. 92 :param telemetry_provider: A function that provides a telemetry instance. 93 """ 94 super().__init__( 95 base_path=base_path, 96 provider_name=PROVIDER, 97 config_dict=config_dict, 98 telemetry_provider=telemetry_provider, 99 ) 100 101 self._account_url = endpoint_url 102 self._credentials_provider = credentials_provider 103 # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#optional-configuration 104 client_optional_configuration_keys = { 105 "retry_total", 106 "retry_connect", 107 "retry_read", 108 "retry_status", 109 "connection_timeout", 110 "read_timeout", 111 } 112 self._client_optional_configuration: dict[str, Any] = { 113 key: value for key, value in kwargs.items() if key in client_optional_configuration_keys 114 } 115 if "connection_timeout" not in self._client_optional_configuration: 116 self._client_optional_configuration["connection_timeout"] = DEFAULT_CONNECT_TIMEOUT 117 if "read_timeout" not in self._client_optional_configuration: 118 self._client_optional_configuration["read_timeout"] = DEFAULT_READ_TIMEOUT 119 self._blob_service_client = self._create_blob_service_client() 120 121 def _create_blob_service_client(self) -> BlobServiceClient: 122 """ 123 Creates and configures the Azure BlobServiceClient using the current credentials. 124 125 :return: The configured BlobServiceClient. 126 """ 127 if self._credentials_provider: 128 credentials = self._credentials_provider.get_credentials() 129 return BlobServiceClient.from_connection_string( 130 credentials.access_key, **self._client_optional_configuration 131 ) 132 else: 133 return BlobServiceClient(account_url=self._account_url, **self._client_optional_configuration) 134 135 def _refresh_blob_service_client_if_needed(self) -> None: 136 """ 137 Refreshes the BlobServiceClient if the current credentials are expired. 138 """ 139 if self._credentials_provider: 140 credentials = self._credentials_provider.get_credentials() 141 if credentials.is_expired(): 142 self._credentials_provider.refresh_credentials() 143 self._blob_service_client = self._create_blob_service_client() 144 145 def _translate_errors( 146 self, 147 func: Callable[[], _T], 148 operation: str, 149 container: str, 150 blob: str, 151 ) -> _T: 152 """ 153 Translates errors like timeouts and client errors. 154 155 :param func: The function that performs the actual Azure Blob Storage operation. 156 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 157 :param container: The name of the Azure container involved in the operation. 158 :param blob: The name of the blob within the Azure container. 159 160 :return The result of the Azure Blob Storage operation, typically the return value of the `func` callable. 161 """ 162 try: 163 return func() 164 except HttpResponseError as error: 165 status_code = error.status_code if error.status_code else -1 166 error_info = f"status_code: {error.status_code}, reason: {error.reason}" 167 if status_code == 404: 168 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from 169 elif status_code == 412: 170 # raised when If-Match or If-Modified fails 171 raise PreconditionFailedError( 172 f"Failed to {operation} object(s) at {container}/{blob}. {error_info}" 173 ) from error 174 else: 175 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 176 except AzureError as error: 177 error_info = f"message: {error.message}" 178 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}. {error_info}") from error 179 except FileNotFoundError: 180 raise 181 except Exception as error: 182 raise RuntimeError( 183 f"Failed to {operation} object(s) at {container}/{blob}. error_type: {type(error).__name__}, error: {error}" 184 ) from error 185 186 def _put_object( 187 self, 188 path: str, 189 body: bytes, 190 if_match: Optional[str] = None, 191 if_none_match: Optional[str] = None, 192 attributes: Optional[dict[str, str]] = None, 193 ) -> int: 194 """ 195 Uploads an object to Azure Blob Storage. 196 197 :param path: The path to the object to upload. 198 :param body: The content of the object to upload. 199 :param if_match: Optional ETag to match against the object. 200 :param if_none_match: Optional ETag to match against the object. 201 :param attributes: Optional attributes to attach to the object. 202 """ 203 container_name, blob_name = split_path(path) 204 self._refresh_blob_service_client_if_needed() 205 206 def _invoke_api() -> int: 207 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 208 209 kwargs = { 210 "data": body, 211 "overwrite": True, 212 } 213 214 validated_attributes = validate_attributes(attributes) 215 if validated_attributes: 216 kwargs["metadata"] = validated_attributes 217 218 if if_match: 219 kwargs["match_condition"] = MatchConditions.IfNotModified 220 kwargs["etag"] = if_match 221 222 if if_none_match: 223 if if_none_match == "*": 224 raise NotImplementedError("if_none_match='*' is not supported for Azure") 225 kwargs["match_condition"] = MatchConditions.IfModified 226 kwargs["etag"] = if_none_match 227 228 blob_client.upload_blob(**kwargs) 229 230 return len(body) 231 232 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 233 234 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 235 container_name, blob_name = split_path(path) 236 self._refresh_blob_service_client_if_needed() 237 238 def _invoke_api() -> bytes: 239 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 240 if byte_range: 241 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size) 242 else: 243 stream = blob_client.download_blob() 244 return stream.readall() 245 246 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 247 248 def _copy_object(self, src_path: str, dest_path: str) -> int: 249 src_container, src_blob = split_path(src_path) 250 dest_container, dest_blob = split_path(dest_path) 251 self._refresh_blob_service_client_if_needed() 252 253 src_object = self._get_object_metadata(src_path) 254 255 def _invoke_api() -> int: 256 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob) 257 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob) 258 dest_blob_client.start_copy_from_url(src_blob_client.url) 259 260 return src_object.content_length 261 262 return self._translate_errors(_invoke_api, operation="COPY", container=src_container, blob=src_blob) 263 264 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 265 container_name, blob_name = split_path(path) 266 self._refresh_blob_service_client_if_needed() 267 268 def _invoke_api() -> None: 269 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 270 # If if_match is provided, use if_match for conditional deletion 271 if if_match: 272 blob_client.delete_blob(etag=if_match, match_condition=MatchConditions.IfNotModified) 273 else: 274 # No if_match provided, perform unconditional deletion 275 blob_client.delete_blob() 276 277 return self._translate_errors(_invoke_api, operation="DELETE", container=container_name, blob=blob_name) 278 279 def _is_dir(self, path: str) -> bool: 280 # Ensure the path ends with '/' to mimic a directory 281 path = self._append_delimiter(path) 282 283 container_name, prefix = split_path(path) 284 self._refresh_blob_service_client_if_needed() 285 286 def _invoke_api() -> bool: 287 # List objects with the given prefix 288 container_client = self._blob_service_client.get_container_client(container=container_name) 289 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 290 # Check if there are any contents or common prefixes 291 return any(True for _ in blobs) 292 293 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 294 295 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 296 container_name, blob_name = split_path(path) 297 if path.endswith("/") or (container_name and not blob_name): 298 # If path ends with "/" or empty blob name is provided, then assume it's a "directory", 299 # which metadata is not guaranteed to exist for cases such as 300 # "virtual prefix" that was never explicitly created. 301 if self._is_dir(path): 302 return ObjectMetadata( 303 key=self._append_delimiter(path), 304 type="directory", 305 content_length=0, 306 last_modified=AWARE_DATETIME_MIN, 307 ) 308 else: 309 raise FileNotFoundError(f"Directory {path} does not exist.") 310 else: 311 self._refresh_blob_service_client_if_needed() 312 313 def _invoke_api() -> ObjectMetadata: 314 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 315 properties = blob_client.get_blob_properties() 316 return ObjectMetadata( 317 key=path, 318 content_length=properties.size, 319 content_type=properties.content_settings.content_type, 320 last_modified=properties.last_modified, 321 etag=properties.etag.strip('"') if properties.etag else "", 322 metadata=dict(properties.metadata) if properties.metadata else None, 323 ) 324 325 try: 326 return self._translate_errors(_invoke_api, operation="HEAD", container=container_name, blob=blob_name) 327 except FileNotFoundError as error: 328 if strict: 329 # If the object does not exist on the given path, we will append a trailing slash and 330 # check if the path is a directory. 331 path = self._append_delimiter(path) 332 if self._is_dir(path): 333 return ObjectMetadata( 334 key=path, 335 type="directory", 336 content_length=0, 337 last_modified=AWARE_DATETIME_MIN, 338 ) 339 raise error 340 341 def _list_objects( 342 self, 343 path: str, 344 start_after: Optional[str] = None, 345 end_at: Optional[str] = None, 346 include_directories: bool = False, 347 follow_symlinks: bool = True, 348 ) -> Iterator[ObjectMetadata]: 349 container_name, prefix = split_path(path) 350 351 # Get the prefix of the start_after and end_at paths relative to the bucket. 352 if start_after: 353 _, start_after = split_path(start_after) 354 if end_at: 355 _, end_at = split_path(end_at) 356 357 self._refresh_blob_service_client_if_needed() 358 359 def _invoke_api() -> Iterator[ObjectMetadata]: 360 container_client = self._blob_service_client.get_container_client(container=container_name) 361 # Azure has no start key option like other object stores. 362 if include_directories: 363 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 364 else: 365 blobs = container_client.list_blobs(name_starts_with=prefix) 366 # Azure guarantees lexicographical order. 367 for blob in blobs: 368 if isinstance(blob, BlobPrefix): 369 yield ObjectMetadata( 370 key=os.path.join(container_name, blob.name.rstrip("/")), 371 type="directory", 372 content_length=0, 373 last_modified=AWARE_DATETIME_MIN, 374 ) 375 else: 376 key = blob.name 377 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 378 if key.endswith("/"): 379 if include_directories: 380 yield ObjectMetadata( 381 key=os.path.join(container_name, key.rstrip("/")), 382 type="directory", 383 content_length=0, 384 last_modified=blob.last_modified, 385 ) 386 else: 387 yield ObjectMetadata( 388 key=os.path.join(container_name, key), 389 content_length=blob.size, 390 content_type=blob.content_settings.content_type, 391 last_modified=blob.last_modified, 392 etag=blob.etag.strip('"') if blob.etag else "", 393 ) 394 elif end_at is not None and end_at < key: 395 return 396 397 return self._translate_errors(_invoke_api, operation="LIST", container=container_name, blob=prefix) 398 399 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 400 container_name, blob_name = split_path(remote_path) 401 file_size: int = 0 402 self._refresh_blob_service_client_if_needed() 403 404 validated_attributes = validate_attributes(attributes) 405 if isinstance(f, str): 406 file_size = os.path.getsize(f) 407 408 def _invoke_api() -> int: 409 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 410 with open(f, "rb") as data: 411 blob_client.upload_blob(data, overwrite=True, metadata=validated_attributes or {}) 412 413 return file_size 414 415 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 416 else: 417 # Convert StringIO to BytesIO before upload 418 if isinstance(f, io.StringIO): 419 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore 420 else: 421 fp = f 422 423 fp.seek(0, io.SEEK_END) 424 file_size = fp.tell() 425 fp.seek(0) 426 427 def _invoke_api() -> int: 428 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 429 blob_client.upload_blob(fp, overwrite=True, metadata=validated_attributes or {}) 430 431 return file_size 432 433 return self._translate_errors(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 434 435 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 436 if metadata is None: 437 metadata = self._get_object_metadata(remote_path) 438 439 container_name, blob_name = split_path(remote_path) 440 self._refresh_blob_service_client_if_needed() 441 442 if isinstance(f, str): 443 if os.path.dirname(f): 444 os.makedirs(os.path.dirname(f), exist_ok=True) 445 446 def _invoke_api() -> int: 447 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 448 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 449 temp_file_path = fp.name 450 stream = blob_client.download_blob() 451 fp.write(stream.readall()) 452 os.rename(src=temp_file_path, dst=f) 453 454 return metadata.content_length 455 456 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name) 457 else: 458 459 def _invoke_api() -> int: 460 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 461 stream = blob_client.download_blob() 462 if isinstance(f, io.StringIO): 463 f.write(stream.readall().decode("utf-8")) 464 else: 465 f.write(stream.readall()) 466 467 return metadata.content_length 468 469 return self._translate_errors(_invoke_api, operation="GET", container=container_name, blob=blob_name)