1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19from collections.abc import Callable, Iterator
20from typing import IO, Any, Optional, TypeVar, Union
21
22import oci
23from dateutil.parser import parse as dateutil_parser
24from oci._vendor.requests.exceptions import (
25 ChunkedEncodingError,
26 ConnectionError,
27 ContentDecodingError,
28)
29from oci.auth.signers import SecurityTokenSigner
30from oci.exceptions import ServiceError
31from oci.object_storage import ObjectStorageClient, UploadManager
32from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
33from oci.signer import load_private_key_from_file
34
35from ..constants import DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT
36from ..telemetry import Telemetry
37from ..types import (
38 AWARE_DATETIME_MIN,
39 CredentialsProvider,
40 ObjectMetadata,
41 PreconditionFailedError,
42 Range,
43 RetryableError,
44 SymlinkHandling,
45)
46from ..utils import safe_makedirs, split_path, validate_attributes
47from .base import BaseStorageProvider
48
49_T = TypeVar("_T")
50
51MB = 1024 * 1024
52
53MULTIPART_THRESHOLD = 64 * MB
54MULTIPART_CHUNKSIZE = 32 * MB
55
56PROVIDER = "oci"
57
58
[docs]
59class OracleStorageProvider(BaseStorageProvider):
60 """
61 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with
62 Oracle Cloud Infrastructure (OCI) Object Storage.
63 """
64
65 _namespace: str
66 _credentials_provider: Optional[CredentialsProvider]
67 _oci_client: ObjectStorageClient
68
69 def __init__(
70 self,
71 namespace: str,
72 base_path: str = "",
73 credentials_provider: Optional[CredentialsProvider] = None,
74 retry_strategy: Optional[dict[str, Any]] = None,
75 config_dict: Optional[dict[str, Any]] = None,
76 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
77 **kwargs: Any,
78 ) -> None:
79 """
80 Initializes an instance of :py:class:`OracleStorageProvider`.
81
82 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy.
83 :param base_path: The root prefix path within the bucket where all operations will be scoped.
84 :param credentials_provider: The provider to retrieve OCI credentials.
85 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters.
86 :param config_dict: Resolved MSC config.
87 :param telemetry_provider: A function that provides a telemetry instance.
88 """
89 super().__init__(
90 base_path=base_path,
91 provider_name=PROVIDER,
92 config_dict=config_dict,
93 telemetry_provider=telemetry_provider,
94 )
95
96 self._namespace = namespace
97 self._credentials_provider = credentials_provider
98 self._retry_strategy = (
99 DEFAULT_RETRY_STRATEGY
100 if retry_strategy is None
101 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy()
102 )
103 self._timeout = kwargs.get("timeout")
104 if self._timeout is None:
105 self._timeout = (DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT)
106 self._oci_client = self._create_oci_client()
107 self._upload_manager = UploadManager(self._oci_client)
108 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD))
109 self._multipart_chunksize = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE))
110
111 def _create_oci_client(self) -> ObjectStorageClient:
112 config = oci.config.from_file()
113 kwargs = {"retry_strategy": self._retry_strategy}
114
115 # OCI doesn't support `authentication_type=security_token` OCI config entries in their SDKs yet. Manually configure signers.
116 #
117 # https://github.com/oracle/oci-python-sdk/blob/v2.169.0/src/oci/util.py#L213-L225
118 # https://github.com/oracle/oci-ruby-sdk/issues/70
119 if "security_token_file" in config:
120 with open(config["security_token_file"], "r") as security_token_file:
121 kwargs["signer"] = SecurityTokenSigner(
122 private_key=load_private_key_from_file(
123 filename=config["key_file"],
124 pass_phrase=config.get("pass_phrase"),
125 ),
126 # The OCI documentation + CLI are unforgiving about newline-terminated security token files.
127 #
128 # Do not ignore trailing newlines in case future upstream automatic signer configuration does the same.
129 #
130 # https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm#sdk_authentication_methods_session_token
131 # https://github.com/oracle/oci-cli/blob/v3.77.0/src/oci_cli/cli_session.py#L143-L144
132 token=security_token_file.read(),
133 )
134
135 client = ObjectStorageClient(config, **kwargs)
136 client.base_client.timeout = self._timeout
137 return client
138
139 def _refresh_oci_client_if_needed(self) -> None:
140 """
141 Refreshes the OCI client if the current credentials are expired.
142 """
143 if self._credentials_provider:
144 credentials = self._credentials_provider.get_credentials()
145 if credentials.is_expired():
146 self._credentials_provider.refresh_credentials()
147 self._oci_client = self._create_oci_client()
148 self._upload_manager = UploadManager(
149 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4
150 )
151
152 def _translate_errors(
153 self,
154 func: Callable[[], _T],
155 operation: str,
156 bucket: str,
157 key: str,
158 ) -> _T:
159 """
160 Translates errors like timeouts and client errors.
161
162 :param func: The function that performs the actual object storage operation.
163 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
164 :param bucket: The name of the object storage bucket involved in the operation.
165 :param key: The key of the object within the object storage bucket.
166
167 :return: The result of the object storage operation, typically the return value of the `func` callable.
168 """
169 try:
170 return func()
171 except ServiceError as error:
172 status_code = error.status
173 request_id = error.request_id
174 endpoint = error.request_endpoint
175 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}"
176
177 if status_code == 404:
178 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
179 elif status_code == 412:
180 raise PreconditionFailedError(
181 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
182 ) from error
183 elif status_code == 429:
184 raise RetryableError(
185 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
186 ) from error
187 else:
188 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
189 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error:
190 raise RetryableError(
191 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}"
192 ) from error
193 except FileNotFoundError:
194 raise
195 except Exception as error:
196 raise RuntimeError(
197 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
198 ) from error
199
200 def _put_object(
201 self,
202 path: str,
203 body: bytes,
204 if_match: Optional[str] = None,
205 if_none_match: Optional[str] = None,
206 attributes: Optional[dict[str, str]] = None,
207 ) -> int:
208 bucket, key = split_path(path)
209 self._refresh_oci_client_if_needed()
210
211 # OCI only supports if_none_match=="*"
212 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object
213 def _invoke_api() -> int:
214 validated_attributes = validate_attributes(attributes)
215 self._oci_client.put_object(
216 namespace_name=self._namespace,
217 bucket_name=bucket,
218 object_name=key,
219 put_object_body=body,
220 opc_meta=validated_attributes or {}, # Pass metadata or empty dict
221 if_match=if_match,
222 if_none_match=if_none_match,
223 )
224
225 return len(body)
226
227 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
228
229 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
230 bucket, key = split_path(path)
231 self._refresh_oci_client_if_needed()
232
233 def _invoke_api() -> bytes:
234 if byte_range:
235 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
236 else:
237 bytes_range = None
238 response = self._oci_client.get_object(
239 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range
240 )
241 return response.data.content # pyright: ignore [reportOptionalMemberAccess]
242
243 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
244
245 def _copy_object(self, src_path: str, dest_path: str) -> int:
246 src_bucket, src_key = split_path(src_path)
247 dest_bucket, dest_key = split_path(dest_path)
248 self._refresh_oci_client_if_needed()
249
250 src_object = self._get_object_metadata(src_path)
251
252 def _invoke_api() -> int:
253 copy_details = oci.object_storage.models.CopyObjectDetails(
254 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key
255 )
256
257 self._oci_client.copy_object(
258 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details
259 )
260
261 return src_object.content_length
262
263 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
264
265 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
266 bucket, key = split_path(path)
267 self._refresh_oci_client_if_needed()
268
269 def _invoke_api() -> None:
270 namespace_name = self._namespace
271 bucket_name = bucket
272 object_name = key
273 if if_match is not None:
274 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match)
275 else:
276 self._oci_client.delete_object(namespace_name, bucket_name, object_name)
277
278 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
279
280 def _is_dir(self, path: str) -> bool:
281 # Ensure the path ends with '/' to mimic a directory
282 path = self._append_delimiter(path)
283
284 bucket, key = split_path(path)
285 self._refresh_oci_client_if_needed()
286
287 def _invoke_api() -> bool:
288 # List objects with the given prefix
289 response = self._oci_client.list_objects(
290 namespace_name=self._namespace,
291 bucket_name=bucket,
292 prefix=key,
293 delimiter="/",
294 )
295 # Check if there are any contents or common prefixes
296 if response:
297 return bool(response.data.objects or response.data.prefixes)
298 return False
299
300 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
301
302 def _make_symlink(self, path: str, target: str) -> None:
303 bucket, key = split_path(path)
304 target_bucket, target_key = split_path(target)
305 if bucket != target_bucket:
306 raise ValueError(f"Cannot create cross-bucket symlink: '{bucket}' -> '{target_bucket}'.")
307 relative_target = ObjectMetadata.encode_symlink_target(key, target_key)
308 self._refresh_oci_client_if_needed()
309
310 def _invoke_api() -> None:
311 self._oci_client.put_object(
312 namespace_name=self._namespace,
313 bucket_name=bucket,
314 object_name=key,
315 put_object_body=b"",
316 opc_meta={"msc-symlink-target": relative_target},
317 )
318
319 self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
320
321 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
322 bucket, key = split_path(path)
323 if path.endswith("/") or (bucket and not key):
324 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
325 # which metadata is not guaranteed to exist for cases such as
326 # "virtual prefix" that was never explicitly created.
327 if self._is_dir(path):
328 return ObjectMetadata(
329 key=path,
330 type="directory",
331 content_length=0,
332 last_modified=AWARE_DATETIME_MIN,
333 )
334 else:
335 raise FileNotFoundError(f"Directory {path} does not exist.")
336 else:
337 self._refresh_oci_client_if_needed()
338
339 def _invoke_api() -> ObjectMetadata:
340 response = self._oci_client.head_object(
341 namespace_name=self._namespace, bucket_name=bucket, object_name=key
342 )
343
344 # Extract custom metadata from headers with 'opc-meta-' prefix
345 attributes = {}
346 if response.headers: # pyright: ignore [reportOptionalMemberAccess]
347 for metadata_key, metadata_val in response.headers.items(): # pyright: ignore [reportOptionalMemberAccess]
348 if metadata_key.startswith("opc-meta-"):
349 # Remove the 'opc-meta-' prefix to get the original key
350 metadata_key = metadata_key[len("opc-meta-") :]
351 attributes[metadata_key] = metadata_val
352
353 user_metadata = attributes if attributes else None
354 symlink_target = user_metadata.get("msc-symlink-target") if user_metadata else None
355 return ObjectMetadata(
356 key=path,
357 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess]
358 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess]
359 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess]
360 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess]
361 metadata=user_metadata,
362 symlink_target=symlink_target,
363 )
364
365 try:
366 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
367 except FileNotFoundError as error:
368 if strict:
369 # If the object does not exist on the given path, we will append a trailing slash and
370 # check if the path is a directory.
371 path = self._append_delimiter(path)
372 if self._is_dir(path):
373 return ObjectMetadata(
374 key=path,
375 type="directory",
376 content_length=0,
377 last_modified=AWARE_DATETIME_MIN,
378 )
379 raise error
380
381 def _list_objects(
382 self,
383 path: str,
384 start_after: Optional[str] = None,
385 end_at: Optional[str] = None,
386 include_directories: bool = False,
387 symlink_handling: SymlinkHandling = SymlinkHandling.FOLLOW,
388 ) -> Iterator[ObjectMetadata]:
389 bucket, prefix = split_path(path)
390 self._refresh_oci_client_if_needed()
391
392 def _invoke_api() -> Iterator[ObjectMetadata]:
393 # ListObjects only includes object names by default.
394 #
395 # Request additional fields needed for creating an ObjectMetadata.
396 fields = ",".join(
397 [
398 "etag",
399 "name",
400 "size",
401 "timeModified",
402 ]
403 )
404 next_start_with: Optional[str] = start_after
405 while True:
406 if include_directories:
407 response = self._oci_client.list_objects(
408 namespace_name=self._namespace,
409 bucket_name=bucket,
410 prefix=prefix,
411 # This is ≥ instead of >.
412 start=next_start_with,
413 delimiter="/",
414 fields=fields,
415 )
416 else:
417 response = self._oci_client.list_objects(
418 namespace_name=self._namespace,
419 bucket_name=bucket,
420 prefix=prefix,
421 # This is ≥ instead of >.
422 start=next_start_with,
423 fields=fields,
424 )
425
426 if not response:
427 return []
428
429 if include_directories:
430 for directory in response.data.prefixes:
431 prefix_key = directory.rstrip("/")
432 # Filter by start_after and end_at if specified
433 if (start_after is None or start_after < prefix_key) and (
434 end_at is None or prefix_key <= end_at
435 ):
436 yield ObjectMetadata(
437 key=os.path.join(bucket, prefix_key),
438 type="directory",
439 content_length=0,
440 last_modified=AWARE_DATETIME_MIN,
441 )
442 elif end_at is not None and end_at < prefix_key:
443 return
444
445 # OCI guarantees lexicographical order.
446 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess]
447 key = response_object.name
448 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
449 if key.endswith("/"):
450 if include_directories:
451 yield ObjectMetadata(
452 key=os.path.join(bucket, key.rstrip("/")),
453 type="directory",
454 content_length=0,
455 last_modified=response_object.time_modified,
456 )
457 else:
458 symlink_target = None
459 if response_object.size == 0:
460 try:
461 meta = self._get_object_metadata(os.path.join(bucket, key))
462 symlink_target = meta.symlink_target
463 except Exception:
464 symlink_target = None
465 yield ObjectMetadata(
466 key=os.path.join(bucket, key),
467 type="file",
468 content_length=response_object.size,
469 last_modified=response_object.time_modified,
470 etag=response_object.etag,
471 symlink_target=symlink_target,
472 )
473 elif start_after != key:
474 return
475 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess]
476 if next_start_with is None or (end_at is not None and end_at < next_start_with):
477 return
478
479 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
480
481 @property
482 def supports_parallel_listing(self) -> bool:
483 return True
484
485 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
486 bucket, key = split_path(remote_path)
487 file_size: int = 0
488 self._refresh_oci_client_if_needed()
489
490 validated_attributes = validate_attributes(attributes)
491 if isinstance(f, str):
492 file_size = os.path.getsize(f)
493
494 def _invoke_api() -> int:
495 if file_size > self._multipart_threshold:
496 self._upload_manager.upload_file(
497 namespace_name=self._namespace,
498 bucket_name=bucket,
499 object_name=key,
500 file_path=f,
501 part_size=self._multipart_chunksize,
502 allow_parallel_uploads=True,
503 metadata=validated_attributes or {},
504 )
505 else:
506 self._upload_manager.upload_file(
507 namespace_name=self._namespace,
508 bucket_name=bucket,
509 object_name=key,
510 file_path=f,
511 metadata=validated_attributes or {},
512 )
513
514 return file_size
515
516 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
517 else:
518 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
519 if isinstance(f, io.StringIO):
520 f = io.BytesIO(f.getvalue().encode("utf-8"))
521
522 f.seek(0, io.SEEK_END)
523 file_size = f.tell()
524 f.seek(0)
525
526 def _invoke_api() -> int:
527 if file_size > self._multipart_threshold:
528 self._upload_manager.upload_stream(
529 namespace_name=self._namespace,
530 bucket_name=bucket,
531 object_name=key,
532 stream_ref=f,
533 part_size=self._multipart_chunksize,
534 allow_parallel_uploads=True,
535 metadata=validated_attributes or {},
536 )
537 else:
538 self._upload_manager.upload_stream(
539 namespace_name=self._namespace,
540 bucket_name=bucket,
541 object_name=key,
542 stream_ref=f,
543 metadata=validated_attributes or {},
544 )
545
546 return file_size
547
548 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
549
550 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
551 self._refresh_oci_client_if_needed()
552
553 if metadata is None:
554 metadata = self._get_object_metadata(remote_path)
555
556 bucket, key = split_path(remote_path)
557
558 if isinstance(f, str):
559 if os.path.dirname(f):
560 safe_makedirs(os.path.dirname(f))
561
562 def _invoke_api() -> int:
563 response = self._oci_client.get_object(
564 namespace_name=self._namespace, bucket_name=bucket, object_name=key
565 )
566 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
567 temp_file_path = fp.name
568 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
569 fp.write(chunk)
570 os.rename(src=temp_file_path, dst=f)
571
572 return metadata.content_length
573
574 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
575 else:
576
577 def _invoke_api() -> int:
578 response = self._oci_client.get_object(
579 namespace_name=self._namespace, bucket_name=bucket, object_name=key
580 )
581 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
582 if isinstance(f, io.StringIO):
583 bytes_fileobj = io.BytesIO()
584 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
585 bytes_fileobj.write(chunk)
586 f.write(bytes_fileobj.getvalue().decode("utf-8"))
587 else:
588 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
589 f.write(chunk)
590
591 return metadata.content_length
592
593 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)