Source code for multistorageclient.providers.azure

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19import time
 20from datetime import datetime
 21from typing import IO, Any, Callable, Iterator, Optional, Union
 22
 23from azure.core.exceptions import ResourceNotFoundError
 24from azure.storage.blob import BlobPrefix, BlobServiceClient
 25
 26from ..types import (
 27    Credentials,
 28    CredentialsProvider,
 29    ObjectMetadata,
 30    Range,
 31)
 32from ..utils import split_path
 33from .base import BaseStorageProvider
 34
 35PROVIDER = "azure"
 36
 37
[docs] 38class StaticAzureCredentialsProvider(CredentialsProvider): 39 """ 40 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static Azure credentials. 41 """ 42 43 _connection: str 44 45 def __init__(self, connection: str): 46 """ 47 Initializes the :py:class:`StaticAzureCredentialsProvider` with the provided connection string. 48 49 :param connection: The connection string for Azure Blob Storage authentication. 50 """ 51 self._connection = connection 52
[docs] 53 def get_credentials(self) -> Credentials: 54 return Credentials( 55 access_key=self._connection, 56 secret_key="", 57 token=None, 58 expiration=None, 59 )
60
[docs] 61 def refresh_credentials(self) -> None: 62 pass
63 64
[docs] 65class AzureBlobStorageProvider(BaseStorageProvider): 66 """ 67 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with Azure Blob Storage. 68 """ 69 70 def __init__( 71 self, endpoint_url: str, base_path: str = "", credentials_provider: Optional[CredentialsProvider] = None 72 ): 73 """ 74 Initializes the :py:class:`AzureBlobStorageProvider` with the endpoint URL and optional credentials provider. 75 76 :param endpoint_url: The Azure storage account URL. 77 :param base_path: The root prefix path within the container where all operations will be scoped. 78 :param credentials_provider: The provider to retrieve Azure credentials. 79 """ 80 super().__init__(base_path=base_path, provider_name=PROVIDER) 81 82 self._account_url = endpoint_url 83 self._credentials_provider = credentials_provider 84 self._blob_service_client = self._create_blob_service_client() 85 86 def _create_blob_service_client(self) -> BlobServiceClient: 87 """ 88 Creates and configures the Azure BlobServiceClient using the current credentials. 89 90 :return: The configured BlobServiceClient. 91 """ 92 if self._credentials_provider: 93 credentials = self._credentials_provider.get_credentials() 94 return BlobServiceClient.from_connection_string(credentials.access_key) 95 else: 96 return BlobServiceClient(account_url=self._account_url) 97 98 def _refresh_blob_service_client_if_needed(self) -> None: 99 """ 100 Refreshes the BlobServiceClient if the current credentials are expired. 101 """ 102 if self._credentials_provider: 103 credentials = self._credentials_provider.get_credentials() 104 if credentials.is_expired(): 105 self._credentials_provider.refresh_credentials() 106 self._blob_service_client = self._create_blob_service_client() 107 108 def _collect_metrics( 109 self, 110 func: Callable, 111 operation: str, 112 container: str, 113 blob: str, 114 put_object_size: Optional[int] = None, 115 get_object_size: Optional[int] = None, 116 ) -> Any: 117 """ 118 Collects and records performance metrics around Azure operations such as PUT, GET, DELETE, etc. 119 120 This method wraps an Azure operation and measures the time it takes to complete, along with recording 121 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 122 proper logging of duration and object size. 123 124 :param func: The function that performs the actual GCS operation. 125 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 126 :param container: The name of the Azure container involved in the operation. 127 :param blob: The name of the blob within the Azure container. 128 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 129 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 130 131 :return The result of the GCS operation, typically the return value of the `func` callable. 132 """ 133 start_time = time.time() 134 status_code = 200 135 136 object_size = None 137 if operation == "PUT": 138 object_size = put_object_size 139 elif operation == "GET" and get_object_size: 140 object_size = get_object_size 141 142 try: 143 result = func() 144 if operation == "GET" and object_size is None: 145 object_size = len(result) 146 return result 147 except ResourceNotFoundError: 148 status_code = 404 149 raise FileNotFoundError(f"Object {container}/{blob} does not exist.") # pylint: disable=raise-missing-from 150 except Exception as error: 151 status_code = -1 152 raise RuntimeError(f"Failed to {operation} object(s) at {container}/{blob}") from error 153 finally: 154 elapsed_time = time.time() - start_time 155 self._metric_helper.record_duration( 156 elapsed_time, provider=PROVIDER, operation=operation, bucket=container, status_code=status_code 157 ) 158 if object_size: 159 self._metric_helper.record_object_size( 160 object_size, provider=PROVIDER, operation=operation, bucket=container, status_code=status_code 161 ) 162 163 def _put_object(self, path: str, body: bytes) -> None: 164 container_name, blob_name = split_path(path) 165 self._refresh_blob_service_client_if_needed() 166 167 def _invoke_api() -> None: 168 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 169 blob_client.upload_blob(body, overwrite=True) 170 171 return self._collect_metrics(_invoke_api, operation="PUT", container=container_name, blob=blob_name) 172 173 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 174 container_name, blob_name = split_path(path) 175 self._refresh_blob_service_client_if_needed() 176 177 def _invoke_api() -> bytes: 178 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 179 if byte_range: 180 stream = blob_client.download_blob(offset=byte_range.offset, length=byte_range.size) 181 else: 182 stream = blob_client.download_blob() 183 return stream.readall() 184 185 return self._collect_metrics(_invoke_api, operation="GET", container=container_name, blob=blob_name) 186 187 def _copy_object(self, src_path: str, dest_path: str) -> None: 188 src_container, src_blob = split_path(src_path) 189 dest_container, dest_blob = split_path(dest_path) 190 self._refresh_blob_service_client_if_needed() 191 192 def _invoke_api() -> None: 193 src_blob_client = self._blob_service_client.get_blob_client(container=src_container, blob=src_blob) 194 dest_blob_client = self._blob_service_client.get_blob_client(container=dest_container, blob=dest_blob) 195 dest_blob_client.start_copy_from_url(src_blob_client.url) 196 197 src_object = self._get_object_metadata(src_path) 198 199 return self._collect_metrics( 200 _invoke_api, 201 operation="COPY", 202 container=src_container, 203 blob=src_blob, 204 put_object_size=src_object.content_length, 205 ) 206 207 def _delete_object(self, path: str) -> None: 208 container_name, blob_name = split_path(path) 209 self._refresh_blob_service_client_if_needed() 210 211 def _invoke_api() -> None: 212 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 213 blob_client.delete_blob() 214 215 return self._collect_metrics(_invoke_api, operation="DELETE", container=container_name, blob=blob_name) 216 217 def _is_dir(self, path: str) -> bool: 218 # Ensure the path ends with '/' to mimic a directory 219 path = self._append_delimiter(path) 220 221 container_name, prefix = split_path(path) 222 self._refresh_blob_service_client_if_needed() 223 224 def _invoke_api() -> bool: 225 # List objects with the given prefix 226 container_client = self._blob_service_client.get_container_client(container=container_name) 227 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 228 # Check if there are any contents or common prefixes 229 return any(True for _ in blobs) 230 231 return self._collect_metrics(_invoke_api, operation="LIST", container=container_name, blob=prefix) 232 233 def _get_object_metadata(self, path: str) -> ObjectMetadata: 234 if path.endswith("/"): 235 # If path is a "directory", then metadata is not guaranteed to exist if 236 # it is a "virtual prefix" that was never explicitly created. 237 if self._is_dir(path): 238 return ObjectMetadata( 239 key=self._append_delimiter(path), 240 type="directory", 241 content_length=0, 242 last_modified=datetime.min, 243 ) 244 else: 245 raise FileNotFoundError(f"Directory {path} does not exist.") 246 else: 247 container_name, blob_name = split_path(path) 248 self._refresh_blob_service_client_if_needed() 249 250 def _invoke_api() -> ObjectMetadata: 251 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 252 properties = blob_client.get_blob_properties() 253 return ObjectMetadata( 254 key=path, 255 content_length=properties.size, 256 content_type=properties.content_settings.content_type, 257 last_modified=properties.last_modified, 258 etag=properties.etag.strip('"') if properties.etag else "", 259 ) 260 261 try: 262 return self._collect_metrics(_invoke_api, operation="HEAD", container=container_name, blob=blob_name) 263 except FileNotFoundError as error: 264 # If the object does not exist on the given path, we will append a trailing slash and 265 # check if the path is a directory. 266 path = self._append_delimiter(path) 267 if self._is_dir(path): 268 return ObjectMetadata( 269 key=path, 270 type="directory", 271 content_length=0, 272 last_modified=datetime.min, 273 ) 274 else: 275 raise error 276 277 def _list_objects( 278 self, 279 prefix: str, 280 start_after: Optional[str] = None, 281 end_at: Optional[str] = None, 282 include_directories: bool = False, 283 ) -> Iterator[ObjectMetadata]: 284 container_name, prefix = split_path(prefix) 285 self._refresh_blob_service_client_if_needed() 286 287 def _invoke_api() -> Iterator[ObjectMetadata]: 288 container_client = self._blob_service_client.get_container_client(container=container_name) 289 # Azure has no start key option like other object stores. 290 if include_directories: 291 blobs = container_client.walk_blobs(name_starts_with=prefix, delimiter="/") 292 else: 293 blobs = container_client.list_blobs(name_starts_with=prefix) 294 # Azure guarantees lexicographical order. 295 for blob in blobs: 296 if isinstance(blob, BlobPrefix): 297 yield ObjectMetadata( 298 key=blob.name.rstrip("/"), 299 type="directory", 300 content_length=0, 301 last_modified=datetime.min, 302 ) 303 else: 304 key = blob.name 305 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 306 yield ObjectMetadata( 307 key=key, 308 content_length=blob.size, 309 content_type=blob.content_settings.content_type, 310 last_modified=blob.last_modified, 311 etag=blob.etag.strip('"') if blob.etag else "", 312 ) 313 elif end_at is not None and end_at < key: 314 return 315 316 return self._collect_metrics(_invoke_api, operation="LIST", container=container_name, blob=prefix) 317 318 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None: 319 container_name, blob_name = split_path(remote_path) 320 self._refresh_blob_service_client_if_needed() 321 322 if isinstance(f, str): 323 file_size = os.path.getsize(f) 324 325 def _invoke_api() -> None: 326 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 327 with open(f, "rb") as data: 328 blob_client.upload_blob(data, overwrite=True) 329 330 return self._collect_metrics( 331 _invoke_api, operation="PUT", container=container_name, blob=blob_name, put_object_size=file_size 332 ) 333 else: 334 # Convert StringIO to BytesIO before upload 335 if isinstance(f, io.StringIO): 336 fp: IO = io.BytesIO(f.getvalue().encode("utf-8")) # type: ignore 337 else: 338 fp = f 339 340 fp.seek(0, io.SEEK_END) 341 file_size = fp.tell() 342 fp.seek(0) 343 344 def _invoke_api() -> None: 345 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 346 blob_client.upload_blob(fp, overwrite=True) 347 348 return self._collect_metrics( 349 _invoke_api, operation="PUT", container=container_name, blob=blob_name, put_object_size=file_size 350 ) 351 352 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None: 353 if not metadata: 354 metadata = self._get_object_metadata(remote_path) 355 356 container_name, blob_name = split_path(remote_path) 357 self._refresh_blob_service_client_if_needed() 358 359 if isinstance(f, str): 360 os.makedirs(os.path.dirname(f), exist_ok=True) 361 362 def _invoke_api() -> None: 363 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 364 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 365 temp_file_path = fp.name 366 stream = blob_client.download_blob() 367 fp.write(stream.readall()) 368 os.rename(src=temp_file_path, dst=f) 369 370 return self._collect_metrics( 371 _invoke_api, 372 operation="GET", 373 container=container_name, 374 blob=blob_name, 375 get_object_size=metadata.content_length, 376 ) 377 else: 378 379 def _invoke_api() -> None: 380 blob_client = self._blob_service_client.get_blob_client(container=container_name, blob=blob_name) 381 stream = blob_client.download_blob() 382 if isinstance(f, io.StringIO): 383 f.write(stream.readall().decode("utf-8")) 384 else: 385 f.write(stream.readall()) 386 387 return self._collect_metrics( 388 _invoke_api, 389 operation="GET", 390 container=container_name, 391 blob=blob_name, 392 get_object_size=metadata.content_length, 393 )