Source code for multistorageclient.providers.oci

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18import tempfile
 19from collections.abc import Callable, Iterator
 20from typing import IO, Any, Optional, TypeVar, Union
 21
 22import oci
 23from dateutil.parser import parse as dateutil_parser
 24from oci._vendor.requests.exceptions import (
 25    ChunkedEncodingError,
 26    ConnectionError,
 27    ContentDecodingError,
 28)
 29from oci.exceptions import ServiceError
 30from oci.object_storage import ObjectStorageClient, UploadManager
 31from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
 32
 33from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
 34from ..telemetry import Telemetry
 35from ..types import (
 36    AWARE_DATETIME_MIN,
 37    CredentialsProvider,
 38    ObjectMetadata,
 39    PreconditionFailedError,
 40    Range,
 41    RetryableError,
 42)
 43from ..utils import safe_makedirs, split_path, validate_attributes
 44from .base import BaseStorageProvider
 45
 46_T = TypeVar("_T")
 47
 48MB = 1024 * 1024
 49
 50MULTIPART_THRESHOLD = 64 * MB
 51MULTIPART_CHUNKSIZE = 32 * MB
 52
 53PROVIDER = "oci"
 54
 55
[docs] 56class OracleStorageProvider(BaseStorageProvider): 57 """ 58 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with 59 Oracle Cloud Infrastructure (OCI) Object Storage. 60 """ 61 62 def __init__( 63 self, 64 namespace: str, 65 base_path: str = "", 66 credentials_provider: Optional[CredentialsProvider] = None, 67 retry_strategy: Optional[dict[str, Any]] = None, 68 config_dict: Optional[dict[str, Any]] = None, 69 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 70 **kwargs: Any, 71 ) -> None: 72 """ 73 Initializes an instance of :py:class:`OracleStorageProvider`. 74 75 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy. 76 :param base_path: The root prefix path within the bucket where all operations will be scoped. 77 :param credentials_provider: The provider to retrieve OCI credentials. 78 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters. 79 :param config_dict: Resolved MSC config. 80 :param telemetry_provider: A function that provides a telemetry instance. 81 """ 82 super().__init__( 83 base_path=base_path, 84 provider_name=PROVIDER, 85 config_dict=config_dict, 86 telemetry_provider=telemetry_provider, 87 ) 88 89 self._namespace = namespace 90 self._credentials_provider = credentials_provider 91 self._retry_strategy = ( 92 DEFAULT_RETRY_STRATEGY 93 if retry_strategy is None 94 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy() 95 ) 96 self._timeout = kwargs.get("timeout") 97 if self._timeout is None: 98 self._timeout = (DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT) 99 self._oci_client = self._create_oci_client() 100 self._upload_manager = UploadManager(self._oci_client) 101 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD)) 102 self._multipart_chunksize = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE)) 103 104 def _create_oci_client(self) -> ObjectStorageClient: 105 config = oci.config.from_file() 106 client = ObjectStorageClient(config, retry_strategy=self._retry_strategy) 107 client.base_client.timeout = self._timeout 108 return client 109 110 def _refresh_oci_client_if_needed(self) -> None: 111 """ 112 Refreshes the OCI client if the current credentials are expired. 113 """ 114 if self._credentials_provider: 115 credentials = self._credentials_provider.get_credentials() 116 if credentials.is_expired(): 117 self._credentials_provider.refresh_credentials() 118 self._oci_client = self._create_oci_client() 119 self._upload_manager = UploadManager( 120 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4 121 ) 122 123 def _translate_errors( 124 self, 125 func: Callable[[], _T], 126 operation: str, 127 bucket: str, 128 key: str, 129 ) -> _T: 130 """ 131 Translates errors like timeouts and client errors. 132 133 :param func: The function that performs the actual object storage operation. 134 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 135 :param bucket: The name of the object storage bucket involved in the operation. 136 :param key: The key of the object within the object storage bucket. 137 138 :return: The result of the object storage operation, typically the return value of the `func` callable. 139 """ 140 try: 141 return func() 142 except ServiceError as error: 143 status_code = error.status 144 request_id = error.request_id 145 endpoint = error.request_endpoint 146 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}" 147 148 if status_code == 404: 149 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from 150 elif status_code == 412: 151 raise PreconditionFailedError( 152 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}" 153 ) from error 154 elif status_code == 429: 155 raise RetryableError( 156 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}" 157 ) from error 158 else: 159 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error 160 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error: 161 raise RetryableError( 162 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}" 163 ) from error 164 except FileNotFoundError: 165 raise 166 except Exception as error: 167 raise RuntimeError( 168 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}" 169 ) from error 170 171 def _put_object( 172 self, 173 path: str, 174 body: bytes, 175 if_match: Optional[str] = None, 176 if_none_match: Optional[str] = None, 177 attributes: Optional[dict[str, str]] = None, 178 ) -> int: 179 bucket, key = split_path(path) 180 self._refresh_oci_client_if_needed() 181 182 # OCI only supports if_none_match=="*" 183 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object 184 def _invoke_api() -> int: 185 validated_attributes = validate_attributes(attributes) 186 self._oci_client.put_object( 187 namespace_name=self._namespace, 188 bucket_name=bucket, 189 object_name=key, 190 put_object_body=body, 191 opc_meta=validated_attributes or {}, # Pass metadata or empty dict 192 if_match=if_match, 193 if_none_match=if_none_match, 194 ) 195 196 return len(body) 197 198 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 199 200 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 201 bucket, key = split_path(path) 202 self._refresh_oci_client_if_needed() 203 204 def _invoke_api() -> bytes: 205 if byte_range: 206 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}" 207 else: 208 bytes_range = None 209 response = self._oci_client.get_object( 210 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range 211 ) 212 return response.data.content # pyright: ignore [reportOptionalMemberAccess] 213 214 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 215 216 def _copy_object(self, src_path: str, dest_path: str) -> int: 217 src_bucket, src_key = split_path(src_path) 218 dest_bucket, dest_key = split_path(dest_path) 219 self._refresh_oci_client_if_needed() 220 221 src_object = self._get_object_metadata(src_path) 222 223 def _invoke_api() -> int: 224 copy_details = oci.object_storage.models.CopyObjectDetails( 225 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key 226 ) 227 228 self._oci_client.copy_object( 229 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details 230 ) 231 232 return src_object.content_length 233 234 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key) 235 236 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 237 bucket, key = split_path(path) 238 self._refresh_oci_client_if_needed() 239 240 def _invoke_api() -> None: 241 namespace_name = self._namespace 242 bucket_name = bucket 243 object_name = key 244 if if_match is not None: 245 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match) 246 else: 247 self._oci_client.delete_object(namespace_name, bucket_name, object_name) 248 249 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key) 250 251 def _is_dir(self, path: str) -> bool: 252 # Ensure the path ends with '/' to mimic a directory 253 path = self._append_delimiter(path) 254 255 bucket, key = split_path(path) 256 self._refresh_oci_client_if_needed() 257 258 def _invoke_api() -> bool: 259 # List objects with the given prefix 260 response = self._oci_client.list_objects( 261 namespace_name=self._namespace, 262 bucket_name=bucket, 263 prefix=key, 264 delimiter="/", 265 ) 266 # Check if there are any contents or common prefixes 267 if response: 268 return bool(response.data.objects or response.data.prefixes) 269 return False 270 271 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key) 272 273 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 274 bucket, key = split_path(path) 275 if path.endswith("/") or (bucket and not key): 276 # If path ends with "/" or empty key name is provided, then assume it's a "directory", 277 # which metadata is not guaranteed to exist for cases such as 278 # "virtual prefix" that was never explicitly created. 279 if self._is_dir(path): 280 return ObjectMetadata( 281 key=path, 282 type="directory", 283 content_length=0, 284 last_modified=AWARE_DATETIME_MIN, 285 ) 286 else: 287 raise FileNotFoundError(f"Directory {path} does not exist.") 288 else: 289 self._refresh_oci_client_if_needed() 290 291 def _invoke_api() -> ObjectMetadata: 292 response = self._oci_client.head_object( 293 namespace_name=self._namespace, bucket_name=bucket, object_name=key 294 ) 295 296 # Extract custom metadata from headers with 'opc-meta-' prefix 297 attributes = {} 298 if response.headers: # pyright: ignore [reportOptionalMemberAccess] 299 for metadata_key, metadata_val in response.headers.items(): # pyright: ignore [reportOptionalMemberAccess] 300 if metadata_key.startswith("opc-meta-"): 301 # Remove the 'opc-meta-' prefix to get the original key 302 metadata_key = metadata_key[len("opc-meta-") :] 303 attributes[metadata_key] = metadata_val 304 305 return ObjectMetadata( 306 key=path, 307 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess] 308 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess] 309 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess] 310 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess] 311 metadata=attributes if attributes else None, 312 ) 313 314 try: 315 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key) 316 except FileNotFoundError as error: 317 if strict: 318 # If the object does not exist on the given path, we will append a trailing slash and 319 # check if the path is a directory. 320 path = self._append_delimiter(path) 321 if self._is_dir(path): 322 return ObjectMetadata( 323 key=path, 324 type="directory", 325 content_length=0, 326 last_modified=AWARE_DATETIME_MIN, 327 ) 328 raise error 329 330 def _list_objects( 331 self, 332 path: str, 333 start_after: Optional[str] = None, 334 end_at: Optional[str] = None, 335 include_directories: bool = False, 336 follow_symlinks: bool = True, 337 ) -> Iterator[ObjectMetadata]: 338 bucket, prefix = split_path(path) 339 self._refresh_oci_client_if_needed() 340 341 def _invoke_api() -> Iterator[ObjectMetadata]: 342 # ListObjects only includes object names by default. 343 # 344 # Request additional fields needed for creating an ObjectMetadata. 345 fields = ",".join( 346 [ 347 "etag", 348 "name", 349 "size", 350 "timeModified", 351 ] 352 ) 353 next_start_with: Optional[str] = start_after 354 while True: 355 if include_directories: 356 response = self._oci_client.list_objects( 357 namespace_name=self._namespace, 358 bucket_name=bucket, 359 prefix=prefix, 360 # This is ≥ instead of >. 361 start=next_start_with, 362 delimiter="/", 363 fields=fields, 364 ) 365 else: 366 response = self._oci_client.list_objects( 367 namespace_name=self._namespace, 368 bucket_name=bucket, 369 prefix=prefix, 370 # This is ≥ instead of >. 371 start=next_start_with, 372 fields=fields, 373 ) 374 375 if not response: 376 return [] 377 378 if include_directories: 379 for directory in response.data.prefixes: 380 prefix_key = directory.rstrip("/") 381 # Filter by start_after and end_at if specified 382 if (start_after is None or start_after < prefix_key) and ( 383 end_at is None or prefix_key <= end_at 384 ): 385 yield ObjectMetadata( 386 key=os.path.join(bucket, prefix_key), 387 type="directory", 388 content_length=0, 389 last_modified=AWARE_DATETIME_MIN, 390 ) 391 elif end_at is not None and end_at < prefix_key: 392 return 393 394 # OCI guarantees lexicographical order. 395 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess] 396 key = response_object.name 397 if (start_after is None or start_after < key) and (end_at is None or key <= end_at): 398 if key.endswith("/"): 399 if include_directories: 400 yield ObjectMetadata( 401 key=os.path.join(bucket, key.rstrip("/")), 402 type="directory", 403 content_length=0, 404 last_modified=response_object.time_modified, 405 ) 406 else: 407 yield ObjectMetadata( 408 key=os.path.join(bucket, key), 409 type="file", 410 content_length=response_object.size, 411 last_modified=response_object.time_modified, 412 etag=response_object.etag, 413 ) 414 elif start_after != key: 415 return 416 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess] 417 if next_start_with is None or (end_at is not None and end_at < next_start_with): 418 return 419 420 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix) 421 422 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 423 bucket, key = split_path(remote_path) 424 file_size: int = 0 425 self._refresh_oci_client_if_needed() 426 427 validated_attributes = validate_attributes(attributes) 428 if isinstance(f, str): 429 file_size = os.path.getsize(f) 430 431 def _invoke_api() -> int: 432 if file_size > self._multipart_threshold: 433 self._upload_manager.upload_file( 434 namespace_name=self._namespace, 435 bucket_name=bucket, 436 object_name=key, 437 file_path=f, 438 part_size=self._multipart_chunksize, 439 allow_parallel_uploads=True, 440 metadata=validated_attributes or {}, 441 ) 442 else: 443 self._upload_manager.upload_file( 444 namespace_name=self._namespace, 445 bucket_name=bucket, 446 object_name=key, 447 file_path=f, 448 metadata=validated_attributes or {}, 449 ) 450 451 return file_size 452 453 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 454 else: 455 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 456 if isinstance(f, io.StringIO): 457 f = io.BytesIO(f.getvalue().encode("utf-8")) 458 459 f.seek(0, io.SEEK_END) 460 file_size = f.tell() 461 f.seek(0) 462 463 def _invoke_api() -> int: 464 if file_size > self._multipart_threshold: 465 self._upload_manager.upload_stream( 466 namespace_name=self._namespace, 467 bucket_name=bucket, 468 object_name=key, 469 stream_ref=f, 470 part_size=self._multipart_chunksize, 471 allow_parallel_uploads=True, 472 metadata=validated_attributes or {}, 473 ) 474 else: 475 self._upload_manager.upload_stream( 476 namespace_name=self._namespace, 477 bucket_name=bucket, 478 object_name=key, 479 stream_ref=f, 480 metadata=validated_attributes or {}, 481 ) 482 483 return file_size 484 485 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key) 486 487 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 488 self._refresh_oci_client_if_needed() 489 490 if metadata is None: 491 metadata = self._get_object_metadata(remote_path) 492 493 bucket, key = split_path(remote_path) 494 495 if isinstance(f, str): 496 if os.path.dirname(f): 497 safe_makedirs(os.path.dirname(f)) 498 499 def _invoke_api() -> int: 500 response = self._oci_client.get_object( 501 namespace_name=self._namespace, bucket_name=bucket, object_name=key 502 ) 503 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp: 504 temp_file_path = fp.name 505 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 506 fp.write(chunk) 507 os.rename(src=temp_file_path, dst=f) 508 509 return metadata.content_length 510 511 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key) 512 else: 513 514 def _invoke_api() -> int: 515 response = self._oci_client.get_object( 516 namespace_name=self._namespace, bucket_name=bucket, object_name=key 517 ) 518 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO. 519 if isinstance(f, io.StringIO): 520 bytes_fileobj = io.BytesIO() 521 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 522 bytes_fileobj.write(chunk) 523 f.write(bytes_fileobj.getvalue().decode("utf-8")) 524 else: 525 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess] 526 f.write(chunk) 527 528 return metadata.content_length 529 530 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)