Source code for multistorageclient.providers.oci

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19import time
 20from collections.abc import Callable, Iterator, Sequence, Sized
 21from typing import IO, Any, Optional, TypeVar, Union
 22
 23import oci
 24import opentelemetry.metrics as api_metrics
 25from dateutil.parser import parse as dateutil_parser
 26from oci._vendor.requests.exceptions import (
 27    ChunkedEncodingError,
 28    ConnectionError,
 29    ContentDecodingError,
 30)
 31from oci.exceptions import ServiceError
 32from oci.object_storage import ObjectStorageClient, UploadManager
 33from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
 34
 35from ..telemetry import Telemetry
 36from ..telemetry.attributes.base import AttributesProvider
 37from ..types import (
 38    AWARE_DATETIME_MIN,
 39    CredentialsProvider,
 40    ObjectMetadata,
 41    PreconditionFailedError,
 42    Range,
 43    RetryableError,
 44)
 45from ..utils import split_path, validate_attributes
 46from .base import BaseStorageProvider
 47
 48_T = TypeVar("_T")
 49
 50MB = 1024 * 1024
 51
 52MULTIPART_THRESHOLD = 512 * MB
 53MULTIPART_CHUNKSIZE = 256 * MB
 54
 55PROVIDER = "oci"
 56
 57
[docs] 58class OracleStorageProvider(BaseStorageProvider): 59 """ 60 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with 61 Oracle Cloud Infrastructure (OCI) Object Storage. 62 """ 63 64 def __init__( 65 self, 66 namespace: str, 67 base_path: str = "", 68 credentials_provider: Optional[CredentialsProvider] = None, 69 retry_strategy: Optional[dict[str, Any]] = None, 70 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {}, 71 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {}, 72 metric_attributes_providers: Sequence[AttributesProvider] = (), 73 **kwargs: Any, 74 ) -> None: 75 """ 76 Initializes an instance of :py:class:`OracleStorageProvider`. 77 78 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy. 79 :param base_path: The root prefix path within the bucket where all operations will be scoped. 80 :param credentials_provider: The provider to retrieve OCI credentials. 81 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters. 82 :param metric_counters: Metric counters. 83 :param metric_gauges: Metric gauges. 84 :param metric_attributes_providers: Metric attributes providers. 85 """ 86 super().__init__( 87 base_path=base_path, 88 provider_name=PROVIDER, 89 metric_counters=metric_counters, 90 metric_gauges=metric_gauges, 91 metric_attributes_providers=metric_attributes_providers, 92 ) 93 94 self._namespace = namespace 95 self._credentials_provider = credentials_provider 96 self._retry_strategy = ( 97 DEFAULT_RETRY_STRATEGY 98 if retry_strategy is None 99 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy() 100 ) 101 self._oci_client = self._create_oci_client() 102 self._upload_manager = UploadManager(self._oci_client) 103 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)) 104 self._multipart_chunksize = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)) 105 106 def _create_oci_client(self) -> ObjectStorageClient: 107 config = oci.config.from_file() 108 return ObjectStorageClient(config, retry_strategy=self._retry_strategy) 109 110 def _refresh_oci_client_if_needed(self) -> None: 111 """ 112 Refreshes the OCI client if the current credentials are expired. 113 """ 114 if self._credentials_provider: 115 credentials = self._credentials_provider.get_credentials() 116 if credentials.is_expired(): 117 self._credentials_provider.refresh_credentials() 118 self._oci_client = self._create_oci_client() 119 self._upload_manager = UploadManager( 120 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4 121 ) 122 123 def _collect_metrics( 124 self, 125 func: Callable[[], _T], 126 operation: str, 127 bucket: str, 128 key: str, 129 put_object_size: Optional[int] = None, 130 get_object_size: Optional[int] = None, 131 ) -> _T: 132 """ 133 Collects and records performance metrics around object storage operations such as PUT, GET, DELETE, etc. 134 135 This method wraps an object storage operation and measures the time it takes to complete, along with recording 136 the size of the object if applicable. It handles errors like timeouts and client errors and ensures 137 proper logging of duration and object size. 138 139 :param func: The function that performs the actual object storage operation. 140 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 141 :param bucket: The name of the object storage bucket involved in the operation. 142 :param key: The key of the object within the object storage bucket. 143 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations). 144 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations). 145 146 :return: The result of the object storage operation, typically the return value of the `func` callable. 147 """ 148 start_time = time.time() 149 status_code = 200 150 151 object_size = None 152 if operation == "PUT": 153 object_size = put_object_size 154 elif operation == "GET" and get_object_size: 155 object_size = get_object_size 156 157 try: 158 result = func() 159 if operation == "GET" and object_size is None and isinstance(result, Sized): 160 object_size = len(result) 161 return result 162 except ServiceError as error: 163 status_code = error.status 164 request_id = error.request_id 165 endpoint = error.request_endpoint 166 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}" 167 168 if status_code == 404: 169 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 170 elif status_code == 412: 171 raise PreconditionFailedError( 172 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 173 ) from error 174 elif status_code == 429: 175 raise RetryableError( 176 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 177 ) from error 178 else: 179 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 180 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error: 181 status_code = -1 182 raise RetryableError( 183 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}" 184 ) from error 185 except Exception as error: 186 status_code = -1 187 raise RuntimeError( 188 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 189 ) from error 190 finally: 191 elapsed_time = time.time() - start_time 192 self._metric_helper.record_duration( 193 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code 194 ) 195 if object_size: 196 self._metric_helper.record_object_size( 197 object_size, 198 provider=self._provider_name, 199 operation=operation, 200 bucket=bucket, 201 status_code=status_code, 202 ) 203 204 def _put_object( 205 self, 206 path: str, 207 body: bytes, 208 if_match: Optional[str] = None, 209 if_none_match: Optional[str] = None, 210 attributes: Optional[dict[str, str]] = None, 211 ) -> int: 212 bucket, key = split_path(path) 213 self._refresh_oci_client_if_needed() 214 215 # OCI only supports if_none_match=="*" 216 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object 217 def _invoke_api() -> int: 218 validated_attributes = validate_attributes(attributes) 219 self._oci_client.put_object( 220 namespace_name=self._namespace, 221 bucket_name=bucket, 222 object_name=key, 223 put_object_body=body, 224 opc_meta=validated_attributes or {}, # Pass metadata or empty dict 225 if_match=if_match, 226 if_none_match=if_none_match, 227 ) 228 229 return len(body) 230 231 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body)) 232 233 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 234 bucket, key = split_path(path) 235 self._refresh_oci_client_if_needed() 236 237 def _invoke_api() -> bytes: 238 if byte_range: 239 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 240 else: 241 bytes_range = None 242 response = self._oci_client.get_object( 243 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range 244 ) 245 return response.data.content # pyright: ignore [reportOptionalMemberAccess] 246 247 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key) 248 249 def _copy_object(self, src_path: str, dest_path: str) -> int: 250 src_bucket, src_key = split_path(src_path) 251 dest_bucket, dest_key = split_path(dest_path) 252 self._refresh_oci_client_if_needed() 253 254 src_object = self._get_object_metadata(src_path) 255 256 def _invoke_api() -> int: 257 copy_details = oci.object_storage.models.CopyObjectDetails( 258 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key 259 ) 260 261 self._oci_client.copy_object( 262 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details 263 ) 264 265 return src_object.content_length 266 267 return self._collect_metrics( 268 _invoke_api, 269 operation="COPY", 270 bucket=src_bucket, 271 key=src_key, 272 put_object_size=src_object.content_length, 273 ) 274 275 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 276 bucket, key = split_path(path) 277 self._refresh_oci_client_if_needed() 278 279 def _invoke_api() -> None: 280 namespace_name = self._namespace 281 bucket_name = bucket 282 object_name = key 283 if if_match is not None: 284 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match) 285 else: 286 self._oci_client.delete_object(namespace_name, bucket_name, object_name) 287 288 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key) 289 290 def _is_dir(self, path: str) -> bool: 291 # Ensure the path ends with '/' to mimic a directory 292 path = self._append_delimiter(path) 293 294 bucket, key = split_path(path) 295 self._refresh_oci_client_if_needed() 296 297 def _invoke_api() -> bool: 298 # List objects with the given prefix 299 response = self._oci_client.list_objects( 300 namespace_name=self._namespace, 301 bucket_name=bucket, 302 prefix=key, 303 delimiter="/", 304 ) 305 # Check if there are any contents or common prefixes 306 if response: 307 return bool(response.data.objects or response.data.prefixes) 308 return False 309 310 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key) 311 312 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 313 bucket, key = split_path(path) 314 if path.endswith("/") or (bucket and not key): 315 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 316 # which metadata is not guaranteed to exist for cases such as 317 # "virtual prefix" that was never explicitly created. 318 if self._is_dir(path): 319 return ObjectMetadata( 320 key=path, 321 type="directory", 322 content_length=0, 323 last_modified=AWARE_DATETIME_MIN, 324 ) 325 else: 326 raise FileNotFoundError(f"Directory {path} does not exist.") 327 else: 328 self._refresh_oci_client_if_needed() 329 330 def _invoke_api() -> ObjectMetadata: 331 response = self._oci_client.head_object( 332 namespace_name=self._namespace, bucket_name=bucket, object_name=key 333 ) 334 335 # Extract custom metadata from headers with 'opc-meta-' prefix 336 attributes = {} 337 if response.headers: # pyright: ignore [reportOptionalMemberAccess] 338 for metadata_key, metadata_val in response.headers.items(): # pyright: ignore [reportOptionalMemberAccess] 339 if metadata_key.startswith("opc-meta-"): 340 # Remove the 'opc-meta-' prefix to get the original key 341 metadata_key = metadata_key[len("opc-meta-") :] 342 attributes[metadata_key] = metadata_val 343 344 return ObjectMetadata( 345 key=path, 346 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess] 347 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess] 348 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess] 349 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess] 350 metadata=attributes if attributes else None, 351 ) 352 353 try: 354 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key) 355 except FileNotFoundError as error: 356 if strict: 357 # If the object does not exist on the given path, we will append a trailing slash and 358 # check if the path is a directory. 359 path = self._append_delimiter(path) 360 if self._is_dir(path): 361 return ObjectMetadata( 362 key=path, 363 type="directory", 364 content_length=0, 365 last_modified=AWARE_DATETIME_MIN, 366 ) 367 raise error 368 369 def _list_objects( 370 self, 371 path: str, 372 start_after: Optional[str] = None, 373 end_at: Optional[str] = None, 374 include_directories: bool = False, 375 ) -> Iterator[ObjectMetadata]: 376 bucket, prefix = split_path(path) 377 self._refresh_oci_client_if_needed() 378 379 def _invoke_api() -> Iterator[ObjectMetadata]: 380 # ListObjects only includes object names by default. 381 # 382 # Request additional fields needed for creating an ObjectMetadata. 383 fields = ",".join( 384 [ 385 "etag", 386 "name", 387 "size", 388 "timeModified", 389 ] 390 ) 391 next_start_with: Optional[str] = start_after 392 while True: 393 if include_directories: 394 response = self._oci_client.list_objects( 395 namespace_name=self._namespace, 396 bucket_name=bucket, 397 prefix=prefix, 398 # This is ≥ instead of >. 399 start=next_start_with, 400 delimiter="/", 401 fields=fields, 402 ) 403 else: 404 response = self._oci_client.list_objects( 405 namespace_name=self._namespace, 406 bucket_name=bucket, 407 prefix=prefix, 408 # This is ≥ instead of >. 409 start=next_start_with, 410 fields=fields, 411 ) 412 413 if not response: 414 return [] 415 416 if include_directories: 417 for directory in response.data.prefixes: 418 yield ObjectMetadata( 419 key=directory.rstrip("/"), 420 type="directory", 421 content_length=0, 422 last_modified=AWARE_DATETIME_MIN, 423 ) 424 425 # OCI guarantees lexicographical order. 426 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess] 427 key = response_object.name 428 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 429 if key.endswith("/"): 430 if include_directories: 431 yield ObjectMetadata( 432 key=os.path.join(bucket, key.rstrip("/")), 433 type="directory", 434 content_length=0, 435 last_modified=response_object.time_modified, 436 ) 437 else: 438 yield ObjectMetadata( 439 key=os.path.join(bucket, key), 440 type="file", 441 content_length=response_object.size, 442 last_modified=response_object.time_modified, 443 etag=response_object.etag, 444 ) 445 elif start_after != key: 446 return 447 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess] 448 if next_start_with is None or (end_at is not None and end_at < next_start_with): 449 return 450 451 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 452 453 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 454 bucket, key = split_path(remote_path) 455 file_size: int = 0 456 self._refresh_oci_client_if_needed() 457 458 validated_attributes = validate_attributes(attributes) 459 if isinstance(f, str): 460 file_size = os.path.getsize(f) 461 462 def _invoke_api() -> int: 463 if file_size > self._multipart_threshold: 464 self._upload_manager.upload_file( 465 namespace_name=self._namespace, 466 bucket_name=bucket, 467 object_name=key, 468 file_path=f, 469 part_size=self._multipart_chunksize, 470 allow_parallel_uploads=True, 471 metadata=validated_attributes or {}, 472 ) 473 else: 474 self._upload_manager.upload_file( 475 namespace_name=self._namespace, 476 bucket_name=bucket, 477 object_name=key, 478 file_path=f, 479 metadata=validated_attributes or {}, 480 ) 481 482 return file_size 483 484 return self._collect_metrics( 485 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 486 ) 487 else: 488 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 489 if isinstance(f, io.StringIO): 490 f = io.BytesIO(f.getvalue().encode("utf-8")) 491 492 f.seek(0, io.SEEK_END) 493 file_size = f.tell() 494 f.seek(0) 495 496 def _invoke_api() -> int: 497 if file_size > self._multipart_threshold: 498 self._upload_manager.upload_stream( 499 namespace_name=self._namespace, 500 bucket_name=bucket, 501 object_name=key, 502 stream_ref=f, 503 part_size=self._multipart_chunksize, 504 allow_parallel_uploads=True, 505 metadata=validated_attributes or {}, 506 ) 507 else: 508 self._upload_manager.upload_stream( 509 namespace_name=self._namespace, 510 bucket_name=bucket, 511 object_name=key, 512 stream_ref=f, 513 metadata=validated_attributes or {}, 514 ) 515 516 return file_size 517 518 return self._collect_metrics( 519 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size 520 ) 521 522 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 523 self._refresh_oci_client_if_needed() 524 525 if metadata is None: 526 metadata = self._get_object_metadata(remote_path) 527 528 bucket, key = split_path(remote_path) 529 530 if isinstance(f, str): 531 if os.path.dirname(f): 532 os.makedirs(os.path.dirname(f), exist_ok=True) 533 534 def _invoke_api() -> int: 535 response = self._oci_client.get_object( 536 namespace_name=self._namespace, bucket_name=bucket, object_name=key 537 ) 538 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 539 temp_file_path = fp.name 540 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 541 fp.write(chunk) 542 os.rename(src=temp_file_path, dst=f) 543 544 return metadata.content_length 545 546 return self._collect_metrics( 547 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 548 ) 549 else: 550 551 def _invoke_api() -> int: 552 response = self._oci_client.get_object( 553 namespace_name=self._namespace, bucket_name=bucket, object_name=key 554 ) 555 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 556 if isinstance(f, io.StringIO): 557 bytes_fileobj = io.BytesIO() 558 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 559 bytes_fileobj.write(chunk) 560 f.write(bytes_fileobj.getvalue().decode("utf-8")) 561 else: 562 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 563 f.write(chunk) 564 565 return metadata.content_length 566 567 return self._collect_metrics( 568 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length 569 )