Source code for multistorageclient.providers.oci

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19import time
 20from collections.abc import Callable, Iterator, Sequence, Sized
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23import oci
 24import opentelemetry.metrics as api_metrics
 25from dateutil.parser import parse as dateutil_parser
 26from oci._vendor.requests.exceptions import (
 27    ChunkedEncodingError,
 28    ConnectionError,
 29    ContentDecodingError,
 30)
 31from oci.exceptions import ServiceError
 32from oci.object_storage import ObjectStorageClient, UploadManager
 33from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
 34
 35from ..telemetry import Telemetry
 36from ..telemetry.attributes.base import AttributesProvider
 37from ..types import (
 38    AWARE_DATETIME_MIN,
 39    CredentialsProvider,
 40    ObjectMetadata,
 41    PreconditionFailedError,
 42    Range,
 43    RetryableError,
 44)
 45from ..utils import split_path, validate_attributes
 46from .base import BaseStorageProvider
 47
 48_T = TypeVar("_T")
 49
 50MB = 1024 * 1024
 51
 52MULTIPART_THRESHOLD = 512 * MB
 53MULTIPART_CHUNK_SIZE = 256 * MB
 54
 55PROVIDER = "oci"
 56
 57
[docs] 58class OracleStorageProvider(BaseStorageProvider): 59 """ 60 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with 61 Oracle Cloud Infrastructure (OCI) Object Storage. 62 """ 63 64 def __init__( 65 self, 66 namespace: str, 67 base_path: str = "", 68 credentials_provider: Optional[CredentialsProvider] = None, 69 retry_strategy: Optional[dict[str, Any]] = None, 70 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {}, 71 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {}, 72 metric_attributes_providers: Sequence[AttributesProvider] = (), 73 **kwargs: Any, 74 ) -> None: 75 """ 76 Initializes an instance of :py:class:`OracleStorageProvider`. 77 78 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy. 79 :param base_path: The root prefix path within the bucket where all operations will be scoped. 80 :param credentials_provider: The provider to retrieve OCI credentials. 81 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters. 82 :param metric_counters: Metric counters. 83 :param metric_gauges: Metric gauges. 84 :param metric_attributes_providers: Metric attributes providers. 85 """ 86 super().__init__( 87 base_path=base_path, 88 provider_name=PROVIDER, 89 metric_counters=metric_counters, 90 metric_gauges=metric_gauges, 91 metric_attributes_providers=metric_attributes_providers, 92 ) 93 94 self._namespace = namespace 95 self._credentials_provider = credentials_provider 96 self._retry_strategy = ( 97 DEFAULT_RETRY_STRATEGY 98 if retry_strategy is None 99 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy() 100 ) 101 self._oci_client = self._create_oci_client() 102 self._upload_manager = UploadManager(self._oci_client) 103 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)) 104 self._multipart_chunk_size = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNK_SIZE)) 105 106 def _create_oci_client(self) -> ObjectStorageClient: 107 config = oci.config.from_file() 108 return ObjectStorageClient(config, retry_strategy=self._retry_strategy) 109 110 def _refresh_oci_client_if_needed(self) -> None: 111 """ 112 Refreshes the OCI client if the current credentials are expired. 113 """ 114 if self._credentials_provider: 115 credentials = self._credentials_provider.get_credentials() 116 if credentials.is_expired(): 117 self._credentials_provider.refresh_credentials() 118 self._oci_client = self._create_oci_client() 119 self._upload_manager = UploadManager( 120 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4 121 ) 122 123 def _collect_metrics( 124 self, 125 func: Callable[[], _T], 126 operation: str, 127 bucket: str, 128 key: str, 129 put_object_size: Optional[int] = None, 130 get_object_size: Optional[int] = None, 131 ) -> _T: 132 """ 133 Collects and records performance metrics around object storage operations such as PUT, GET, DELETE, etc. 134 135 This method wraps an object storage operation and measures the time it takes to complete, along with recording 136 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 137 proper logging of duration and object size. 138 139 :param func: The function that performs the actual object storage operation. 140 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 141 :param bucket: The name of the object storage bucket involved in the operation. 142 :param key: The key of the object within the object storage bucket. 143 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 144 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 145 146 :return: The result of the object storage operation, typically the return value of the `func` callable. 147 """ 148 start_time = time.time() 149 status_code = 200 150 151 object_size = None 152 if operation == "PUT": 153 object_size = put_object_size 154 elif operation == "GET" and get_object_size: 155 object_size = get_object_size 156 157 try: 158 result = func() 159 if operation == "GET" and object_size is None and isinstance(result, Sized): 160 object_size = len(result) 161 return result 162 except ServiceError as error: 163 status_code = error.status 164 request_id = error.request_id 165 endpoint = error.request_endpoint 166 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}" 167 168 if status_code == 404: 169 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 170 elif status_code == 412: 171 raise PreconditionFailedError( 172 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 173 ) from error 174 elif status_code == 429: 175 raise RetryableError( 176 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 177 ) from error 178 else: 179 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 180 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error: 181 status_code = -1 182 raise RetryableError( 183 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}" 184 ) from error 185 except Exception as error: 186 status_code = -1 187 raise RuntimeError( 188 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 189 ) from error 190 finally: 191 elapsed_time = time.time() - start_time 192 self._metric_helper.record_duration( 193 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code 194 ) 195 if object_size: 196 self._metric_helper.record_object_size( 197 object_size, 198 provider=self._provider_name, 199 operation=operation, 200 bucket=bucket, 201 status_code=status_code, 202 ) 203 204 def _put_object( 205 self, 206 path: str, 207 body: bytes, 208 if_match: Optional[str] = None, 209 if_none_match: Optional[str] = None, 210 attributes: Optional[dict[str, str]] = None, 211 ) -> int: 212 bucket, key = split_path(path) 213 self._refresh_oci_client_if_needed() 214 215 # OCI only supports if_none_match=="*" 216 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object 217 def _invoke_api() -> int: 218 validated_attributes = validate_attributes(attributes) 219 self._oci_client.put_object( 220 namespace_name=self._namespace, 221 bucket_name=bucket, 222 object_name=key, 223 put_object_body=body, 224 opc_meta=validated_attributes or {}, # Pass metadata or empty dict 225 if_match=if_match, 226 if_none_match=if_none_match, 227 ) 228 229 return len(body) 230 231 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 232 233 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 234 bucket, key = split_path(path) 235 self._refresh_oci_client_if_needed() 236 237 def _invoke_api() -> bytes: 238 if byte_range: 239 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 240 else: 241 bytes_range = None 242 response = self._oci_client.get_object( 243 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range 244 ) 245 return response.data.content # pyright: ignore [reportOptionalMemberAccess] 246 247 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 248 249 def _copy_object(self, src_path: str, dest_path: str) -> int: 250 src_bucket, src_key = split_path(src_path) 251 dest_bucket, dest_key = split_path(dest_path) 252 self._refresh_oci_client_if_needed() 253 254 src_object = self._get_object_metadata(src_path) 255 256 def _invoke_api() -> int: 257 copy_details = oci.object_storage.models.CopyObjectDetails( 258 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key 259 ) 260 261 self._oci_client.copy_object( 262 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details 263 ) 264 265 return src_object.content_length 266 267 return self._collect_metrics( 268 _invoke_api, 269 operation="COPY", 270 bucket=src_bucket, 271 key=src_key, 272 put_object_size=src_object.content_length, 273 ) 274 275 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 276 bucket, key = split_path(path) 277 self._refresh_oci_client_if_needed() 278 279 def _invoke_api() -> None: 280 namespace_name = self._namespace 281 bucket_name = bucket 282 object_name = key 283 if if_match is not None: 284 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match) 285 else: 286 self._oci_client.delete_object(namespace_name, bucket_name, object_name) 287 288 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 289 290 def _is_dir(self, path: str) -> bool: 291 # Ensure the path ends with '/' to mimic a directory 292 path = self._append_delimiter(path) 293 294 bucket, key = split_path(path) 295 self._refresh_oci_client_if_needed() 296 297 def _invoke_api() -> bool: 298 # List objects with the given prefix 299 response = self._oci_client.list_objects( 300 namespace_name=self._namespace, 301 bucket_name=bucket, 302 prefix=key, 303 delimiter="/", 304 ) 305 # Check if there are any contents or common prefixes 306 if response: 307 return bool(response.data.objects or response.data.prefixes) 308 return False 309 310 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) 311 312 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 313 if path.endswith("/"): 314 # If path is a "directory", then metadata is not guaranteed to exist if 315 # it is a "virtual prefix" that was never explicitly created. 316 if self._is_dir(path): 317 return ObjectMetadata( 318 key=path, 319 type="directory", 320 content_length=0, 321 last_modified=AWARE_DATETIME_MIN, 322 ) 323 else: 324 raise FileNotFoundError(f"Directory {path} does not exist.") 325 else: 326 bucket, key = split_path(path) 327 self._refresh_oci_client_if_needed() 328 329 def _invoke_api() -> ObjectMetadata: 330 response = self._oci_client.head_object( 331 namespace_name=self._namespace, bucket_name=bucket, object_name=key 332 ) 333 return ObjectMetadata( 334 key=path, 335 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess] 336 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess] 337 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess] 338 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess] 339 ) 340 341 try: 342 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 343 except FileNotFoundError as error: 344 if strict: 345 # If the object does not exist on the given path, we will append a trailing slash and 346 # check if the path is a directory. 347 path = self._append_delimiter(path) 348 if self._is_dir(path): 349 return ObjectMetadata( 350 key=path, 351 type="directory", 352 content_length=0, 353 last_modified=AWARE_DATETIME_MIN, 354 ) 355 raise error 356 357 def _list_objects( 358 self, 359 prefix: str, 360 start_after: Optional[str] = None, 361 end_at: Optional[str] = None, 362 include_directories: bool = False, 363 ) -> Iterator[ObjectMetadata]: 364 bucket, prefix = split_path(prefix) 365 self._refresh_oci_client_if_needed() 366 367 def _invoke_api() -> Iterator[ObjectMetadata]: 368 # ListObjects only includes object names by default. 369 # 370 # Request additional fields needed for creating an ObjectMetadata. 371 fields = ",".join( 372 [ 373 "etag", 374 "name", 375 "size", 376 "timeModified", 377 ] 378 ) 379 next_start_with: Optional[str] = start_after 380 while True: 381 if include_directories: 382 response = self._oci_client.list_objects( 383 namespace_name=self._namespace, 384 bucket_name=bucket, 385 prefix=prefix, 386 # This is ≥ instead of >. 387 start=next_start_with, 388 delimiter="/", 389 fields=fields, 390 ) 391 else: 392 response = self._oci_client.list_objects( 393 namespace_name=self._namespace, 394 bucket_name=bucket, 395 prefix=prefix, 396 # This is ≥ instead of >. 397 start=next_start_with, 398 fields=fields, 399 ) 400 401 if not response: 402 return [] 403 404 if include_directories: 405 for directory in response.data.prefixes: 406 yield ObjectMetadata( 407 key=directory.rstrip("/"), 408 type="directory", 409 content_length=0, 410 last_modified=AWARE_DATETIME_MIN, 411 ) 412 413 # OCI guarantees lexicographical order. 414 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess] 415 key = response_object.name 416 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 417 if key.endswith("/"): 418 if include_directories: 419 yield ObjectMetadata( 420 key=os.path.join(bucket, key.rstrip("/")), 421 type="directory", 422 content_length=0, 423 last_modified=response_object.time_modified, 424 ) 425 else: 426 yield ObjectMetadata( 427 key=os.path.join(bucket, key), 428 type="file", 429 content_length=response_object.size, 430 last_modified=response_object.time_modified, 431 etag=response_object.etag, 432 ) 433 elif start_after != key: 434 return 435 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess] 436 if next_start_with is None or (end_at is not None and end_at < next_start_with): 437 return 438 439 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 440 441 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 442 bucket, key = split_path(remote_path) 443 file_size: int = 0 444 self._refresh_oci_client_if_needed() 445 446 if isinstance(f, str): 447 file_size = os.path.getsize(f) 448 449 def _invoke_api() -> int: 450 if file_size > self._multipart_threshold: 451 self._upload_manager.upload_file( 452 namespace_name=self._namespace, 453 bucket_name=bucket, 454 object_name=key, 455 file_path=f, 456 part_size=self._multipart_chunk_size, 457 allow_parallel_uploads=True, 458 ) 459 else: 460 self._upload_manager.upload_file( 461 namespace_name=self._namespace, bucket_name=bucket, object_name=key, file_path=f 462 ) 463 464 return file_size 465 466 return self._collect_metrics( 467 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 468 ) 469 else: 470 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 471 if isinstance(f, io.StringIO): 472 f = io.BytesIO(f.getvalue().encode("utf-8")) 473 474 f.seek(0, io.SEEK_END) 475 file_size = f.tell() 476 f.seek(0) 477 478 def _invoke_api() -> int: 479 if file_size > self._multipart_threshold: 480 self._upload_manager.upload_stream( 481 namespace_name=self._namespace, 482 bucket_name=bucket, 483 object_name=key, 484 stream_ref=f, 485 part_size=self._multipart_chunk_size, 486 allow_parallel_uploads=True, 487 ) 488 else: 489 self._upload_manager.upload_stream( 490 namespace_name=self._namespace, bucket_name=bucket, object_name=key, stream_ref=f 491 ) 492 493 return file_size 494 495 return self._collect_metrics( 496 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 497 ) 498 499 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 500 self._refresh_oci_client_if_needed() 501 502 if metadata is None: 503 metadata = self._get_object_metadata(remote_path) 504 505 bucket, key = split_path(remote_path) 506 507 if isinstance(f, str): 508 os.makedirs(os.path.dirname(f), exist_ok=True) 509 510 def _invoke_api() -> int: 511 response = self._oci_client.get_object( 512 namespace_name=self._namespace, bucket_name=bucket, object_name=key 513 ) 514 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 515 temp_file_path = fp.name 516 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 517 fp.write(chunk) 518 os.rename(src=temp_file_path, dst=f) 519 520 return metadata.content_length 521 522 return self._collect_metrics( 523 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 524 ) 525 else: 526 527 def _invoke_api() -> int: 528 response = self._oci_client.get_object( 529 namespace_name=self._namespace, bucket_name=bucket, object_name=key 530 ) 531 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 532 if isinstance(f, io.StringIO): 533 bytes_fileobj = io.BytesIO() 534 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 535 bytes_fileobj.write(chunk) 536 f.write(bytes_fileobj.getvalue().decode("utf-8")) 537 else: 538 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 539 f.write(chunk) 540 541 return metadata.content_length 542 543 return self._collect_metrics( 544 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 545 )