Source code for multistorageclient.providers.oci

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19from collections.abc import Callable, Iterator
 20from typing import IO, Any, Optional, TypeVar, Union
 21
 22import oci
 23from dateutil.parser import parse as dateutil_parser
 24from oci._vendor.requests.exceptions import (
 25    ChunkedEncodingError,
 26    ConnectionError,
 27    ContentDecodingError,
 28)
 29from oci.exceptions import ServiceError
 30from oci.object_storage import ObjectStorageClient, UploadManager
 31from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
 32
 33from ..telemetry import Telemetry
 34from ..types import (
 35    AWARE_DATETIME_MIN,
 36    CredentialsProvider,
 37    ObjectMetadata,
 38    PreconditionFailedError,
 39    Range,
 40    RetryableError,
 41)
 42from ..utils import split_path, validate_attributes
 43from .base import BaseStorageProvider
 44
 45_T = TypeVar("_T")
 46
 47MB = 1024 * 1024
 48
 49MULTIPART_THRESHOLD = 64 * MB
 50MULTIPART_CHUNKSIZE = 32 * MB
 51
 52PROVIDER = "oci"
 53
 54
[docs] 55class OracleStorageProvider(BaseStorageProvider): 56 """ 57 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with 58 Oracle Cloud Infrastructure (OCI) Object Storage. 59 """ 60 61 def __init__( 62 self, 63 namespace: str, 64 base_path: str = "", 65 credentials_provider: Optional[CredentialsProvider] = None, 66 retry_strategy: Optional[dict[str, Any]] = None, 67 config_dict: Optional[dict[str, Any]] = None, 68 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 69 **kwargs: Any, 70 ) -> None: 71 """ 72 Initializes an instance of :py:class:`OracleStorageProvider`. 73 74 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy. 75 :param base_path: The root prefix path within the bucket where all operations will be scoped. 76 :param credentials_provider: The provider to retrieve OCI credentials. 77 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters. 78 :param config_dict: Resolved MSC config. 79 :param telemetry_provider: A function that provides a telemetry instance. 80 """ 81 super().__init__( 82 base_path=base_path, 83 provider_name=PROVIDER, 84 config_dict=config_dict, 85 telemetry_provider=telemetry_provider, 86 ) 87 88 self._namespace = namespace 89 self._credentials_provider = credentials_provider 90 self._retry_strategy = ( 91 DEFAULT_RETRY_STRATEGY 92 if retry_strategy is None 93 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy() 94 ) 95 self._oci_client = self._create_oci_client() 96 self._upload_manager = UploadManager(self._oci_client) 97 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)) 98 self._multipart_chunksize = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)) 99 100 def _create_oci_client(self) -> ObjectStorageClient: 101 config = oci.config.from_file() 102 return ObjectStorageClient(config, retry_strategy=self._retry_strategy) 103 104 def _refresh_oci_client_if_needed(self) -> None: 105 """ 106 Refreshes the OCI client if the current credentials are expired. 107 """ 108 if self._credentials_provider: 109 credentials = self._credentials_provider.get_credentials() 110 if credentials.is_expired(): 111 self._credentials_provider.refresh_credentials() 112 self._oci_client = self._create_oci_client() 113 self._upload_manager = UploadManager( 114 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4 115 ) 116 117 def _translate_errors( 118 self, 119 func: Callable[[], _T], 120 operation: str, 121 bucket: str, 122 key: str, 123 ) -> _T: 124 """ 125 Translates errors like timeouts and client errors. 126 127 :param func: The function that performs the actual object storage operation. 128 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 129 :param bucket: The name of the object storage bucket involved in the operation. 130 :param key: The key of the object within the object storage bucket. 131 132 :return: The result of the object storage operation, typically the return value of the `func` callable. 133 """ 134 try: 135 return func() 136 except ServiceError as error: 137 status_code = error.status 138 request_id = error.request_id 139 endpoint = error.request_endpoint 140 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}" 141 142 if status_code == 404: 143 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 144 elif status_code == 412: 145 raise PreconditionFailedError( 146 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 147 ) from error 148 elif status_code == 429: 149 raise RetryableError( 150 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 151 ) from error 152 else: 153 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 154 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error: 155 raise RetryableError( 156 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}" 157 ) from error 158 except Exception as error: 159 raise RuntimeError( 160 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 161 ) from error 162 163 def _put_object( 164 self, 165 path: str, 166 body: bytes, 167 if_match: Optional[str] = None, 168 if_none_match: Optional[str] = None, 169 attributes: Optional[dict[str, str]] = None, 170 ) -> int: 171 bucket, key = split_path(path) 172 self._refresh_oci_client_if_needed() 173 174 # OCI only supports if_none_match=="*" 175 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object 176 def _invoke_api() -> int: 177 validated_attributes = validate_attributes(attributes) 178 self._oci_client.put_object( 179 namespace_name=self._namespace, 180 bucket_name=bucket, 181 object_name=key, 182 put_object_body=body, 183 opc_meta=validated_attributes or {}, # Pass metadata or empty dict 184 if_match=if_match, 185 if_none_match=if_none_match, 186 ) 187 188 return len(body) 189 190 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 191 192 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 193 bucket, key = split_path(path) 194 self._refresh_oci_client_if_needed() 195 196 def _invoke_api() -> bytes: 197 if byte_range: 198 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 199 else: 200 bytes_range = None 201 response = self._oci_client.get_object( 202 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range 203 ) 204 return response.data.content # pyright: ignore [reportOptionalMemberAccess] 205 206 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 207 208 def _copy_object(self, src_path: str, dest_path: str) -> int: 209 src_bucket, src_key = split_path(src_path) 210 dest_bucket, dest_key = split_path(dest_path) 211 self._refresh_oci_client_if_needed() 212 213 src_object = self._get_object_metadata(src_path) 214 215 def _invoke_api() -> int: 216 copy_details = oci.object_storage.models.CopyObjectDetails( 217 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key 218 ) 219 220 self._oci_client.copy_object( 221 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details 222 ) 223 224 return src_object.content_length 225 226 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 227 228 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 229 bucket, key = split_path(path) 230 self._refresh_oci_client_if_needed() 231 232 def _invoke_api() -> None: 233 namespace_name = self._namespace 234 bucket_name = bucket 235 object_name = key 236 if if_match is not None: 237 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match) 238 else: 239 self._oci_client.delete_object(namespace_name, bucket_name, object_name) 240 241 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 242 243 def _is_dir(self, path: str) -> bool: 244 # Ensure the path ends with '/' to mimic a directory 245 path = self._append_delimiter(path) 246 247 bucket, key = split_path(path) 248 self._refresh_oci_client_if_needed() 249 250 def _invoke_api() -> bool: 251 # List objects with the given prefix 252 response = self._oci_client.list_objects( 253 namespace_name=self._namespace, 254 bucket_name=bucket, 255 prefix=key, 256 delimiter="/", 257 ) 258 # Check if there are any contents or common prefixes 259 if response: 260 return bool(response.data.objects or response.data.prefixes) 261 return False 262 263 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 264 265 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 266 bucket, key = split_path(path) 267 if path.endswith("/") or (bucket and not key): 268 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 269 # which metadata is not guaranteed to exist for cases such as 270 # "virtual prefix" that was never explicitly created. 271 if self._is_dir(path): 272 return ObjectMetadata( 273 key=path, 274 type="directory", 275 content_length=0, 276 last_modified=AWARE_DATETIME_MIN, 277 ) 278 else: 279 raise FileNotFoundError(f"Directory {path} does not exist.") 280 else: 281 self._refresh_oci_client_if_needed() 282 283 def _invoke_api() -> ObjectMetadata: 284 response = self._oci_client.head_object( 285 namespace_name=self._namespace, bucket_name=bucket, object_name=key 286 ) 287 288 # Extract custom metadata from headers with 'opc-meta-' prefix 289 attributes = {} 290 if response.headers: # pyright: ignore [reportOptionalMemberAccess] 291 for metadata_key, metadata_val in response.headers.items(): # pyright: ignore [reportOptionalMemberAccess] 292 if metadata_key.startswith("opc-meta-"): 293 # Remove the 'opc-meta-' prefix to get the original key 294 metadata_key = metadata_key[len("opc-meta-") :] 295 attributes[metadata_key] = metadata_val 296 297 return ObjectMetadata( 298 key=path, 299 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess] 300 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess] 301 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess] 302 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess] 303 metadata=attributes if attributes else None, 304 ) 305 306 try: 307 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 308 except FileNotFoundError as error: 309 if strict: 310 # If the object does not exist on the given path, we will append a trailing slash and 311 # check if the path is a directory. 312 path = self._append_delimiter(path) 313 if self._is_dir(path): 314 return ObjectMetadata( 315 key=path, 316 type="directory", 317 content_length=0, 318 last_modified=AWARE_DATETIME_MIN, 319 ) 320 raise error 321 322 def _list_objects( 323 self, 324 path: str, 325 start_after: Optional[str] = None, 326 end_at: Optional[str] = None, 327 include_directories: bool = False, 328 ) -> Iterator[ObjectMetadata]: 329 bucket, prefix = split_path(path) 330 self._refresh_oci_client_if_needed() 331 332 def _invoke_api() -> Iterator[ObjectMetadata]: 333 # ListObjects only includes object names by default. 334 # 335 # Request additional fields needed for creating an ObjectMetadata. 336 fields = ",".join( 337 [ 338 "etag", 339 "name", 340 "size", 341 "timeModified", 342 ] 343 ) 344 next_start_with: Optional[str] = start_after 345 while True: 346 if include_directories: 347 response = self._oci_client.list_objects( 348 namespace_name=self._namespace, 349 bucket_name=bucket, 350 prefix=prefix, 351 # This is ≥ instead of >. 352 start=next_start_with, 353 delimiter="/", 354 fields=fields, 355 ) 356 else: 357 response = self._oci_client.list_objects( 358 namespace_name=self._namespace, 359 bucket_name=bucket, 360 prefix=prefix, 361 # This is ≥ instead of >. 362 start=next_start_with, 363 fields=fields, 364 ) 365 366 if not response: 367 return [] 368 369 if include_directories: 370 for directory in response.data.prefixes: 371 yield ObjectMetadata( 372 key=directory.rstrip("/"), 373 type="directory", 374 content_length=0, 375 last_modified=AWARE_DATETIME_MIN, 376 ) 377 378 # OCI guarantees lexicographical order. 379 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess] 380 key = response_object.name 381 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 382 if key.endswith("/"): 383 if include_directories: 384 yield ObjectMetadata( 385 key=os.path.join(bucket, key.rstrip("/")), 386 type="directory", 387 content_length=0, 388 last_modified=response_object.time_modified, 389 ) 390 else: 391 yield ObjectMetadata( 392 key=os.path.join(bucket, key), 393 type="file", 394 content_length=response_object.size, 395 last_modified=response_object.time_modified, 396 etag=response_object.etag, 397 ) 398 elif start_after != key: 399 return 400 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess] 401 if next_start_with is None or (end_at is not None and end_at < next_start_with): 402 return 403 404 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 405 406 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 407 bucket, key = split_path(remote_path) 408 file_size: int = 0 409 self._refresh_oci_client_if_needed() 410 411 validated_attributes = validate_attributes(attributes) 412 if isinstance(f, str): 413 file_size = os.path.getsize(f) 414 415 def _invoke_api() -> int: 416 if file_size > self._multipart_threshold: 417 self._upload_manager.upload_file( 418 namespace_name=self._namespace, 419 bucket_name=bucket, 420 object_name=key, 421 file_path=f, 422 part_size=self._multipart_chunksize, 423 allow_parallel_uploads=True, 424 metadata=validated_attributes or {}, 425 ) 426 else: 427 self._upload_manager.upload_file( 428 namespace_name=self._namespace, 429 bucket_name=bucket, 430 object_name=key, 431 file_path=f, 432 metadata=validated_attributes or {}, 433 ) 434 435 return file_size 436 437 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 438 else: 439 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 440 if isinstance(f, io.StringIO): 441 f = io.BytesIO(f.getvalue().encode("utf-8")) 442 443 f.seek(0, io.SEEK_END) 444 file_size = f.tell() 445 f.seek(0) 446 447 def _invoke_api() -> int: 448 if file_size > self._multipart_threshold: 449 self._upload_manager.upload_stream( 450 namespace_name=self._namespace, 451 bucket_name=bucket, 452 object_name=key, 453 stream_ref=f, 454 part_size=self._multipart_chunksize, 455 allow_parallel_uploads=True, 456 metadata=validated_attributes or {}, 457 ) 458 else: 459 self._upload_manager.upload_stream( 460 namespace_name=self._namespace, 461 bucket_name=bucket, 462 object_name=key, 463 stream_ref=f, 464 metadata=validated_attributes or {}, 465 ) 466 467 return file_size 468 469 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 470 471 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 472 self._refresh_oci_client_if_needed() 473 474 if metadata is None: 475 metadata = self._get_object_metadata(remote_path) 476 477 bucket, key = split_path(remote_path) 478 479 if isinstance(f, str): 480 if os.path.dirname(f): 481 os.makedirs(os.path.dirname(f), exist_ok=True) 482 483 def _invoke_api() -> int: 484 response = self._oci_client.get_object( 485 namespace_name=self._namespace, bucket_name=bucket, object_name=key 486 ) 487 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 488 temp_file_path = fp.name 489 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 490 fp.write(chunk) 491 os.rename(src=temp_file_path, dst=f) 492 493 return metadata.content_length 494 495 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 496 else: 497 498 def _invoke_api() -> int: 499 response = self._oci_client.get_object( 500 namespace_name=self._namespace, bucket_name=bucket, object_name=key 501 ) 502 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 503 if isinstance(f, io.StringIO): 504 bytes_fileobj = io.BytesIO() 505 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 506 bytes_fileobj.write(chunk) 507 f.write(bytes_fileobj.getvalue().decode("utf-8")) 508 else: 509 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 510 f.write(chunk) 511 512 return metadata.content_length 513 514 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)