Source code for multistorageclient.providers.oci

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19from collections.abc import Callable, Iterator
 20from typing import IO, Any, Optional, TypeVar, Union
 21
 22import oci
 23from dateutil.parser import parse as dateutil_parser
 24from oci._vendor.requests.exceptions import (
 25    ChunkedEncodingError,
 26    ConnectionError,
 27    ContentDecodingError,
 28)
 29from oci.auth.signers import SecurityTokenSigner
 30from oci.exceptions import ServiceError
 31from oci.object_storage import ObjectStorageClient, UploadManager
 32from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
 33from oci.signer import load_private_key_from_file
 34
 35from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
 36from ..telemetry import Telemetry
 37from ..types import (
 38    AWARE_DATETIME_MIN,
 39    CredentialsProvider,
 40    ObjectMetadata,
 41    PreconditionFailedError,
 42    Range,
 43    RetryableError,
 44)
 45from ..utils import safe_makedirs, split_path, validate_attributes
 46from .base import BaseStorageProvider
 47
 48_T = TypeVar("_T")
 49
 50MB = 1024 * 1024
 51
 52MULTIPART_THRESHOLD = 64 * MB
 53MULTIPART_CHUNKSIZE = 32 * MB
 54
 55PROVIDER = "oci"
 56
 57
[docs] 58class OracleStorageProvider(BaseStorageProvider): 59 """ 60 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with 61 Oracle Cloud Infrastructure (OCI) Object Storage. 62 """ 63 64 _namespace: str 65 _credentials_provider: Optional[CredentialsProvider] 66 _oci_client: ObjectStorageClient 67 68 def __init__( 69 self, 70 namespace: str, 71 base_path: str = "", 72 credentials_provider: Optional[CredentialsProvider] = None, 73 retry_strategy: Optional[dict[str, Any]] = None, 74 config_dict: Optional[dict[str, Any]] = None, 75 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 76 **kwargs: Any, 77 ) -> None: 78 """ 79 Initializes an instance of :py:class:`OracleStorageProvider`. 80 81 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy. 82 :param base_path: The root prefix path within the bucket where all operations will be scoped. 83 :param credentials_provider: The provider to retrieve OCI credentials. 84 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters. 85 :param config_dict: Resolved MSC config. 86 :param telemetry_provider: A function that provides a telemetry instance. 87 """ 88 super().__init__( 89 base_path=base_path, 90 provider_name=PROVIDER, 91 config_dict=config_dict, 92 telemetry_provider=telemetry_provider, 93 ) 94 95 self._namespace = namespace 96 self._credentials_provider = credentials_provider 97 self._retry_strategy = ( 98 DEFAULT_RETRY_STRATEGY 99 if retry_strategy is None 100 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy() 101 ) 102 self._timeout = kwargs.get("timeout") 103 if self._timeout is None: 104 self._timeout = (DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT) 105 self._oci_client = self._create_oci_client() 106 self._upload_manager = UploadManager(self._oci_client) 107 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)) 108 self._multipart_chunksize = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)) 109 110 def _create_oci_client(self) -> ObjectStorageClient: 111 config = oci.config.from_file() 112 kwargs = {"retry_strategy": self._retry_strategy} 113 114 # OCI doesn't support `authentication_type=security_token` OCI config entries in their SDKs yet. Manually configure signers. 115 # 116 # https://github.com/oracle/oci-python-sdk/blob/v2.169.0/src/oci/util.py#L213-L225 117 # https://github.com/oracle/oci-ruby-sdk/issues/70 118 if "security_token_file" in config: 119 with open(config["security_token_file"], "r") as security_token_file: 120 kwargs["signer"] = SecurityTokenSigner( 121 private_key=load_private_key_from_file( 122 filename=config["key_file"], 123 pass_phrase=config.get("pass_phrase"), 124 ), 125 # The OCI documentation + CLI are unforgiving about newline-terminated security token files. 126 # 127 # Do not ignore trailing newlines in case future upstream automatic signer configuration does the same. 128 # 129 # https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm#sdk_authentication_methods_session_token 130 # https://github.com/oracle/oci-cli/blob/v3.77.0/src/oci_cli/cli_session.py#L143-L144 131 token=security_token_file.read(), 132 ) 133 134 client = ObjectStorageClient(config, **kwargs) 135 client.base_client.timeout = self._timeout 136 return client 137 138 def _refresh_oci_client_if_needed(self) -> None: 139 """ 140 Refreshes the OCI client if the current credentials are expired. 141 """ 142 if self._credentials_provider: 143 credentials = self._credentials_provider.get_credentials() 144 if credentials.is_expired(): 145 self._credentials_provider.refresh_credentials() 146 self._oci_client = self._create_oci_client() 147 self._upload_manager = UploadManager( 148 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4 149 ) 150 151 def _translate_errors( 152 self, 153 func: Callable[[], _T], 154 operation: str, 155 bucket: str, 156 key: str, 157 ) -> _T: 158 """ 159 Translates errors like timeouts and client errors. 160 161 :param func: The function that performs the actual object storage operation. 162 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 163 :param bucket: The name of the object storage bucket involved in the operation. 164 :param key: The key of the object within the object storage bucket. 165 166 :return: The result of the object storage operation, typically the return value of the `func` callable. 167 """ 168 try: 169 return func() 170 except ServiceError as error: 171 status_code = error.status 172 request_id = error.request_id 173 endpoint = error.request_endpoint 174 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}" 175 176 if status_code == 404: 177 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 178 elif status_code == 412: 179 raise PreconditionFailedError( 180 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 181 ) from error 182 elif status_code == 429: 183 raise RetryableError( 184 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 185 ) from error 186 else: 187 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 188 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error: 189 raise RetryableError( 190 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}" 191 ) from error 192 except FileNotFoundError: 193 raise 194 except Exception as error: 195 raise RuntimeError( 196 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 197 ) from error 198 199 def _put_object( 200 self, 201 path: str, 202 body: bytes, 203 if_match: Optional[str] = None, 204 if_none_match: Optional[str] = None, 205 attributes: Optional[dict[str, str]] = None, 206 ) -> int: 207 bucket, key = split_path(path) 208 self._refresh_oci_client_if_needed() 209 210 # OCI only supports if_none_match=="*" 211 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object 212 def _invoke_api() -> int: 213 validated_attributes = validate_attributes(attributes) 214 self._oci_client.put_object( 215 namespace_name=self._namespace, 216 bucket_name=bucket, 217 object_name=key, 218 put_object_body=body, 219 opc_meta=validated_attributes or {}, # Pass metadata or empty dict 220 if_match=if_match, 221 if_none_match=if_none_match, 222 ) 223 224 return len(body) 225 226 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 227 228 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 229 bucket, key = split_path(path) 230 self._refresh_oci_client_if_needed() 231 232 def _invoke_api() -> bytes: 233 if byte_range: 234 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 235 else: 236 bytes_range = None 237 response = self._oci_client.get_object( 238 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range 239 ) 240 return response.data.content # pyright: ignore [reportOptionalMemberAccess] 241 242 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 243 244 def _copy_object(self, src_path: str, dest_path: str) -> int: 245 src_bucket, src_key = split_path(src_path) 246 dest_bucket, dest_key = split_path(dest_path) 247 self._refresh_oci_client_if_needed() 248 249 src_object = self._get_object_metadata(src_path) 250 251 def _invoke_api() -> int: 252 copy_details = oci.object_storage.models.CopyObjectDetails( 253 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key 254 ) 255 256 self._oci_client.copy_object( 257 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details 258 ) 259 260 return src_object.content_length 261 262 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 263 264 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 265 bucket, key = split_path(path) 266 self._refresh_oci_client_if_needed() 267 268 def _invoke_api() -> None: 269 namespace_name = self._namespace 270 bucket_name = bucket 271 object_name = key 272 if if_match is not None: 273 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match) 274 else: 275 self._oci_client.delete_object(namespace_name, bucket_name, object_name) 276 277 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 278 279 def _is_dir(self, path: str) -> bool: 280 # Ensure the path ends with '/' to mimic a directory 281 path = self._append_delimiter(path) 282 283 bucket, key = split_path(path) 284 self._refresh_oci_client_if_needed() 285 286 def _invoke_api() -> bool: 287 # List objects with the given prefix 288 response = self._oci_client.list_objects( 289 namespace_name=self._namespace, 290 bucket_name=bucket, 291 prefix=key, 292 delimiter="/", 293 ) 294 # Check if there are any contents or common prefixes 295 if response: 296 return bool(response.data.objects or response.data.prefixes) 297 return False 298 299 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 300 301 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 302 bucket, key = split_path(path) 303 if path.endswith("/") or (bucket and not key): 304 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 305 # which metadata is not guaranteed to exist for cases such as 306 # "virtual prefix" that was never explicitly created. 307 if self._is_dir(path): 308 return ObjectMetadata( 309 key=path, 310 type="directory", 311 content_length=0, 312 last_modified=AWARE_DATETIME_MIN, 313 ) 314 else: 315 raise FileNotFoundError(f"Directory {path} does not exist.") 316 else: 317 self._refresh_oci_client_if_needed() 318 319 def _invoke_api() -> ObjectMetadata: 320 response = self._oci_client.head_object( 321 namespace_name=self._namespace, bucket_name=bucket, object_name=key 322 ) 323 324 # Extract custom metadata from headers with 'opc-meta-' prefix 325 attributes = {} 326 if response.headers: # pyright: ignore [reportOptionalMemberAccess] 327 for metadata_key, metadata_val in response.headers.items(): # pyright: ignore [reportOptionalMemberAccess] 328 if metadata_key.startswith("opc-meta-"): 329 # Remove the 'opc-meta-' prefix to get the original key 330 metadata_key = metadata_key[len("opc-meta-") :] 331 attributes[metadata_key] = metadata_val 332 333 return ObjectMetadata( 334 key=path, 335 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess] 336 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess] 337 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess] 338 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess] 339 metadata=attributes if attributes else None, 340 ) 341 342 try: 343 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 344 except FileNotFoundError as error: 345 if strict: 346 # If the object does not exist on the given path, we will append a trailing slash and 347 # check if the path is a directory. 348 path = self._append_delimiter(path) 349 if self._is_dir(path): 350 return ObjectMetadata( 351 key=path, 352 type="directory", 353 content_length=0, 354 last_modified=AWARE_DATETIME_MIN, 355 ) 356 raise error 357 358 def _list_objects( 359 self, 360 path: str, 361 start_after: Optional[str] = None, 362 end_at: Optional[str] = None, 363 include_directories: bool = False, 364 follow_symlinks: bool = True, 365 ) -> Iterator[ObjectMetadata]: 366 bucket, prefix = split_path(path) 367 self._refresh_oci_client_if_needed() 368 369 def _invoke_api() -> Iterator[ObjectMetadata]: 370 # ListObjects only includes object names by default. 371 # 372 # Request additional fields needed for creating an ObjectMetadata. 373 fields = ",".join( 374 [ 375 "etag", 376 "name", 377 "size", 378 "timeModified", 379 ] 380 ) 381 next_start_with: Optional[str] = start_after 382 while True: 383 if include_directories: 384 response = self._oci_client.list_objects( 385 namespace_name=self._namespace, 386 bucket_name=bucket, 387 prefix=prefix, 388 # This is ≥ instead of >. 389 start=next_start_with, 390 delimiter="/", 391 fields=fields, 392 ) 393 else: 394 response = self._oci_client.list_objects( 395 namespace_name=self._namespace, 396 bucket_name=bucket, 397 prefix=prefix, 398 # This is ≥ instead of >. 399 start=next_start_with, 400 fields=fields, 401 ) 402 403 if not response: 404 return [] 405 406 if include_directories: 407 for directory in response.data.prefixes: 408 prefix_key = directory.rstrip("/") 409 # Filter by start_after and end_at if specified 410 if (start_after is None or start_after < prefix_key) and ( 411 end_at is None or prefix_key <= end_at 412 ): 413 yield ObjectMetadata( 414 key=os.path.join(bucket, prefix_key), 415 type="directory", 416 content_length=0, 417 last_modified=AWARE_DATETIME_MIN, 418 ) 419 elif end_at is not None and end_at < prefix_key: 420 return 421 422 # OCI guarantees lexicographical order. 423 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess] 424 key = response_object.name 425 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 426 if key.endswith("/"): 427 if include_directories: 428 yield ObjectMetadata( 429 key=os.path.join(bucket, key.rstrip("/")), 430 type="directory", 431 content_length=0, 432 last_modified=response_object.time_modified, 433 ) 434 else: 435 yield ObjectMetadata( 436 key=os.path.join(bucket, key), 437 type="file", 438 content_length=response_object.size, 439 last_modified=response_object.time_modified, 440 etag=response_object.etag, 441 ) 442 elif start_after != key: 443 return 444 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess] 445 if next_start_with is None or (end_at is not None and end_at < next_start_with): 446 return 447 448 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 449 450 @property 451 def supports_parallel_listing(self) -> bool: 452 return True 453 454 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 455 bucket, key = split_path(remote_path) 456 file_size: int = 0 457 self._refresh_oci_client_if_needed() 458 459 validated_attributes = validate_attributes(attributes) 460 if isinstance(f, str): 461 file_size = os.path.getsize(f) 462 463 def _invoke_api() -> int: 464 if file_size > self._multipart_threshold: 465 self._upload_manager.upload_file( 466 namespace_name=self._namespace, 467 bucket_name=bucket, 468 object_name=key, 469 file_path=f, 470 part_size=self._multipart_chunksize, 471 allow_parallel_uploads=True, 472 metadata=validated_attributes or {}, 473 ) 474 else: 475 self._upload_manager.upload_file( 476 namespace_name=self._namespace, 477 bucket_name=bucket, 478 object_name=key, 479 file_path=f, 480 metadata=validated_attributes or {}, 481 ) 482 483 return file_size 484 485 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 486 else: 487 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 488 if isinstance(f, io.StringIO): 489 f = io.BytesIO(f.getvalue().encode("utf-8")) 490 491 f.seek(0, io.SEEK_END) 492 file_size = f.tell() 493 f.seek(0) 494 495 def _invoke_api() -> int: 496 if file_size > self._multipart_threshold: 497 self._upload_manager.upload_stream( 498 namespace_name=self._namespace, 499 bucket_name=bucket, 500 object_name=key, 501 stream_ref=f, 502 part_size=self._multipart_chunksize, 503 allow_parallel_uploads=True, 504 metadata=validated_attributes or {}, 505 ) 506 else: 507 self._upload_manager.upload_stream( 508 namespace_name=self._namespace, 509 bucket_name=bucket, 510 object_name=key, 511 stream_ref=f, 512 metadata=validated_attributes or {}, 513 ) 514 515 return file_size 516 517 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 518 519 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 520 self._refresh_oci_client_if_needed() 521 522 if metadata is None: 523 metadata = self._get_object_metadata(remote_path) 524 525 bucket, key = split_path(remote_path) 526 527 if isinstance(f, str): 528 if os.path.dirname(f): 529 safe_makedirs(os.path.dirname(f)) 530 531 def _invoke_api() -> int: 532 response = self._oci_client.get_object( 533 namespace_name=self._namespace, bucket_name=bucket, object_name=key 534 ) 535 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 536 temp_file_path = fp.name 537 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 538 fp.write(chunk) 539 os.rename(src=temp_file_path, dst=f) 540 541 return metadata.content_length 542 543 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 544 else: 545 546 def _invoke_api() -> int: 547 response = self._oci_client.get_object( 548 namespace_name=self._namespace, bucket_name=bucket, object_name=key 549 ) 550 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 551 if isinstance(f, io.StringIO): 552 bytes_fileobj = io.BytesIO() 553 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 554 bytes_fileobj.write(chunk) 555 f.write(bytes_fileobj.getvalue().decode("utf-8")) 556 else: 557 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 558 f.write(chunk) 559 560 return metadata.content_length 561 562 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)