Source code for multistorageclient.providers.oci

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19import time
 20from collections.abc import Callable, Iterator, Sequence, Sized
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23import oci
 24import opentelemetry.metrics as api_metrics
 25from dateutil.parser import parse as dateutil_parser
 26from oci._vendor.requests.exceptions import (
 27    ChunkedEncodingError,
 28    ConnectionError,
 29    ContentDecodingError,
 30)
 31from oci.exceptions import ServiceError
 32from oci.object_storage import ObjectStorageClient, UploadManager
 33from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
 34
 35from ..telemetry import Telemetry
 36from ..telemetry.attributes.base import AttributesProvider
 37from ..types import (
 38    AWARE_DATETIME_MIN,
 39    CredentialsProvider,
 40    ObjectMetadata,
 41    PreconditionFailedError,
 42    Range,
 43    RetryableError,
 44)
 45from ..utils import split_path, validate_attributes
 46from .base import BaseStorageProvider
 47
 48_T = TypeVar("_T")
 49
 50MB = 1024 * 1024
 51
 52MULTIPART_THRESHOLD = 512 * MB
 53MULTIPART_CHUNK_SIZE = 256 * MB
 54
 55PROVIDER = "oci"
 56
 57
[docs] 58class OracleStorageProvider(BaseStorageProvider): 59 """ 60 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with 61 Oracle Cloud Infrastructure (OCI) Object Storage. 62 """ 63 64 def __init__( 65 self, 66 namespace: str, 67 base_path: str = "", 68 credentials_provider: Optional[CredentialsProvider] = None, 69 retry_strategy: Optional[dict[str, Any]] = None, 70 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {}, 71 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {}, 72 metric_attributes_providers: Sequence[AttributesProvider] = (), 73 **kwargs: Any, 74 ) -> None: 75 """ 76 Initializes an instance of :py:class:`OracleStorageProvider`. 77 78 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy. 79 :param base_path: The root prefix path within the bucket where all operations will be scoped. 80 :param credentials_provider: The provider to retrieve OCI credentials. 81 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters. 82 :param metric_counters: Metric counters. 83 :param metric_gauges: Metric gauges. 84 :param metric_attributes_providers: Metric attributes providers. 85 """ 86 super().__init__( 87 base_path=base_path, 88 provider_name=PROVIDER, 89 metric_counters=metric_counters, 90 metric_gauges=metric_gauges, 91 metric_attributes_providers=metric_attributes_providers, 92 ) 93 94 self._namespace = namespace 95 self._credentials_provider = credentials_provider 96 self._retry_strategy = ( 97 DEFAULT_RETRY_STRATEGY 98 if retry_strategy is None 99 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy() 100 ) 101 self._oci_client = self._create_oci_client() 102 self._upload_manager = UploadManager(self._oci_client) 103 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)) 104 self._multipart_chunk_size = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNK_SIZE)) 105 106 def _create_oci_client(self) -> ObjectStorageClient: 107 config = oci.config.from_file() 108 return ObjectStorageClient(config, retry_strategy=self._retry_strategy) 109 110 def _refresh_oci_client_if_needed(self) -> None: 111 """ 112 Refreshes the OCI client if the current credentials are expired. 113 """ 114 if self._credentials_provider: 115 credentials = self._credentials_provider.get_credentials() 116 if credentials.is_expired(): 117 self._credentials_provider.refresh_credentials() 118 self._oci_client = self._create_oci_client() 119 self._upload_manager = UploadManager( 120 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4 121 ) 122 123 def _collect_metrics( 124 self, 125 func: Callable[[], _T], 126 operation: str, 127 bucket: str, 128 key: str, 129 put_object_size: Optional[int] = None, 130 get_object_size: Optional[int] = None, 131 ) -> _T: 132 """ 133 Collects and records performance metrics around object storage operations such as PUT, GET, DELETE, etc. 134 135 This method wraps an object storage operation and measures the time it takes to complete, along with recording 136 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 137 proper logging of duration and object size. 138 139 :param func: The function that performs the actual object storage operation. 140 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 141 :param bucket: The name of the object storage bucket involved in the operation. 142 :param key: The key of the object within the object storage bucket. 143 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 144 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 145 146 :return: The result of the object storage operation, typically the return value of the `func` callable. 147 """ 148 start_time = time.time() 149 status_code = 200 150 151 object_size = None 152 if operation == "PUT": 153 object_size = put_object_size 154 elif operation == "GET" and get_object_size: 155 object_size = get_object_size 156 157 try: 158 result = func() 159 if operation == "GET" and object_size is None and isinstance(result, Sized): 160 object_size = len(result) 161 return result 162 except ServiceError as error: 163 status_code = error.status 164 request_id = error.request_id 165 endpoint = error.request_endpoint 166 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}" 167 168 if status_code == 404: 169 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 170 elif status_code == 412: 171 raise PreconditionFailedError( 172 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 173 ) from error 174 elif status_code == 429: 175 raise RetryableError( 176 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 177 ) from error 178 else: 179 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 180 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error: 181 status_code = -1 182 raise RetryableError( 183 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}" 184 ) from error 185 except Exception as error: 186 status_code = -1 187 raise RuntimeError( 188 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 189 ) from error 190 finally: 191 elapsed_time = time.time() - start_time 192 self._metric_helper.record_duration( 193 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code 194 ) 195 if object_size: 196 self._metric_helper.record_object_size( 197 object_size, 198 provider=self._provider_name, 199 operation=operation, 200 bucket=bucket, 201 status_code=status_code, 202 ) 203 204 def _put_object( 205 self, 206 path: str, 207 body: bytes, 208 if_match: Optional[str] = None, 209 if_none_match: Optional[str] = None, 210 attributes: Optional[dict[str, str]] = None, 211 ) -> int: 212 bucket, key = split_path(path) 213 self._refresh_oci_client_if_needed() 214 215 # OCI only supports if_none_match=="*" 216 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object 217 def _invoke_api() -> int: 218 validated_attributes = validate_attributes(attributes) 219 self._oci_client.put_object( 220 namespace_name=self._namespace, 221 bucket_name=bucket, 222 object_name=key, 223 put_object_body=body, 224 opc_meta=validated_attributes or {}, # Pass metadata or empty dict 225 if_match=if_match, 226 if_none_match=if_none_match, 227 ) 228 229 return len(body) 230 231 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 232 233 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 234 bucket, key = split_path(path) 235 self._refresh_oci_client_if_needed() 236 237 def _invoke_api() -> bytes: 238 if byte_range: 239 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 240 else: 241 bytes_range = None 242 response = self._oci_client.get_object( 243 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range 244 ) 245 return response.data.content # pyright: ignore [reportOptionalMemberAccess] 246 247 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 248 249 def _copy_object(self, src_path: str, dest_path: str) -> int: 250 src_bucket, src_key = split_path(src_path) 251 dest_bucket, dest_key = split_path(dest_path) 252 self._refresh_oci_client_if_needed() 253 254 src_object = self._get_object_metadata(src_path) 255 256 def _invoke_api() -> int: 257 copy_details = oci.object_storage.models.CopyObjectDetails( 258 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key 259 ) 260 261 self._oci_client.copy_object( 262 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details 263 ) 264 265 return src_object.content_length 266 267 return self._collect_metrics( 268 _invoke_api, 269 operation="COPY", 270 bucket=src_bucket, 271 key=src_key, 272 put_object_size=src_object.content_length, 273 ) 274 275 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 276 bucket, key = split_path(path) 277 self._refresh_oci_client_if_needed() 278 279 def _invoke_api() -> None: 280 namespace_name = self._namespace 281 bucket_name = bucket 282 object_name = key 283 if if_match is not None: 284 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match) 285 else: 286 self._oci_client.delete_object(namespace_name, bucket_name, object_name) 287 288 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 289 290 def _is_dir(self, path: str) -> bool: 291 # Ensure the path ends with '/' to mimic a directory 292 path = self._append_delimiter(path) 293 294 bucket, key = split_path(path) 295 self._refresh_oci_client_if_needed() 296 297 def _invoke_api() -> bool: 298 # List objects with the given prefix 299 response = self._oci_client.list_objects( 300 namespace_name=self._namespace, 301 bucket_name=bucket, 302 prefix=key, 303 delimiter="/", 304 ) 305 # Check if there are any contents or common prefixes 306 if response: 307 return bool(response.data.objects or response.data.prefixes) 308 return False 309 310 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) 311 312 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 313 if path.endswith("/"): 314 # If path is a "directory", then metadata is not guaranteed to exist if 315 # it is a "virtual prefix" that was never explicitly created. 316 if self._is_dir(path): 317 return ObjectMetadata( 318 key=path, 319 type="directory", 320 content_length=0, 321 last_modified=AWARE_DATETIME_MIN, 322 ) 323 else: 324 raise FileNotFoundError(f"Directory {path} does not exist.") 325 else: 326 bucket, key = split_path(path) 327 self._refresh_oci_client_if_needed() 328 329 def _invoke_api() -> ObjectMetadata: 330 response = self._oci_client.head_object( 331 namespace_name=self._namespace, bucket_name=bucket, object_name=key 332 ) 333 334 # Extract custom metadata from headers with 'opc-meta-' prefix 335 attributes = {} 336 if response.headers: # pyright: ignore [reportOptionalMemberAccess] 337 for metadata_key, metadata_val in response.headers.items(): # pyright: ignore [reportOptionalMemberAccess] 338 if metadata_key.startswith("opc-meta-"): 339 # Remove the 'opc-meta-' prefix to get the original key 340 metadata_key = metadata_key[len("opc-meta-") :] 341 attributes[metadata_key] = metadata_val 342 343 return ObjectMetadata( 344 key=path, 345 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess] 346 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess] 347 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess] 348 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess] 349 metadata=attributes if attributes else None, 350 ) 351 352 try: 353 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 354 except FileNotFoundError as error: 355 if strict: 356 # If the object does not exist on the given path, we will append a trailing slash and 357 # check if the path is a directory. 358 path = self._append_delimiter(path) 359 if self._is_dir(path): 360 return ObjectMetadata( 361 key=path, 362 type="directory", 363 content_length=0, 364 last_modified=AWARE_DATETIME_MIN, 365 ) 366 raise error 367 368 def _list_objects( 369 self, 370 prefix: str, 371 start_after: Optional[str] = None, 372 end_at: Optional[str] = None, 373 include_directories: bool = False, 374 ) -> Iterator[ObjectMetadata]: 375 bucket, prefix = split_path(prefix) 376 self._refresh_oci_client_if_needed() 377 378 def _invoke_api() -> Iterator[ObjectMetadata]: 379 # ListObjects only includes object names by default. 380 # 381 # Request additional fields needed for creating an ObjectMetadata. 382 fields = ",".join( 383 [ 384 "etag", 385 "name", 386 "size", 387 "timeModified", 388 ] 389 ) 390 next_start_with: Optional[str] = start_after 391 while True: 392 if include_directories: 393 response = self._oci_client.list_objects( 394 namespace_name=self._namespace, 395 bucket_name=bucket, 396 prefix=prefix, 397 # This is ≥ instead of >. 398 start=next_start_with, 399 delimiter="/", 400 fields=fields, 401 ) 402 else: 403 response = self._oci_client.list_objects( 404 namespace_name=self._namespace, 405 bucket_name=bucket, 406 prefix=prefix, 407 # This is ≥ instead of >. 408 start=next_start_with, 409 fields=fields, 410 ) 411 412 if not response: 413 return [] 414 415 if include_directories: 416 for directory in response.data.prefixes: 417 yield ObjectMetadata( 418 key=directory.rstrip("/"), 419 type="directory", 420 content_length=0, 421 last_modified=AWARE_DATETIME_MIN, 422 ) 423 424 # OCI guarantees lexicographical order. 425 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess] 426 key = response_object.name 427 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 428 if key.endswith("/"): 429 if include_directories: 430 yield ObjectMetadata( 431 key=os.path.join(bucket, key.rstrip("/")), 432 type="directory", 433 content_length=0, 434 last_modified=response_object.time_modified, 435 ) 436 else: 437 yield ObjectMetadata( 438 key=os.path.join(bucket, key), 439 type="file", 440 content_length=response_object.size, 441 last_modified=response_object.time_modified, 442 etag=response_object.etag, 443 ) 444 elif start_after != key: 445 return 446 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess] 447 if next_start_with is None or (end_at is not None and end_at < next_start_with): 448 return 449 450 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 451 452 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 453 bucket, key = split_path(remote_path) 454 file_size: int = 0 455 self._refresh_oci_client_if_needed() 456 457 validated_attributes = validate_attributes(attributes) 458 if isinstance(f, str): 459 file_size = os.path.getsize(f) 460 461 def _invoke_api() -> int: 462 if file_size > self._multipart_threshold: 463 self._upload_manager.upload_file( 464 namespace_name=self._namespace, 465 bucket_name=bucket, 466 object_name=key, 467 file_path=f, 468 part_size=self._multipart_chunk_size, 469 allow_parallel_uploads=True, 470 metadata=validated_attributes or {}, 471 ) 472 else: 473 self._upload_manager.upload_file( 474 namespace_name=self._namespace, 475 bucket_name=bucket, 476 object_name=key, 477 file_path=f, 478 metadata=validated_attributes or {}, 479 ) 480 481 return file_size 482 483 return self._collect_metrics( 484 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 485 ) 486 else: 487 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 488 if isinstance(f, io.StringIO): 489 f = io.BytesIO(f.getvalue().encode("utf-8")) 490 491 f.seek(0, io.SEEK_END) 492 file_size = f.tell() 493 f.seek(0) 494 495 def _invoke_api() -> int: 496 if file_size > self._multipart_threshold: 497 self._upload_manager.upload_stream( 498 namespace_name=self._namespace, 499 bucket_name=bucket, 500 object_name=key, 501 stream_ref=f, 502 part_size=self._multipart_chunk_size, 503 allow_parallel_uploads=True, 504 metadata=validated_attributes or {}, 505 ) 506 else: 507 self._upload_manager.upload_stream( 508 namespace_name=self._namespace, 509 bucket_name=bucket, 510 object_name=key, 511 stream_ref=f, 512 metadata=validated_attributes or {}, 513 ) 514 515 return file_size 516 517 return self._collect_metrics( 518 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 519 ) 520 521 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 522 self._refresh_oci_client_if_needed() 523 524 if metadata is None: 525 metadata = self._get_object_metadata(remote_path) 526 527 bucket, key = split_path(remote_path) 528 529 if isinstance(f, str): 530 os.makedirs(os.path.dirname(f), exist_ok=True) 531 532 def _invoke_api() -> int: 533 response = self._oci_client.get_object( 534 namespace_name=self._namespace, bucket_name=bucket, object_name=key 535 ) 536 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 537 temp_file_path = fp.name 538 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 539 fp.write(chunk) 540 os.rename(src=temp_file_path, dst=f) 541 542 return metadata.content_length 543 544 return self._collect_metrics( 545 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 546 ) 547 else: 548 549 def _invoke_api() -> int: 550 response = self._oci_client.get_object( 551 namespace_name=self._namespace, bucket_name=bucket, object_name=key 552 ) 553 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 554 if isinstance(f, io.StringIO): 555 bytes_fileobj = io.BytesIO() 556 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 557 bytes_fileobj.write(chunk) 558 f.write(bytes_fileobj.getvalue().decode("utf-8")) 559 else: 560 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 561 f.write(chunk) 562 563 return metadata.content_length 564 565 return self._collect_metrics( 566 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 567 )