1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19from collections.abc import Callable, Iterator
20from typing import IO, Any, Optional, TypeVar, Union
21
22import oci
23from dateutil.parser import parse as dateutil_parser
24from oci._vendor.requests.exceptions import (
25 ChunkedEncodingError,
26 ConnectionError,
27 ContentDecodingError,
28)
29from oci.exceptions import ServiceError
30from oci.object_storage import ObjectStorageClient, UploadManager
31from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
32
33from ..telemetry import Telemetry
34from ..types import (
35 AWARE_DATETIME_MIN,
36 CredentialsProvider,
37 ObjectMetadata,
38 PreconditionFailedError,
39 Range,
40 RetryableError,
41)
42from ..utils import split_path, validate_attributes
43from .base import BaseStorageProvider
44
45_T = TypeVar("_T")
46
47MB = 1024 * 1024
48
49MULTIPART_THRESHOLD = 64 * MB
50MULTIPART_CHUNKSIZE = 32 * MB
51
52PROVIDER = "oci"
53
54
[docs]
55class OracleStorageProvider(BaseStorageProvider):
56 """
57 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with
58 Oracle Cloud Infrastructure (OCI) Object Storage.
59 """
60
61 def __init__(
62 self,
63 namespace: str,
64 base_path: str = "",
65 credentials_provider: Optional[CredentialsProvider] = None,
66 retry_strategy: Optional[dict[str, Any]] = None,
67 config_dict: Optional[dict[str, Any]] = None,
68 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
69 **kwargs: Any,
70 ) -> None:
71 """
72 Initializes an instance of :py:class:`OracleStorageProvider`.
73
74 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy.
75 :param base_path: The root prefix path within the bucket where all operations will be scoped.
76 :param credentials_provider: The provider to retrieve OCI credentials.
77 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters.
78 :param config_dict: Resolved MSC config.
79 :param telemetry_provider: A function that provides a telemetry instance.
80 """
81 super().__init__(
82 base_path=base_path,
83 provider_name=PROVIDER,
84 config_dict=config_dict,
85 telemetry_provider=telemetry_provider,
86 )
87
88 self._namespace = namespace
89 self._credentials_provider = credentials_provider
90 self._retry_strategy = (
91 DEFAULT_RETRY_STRATEGY
92 if retry_strategy is None
93 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy()
94 )
95 self._oci_client = self._create_oci_client()
96 self._upload_manager = UploadManager(self._oci_client)
97 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD))
98 self._multipart_chunksize = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE))
99
100 def _create_oci_client(self) -> ObjectStorageClient:
101 config = oci.config.from_file()
102 return ObjectStorageClient(config, retry_strategy=self._retry_strategy)
103
104 def _refresh_oci_client_if_needed(self) -> None:
105 """
106 Refreshes the OCI client if the current credentials are expired.
107 """
108 if self._credentials_provider:
109 credentials = self._credentials_provider.get_credentials()
110 if credentials.is_expired():
111 self._credentials_provider.refresh_credentials()
112 self._oci_client = self._create_oci_client()
113 self._upload_manager = UploadManager(
114 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4
115 )
116
117 def _translate_errors(
118 self,
119 func: Callable[[], _T],
120 operation: str,
121 bucket: str,
122 key: str,
123 ) -> _T:
124 """
125 Translates errors like timeouts and client errors.
126
127 :param func: The function that performs the actual object storage operation.
128 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
129 :param bucket: The name of the object storage bucket involved in the operation.
130 :param key: The key of the object within the object storage bucket.
131
132 :return: The result of the object storage operation, typically the return value of the `func` callable.
133 """
134 try:
135 return func()
136 except ServiceError as error:
137 status_code = error.status
138 request_id = error.request_id
139 endpoint = error.request_endpoint
140 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}"
141
142 if status_code == 404:
143 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
144 elif status_code == 412:
145 raise PreconditionFailedError(
146 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
147 ) from error
148 elif status_code == 429:
149 raise RetryableError(
150 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
151 ) from error
152 else:
153 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
154 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error:
155 raise RetryableError(
156 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}"
157 ) from error
158 except FileNotFoundError:
159 raise
160 except Exception as error:
161 raise RuntimeError(
162 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
163 ) from error
164
165 def _put_object(
166 self,
167 path: str,
168 body: bytes,
169 if_match: Optional[str] = None,
170 if_none_match: Optional[str] = None,
171 attributes: Optional[dict[str, str]] = None,
172 ) -> int:
173 bucket, key = split_path(path)
174 self._refresh_oci_client_if_needed()
175
176 # OCI only supports if_none_match=="*"
177 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object
178 def _invoke_api() -> int:
179 validated_attributes = validate_attributes(attributes)
180 self._oci_client.put_object(
181 namespace_name=self._namespace,
182 bucket_name=bucket,
183 object_name=key,
184 put_object_body=body,
185 opc_meta=validated_attributes or {}, # Pass metadata or empty dict
186 if_match=if_match,
187 if_none_match=if_none_match,
188 )
189
190 return len(body)
191
192 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
193
194 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
195 bucket, key = split_path(path)
196 self._refresh_oci_client_if_needed()
197
198 def _invoke_api() -> bytes:
199 if byte_range:
200 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
201 else:
202 bytes_range = None
203 response = self._oci_client.get_object(
204 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range
205 )
206 return response.data.content # pyright: ignore [reportOptionalMemberAccess]
207
208 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
209
210 def _copy_object(self, src_path: str, dest_path: str) -> int:
211 src_bucket, src_key = split_path(src_path)
212 dest_bucket, dest_key = split_path(dest_path)
213 self._refresh_oci_client_if_needed()
214
215 src_object = self._get_object_metadata(src_path)
216
217 def _invoke_api() -> int:
218 copy_details = oci.object_storage.models.CopyObjectDetails(
219 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key
220 )
221
222 self._oci_client.copy_object(
223 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details
224 )
225
226 return src_object.content_length
227
228 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
229
230 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
231 bucket, key = split_path(path)
232 self._refresh_oci_client_if_needed()
233
234 def _invoke_api() -> None:
235 namespace_name = self._namespace
236 bucket_name = bucket
237 object_name = key
238 if if_match is not None:
239 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match)
240 else:
241 self._oci_client.delete_object(namespace_name, bucket_name, object_name)
242
243 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
244
245 def _is_dir(self, path: str) -> bool:
246 # Ensure the path ends with '/' to mimic a directory
247 path = self._append_delimiter(path)
248
249 bucket, key = split_path(path)
250 self._refresh_oci_client_if_needed()
251
252 def _invoke_api() -> bool:
253 # List objects with the given prefix
254 response = self._oci_client.list_objects(
255 namespace_name=self._namespace,
256 bucket_name=bucket,
257 prefix=key,
258 delimiter="/",
259 )
260 # Check if there are any contents or common prefixes
261 if response:
262 return bool(response.data.objects or response.data.prefixes)
263 return False
264
265 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
266
267 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
268 bucket, key = split_path(path)
269 if path.endswith("/") or (bucket and not key):
270 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
271 # which metadata is not guaranteed to exist for cases such as
272 # "virtual prefix" that was never explicitly created.
273 if self._is_dir(path):
274 return ObjectMetadata(
275 key=path,
276 type="directory",
277 content_length=0,
278 last_modified=AWARE_DATETIME_MIN,
279 )
280 else:
281 raise FileNotFoundError(f"Directory {path} does not exist.")
282 else:
283 self._refresh_oci_client_if_needed()
284
285 def _invoke_api() -> ObjectMetadata:
286 response = self._oci_client.head_object(
287 namespace_name=self._namespace, bucket_name=bucket, object_name=key
288 )
289
290 # Extract custom metadata from headers with 'opc-meta-' prefix
291 attributes = {}
292 if response.headers: # pyright: ignore [reportOptionalMemberAccess]
293 for metadata_key, metadata_val in response.headers.items(): # pyright: ignore [reportOptionalMemberAccess]
294 if metadata_key.startswith("opc-meta-"):
295 # Remove the 'opc-meta-' prefix to get the original key
296 metadata_key = metadata_key[len("opc-meta-") :]
297 attributes[metadata_key] = metadata_val
298
299 return ObjectMetadata(
300 key=path,
301 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess]
302 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess]
303 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess]
304 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess]
305 metadata=attributes if attributes else None,
306 )
307
308 try:
309 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
310 except FileNotFoundError as error:
311 if strict:
312 # If the object does not exist on the given path, we will append a trailing slash and
313 # check if the path is a directory.
314 path = self._append_delimiter(path)
315 if self._is_dir(path):
316 return ObjectMetadata(
317 key=path,
318 type="directory",
319 content_length=0,
320 last_modified=AWARE_DATETIME_MIN,
321 )
322 raise error
323
324 def _list_objects(
325 self,
326 path: str,
327 start_after: Optional[str] = None,
328 end_at: Optional[str] = None,
329 include_directories: bool = False,
330 follow_symlinks: bool = True,
331 ) -> Iterator[ObjectMetadata]:
332 bucket, prefix = split_path(path)
333 self._refresh_oci_client_if_needed()
334
335 def _invoke_api() -> Iterator[ObjectMetadata]:
336 # ListObjects only includes object names by default.
337 #
338 # Request additional fields needed for creating an ObjectMetadata.
339 fields = ",".join(
340 [
341 "etag",
342 "name",
343 "size",
344 "timeModified",
345 ]
346 )
347 next_start_with: Optional[str] = start_after
348 while True:
349 if include_directories:
350 response = self._oci_client.list_objects(
351 namespace_name=self._namespace,
352 bucket_name=bucket,
353 prefix=prefix,
354 # This is ≥ instead of >.
355 start=next_start_with,
356 delimiter="/",
357 fields=fields,
358 )
359 else:
360 response = self._oci_client.list_objects(
361 namespace_name=self._namespace,
362 bucket_name=bucket,
363 prefix=prefix,
364 # This is ≥ instead of >.
365 start=next_start_with,
366 fields=fields,
367 )
368
369 if not response:
370 return []
371
372 if include_directories:
373 for directory in response.data.prefixes:
374 yield ObjectMetadata(
375 key=directory.rstrip("/"),
376 type="directory",
377 content_length=0,
378 last_modified=AWARE_DATETIME_MIN,
379 )
380
381 # OCI guarantees lexicographical order.
382 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess]
383 key = response_object.name
384 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
385 if key.endswith("/"):
386 if include_directories:
387 yield ObjectMetadata(
388 key=os.path.join(bucket, key.rstrip("/")),
389 type="directory",
390 content_length=0,
391 last_modified=response_object.time_modified,
392 )
393 else:
394 yield ObjectMetadata(
395 key=os.path.join(bucket, key),
396 type="file",
397 content_length=response_object.size,
398 last_modified=response_object.time_modified,
399 etag=response_object.etag,
400 )
401 elif start_after != key:
402 return
403 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess]
404 if next_start_with is None or (end_at is not None and end_at < next_start_with):
405 return
406
407 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
408
409 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
410 bucket, key = split_path(remote_path)
411 file_size: int = 0
412 self._refresh_oci_client_if_needed()
413
414 validated_attributes = validate_attributes(attributes)
415 if isinstance(f, str):
416 file_size = os.path.getsize(f)
417
418 def _invoke_api() -> int:
419 if file_size > self._multipart_threshold:
420 self._upload_manager.upload_file(
421 namespace_name=self._namespace,
422 bucket_name=bucket,
423 object_name=key,
424 file_path=f,
425 part_size=self._multipart_chunksize,
426 allow_parallel_uploads=True,
427 metadata=validated_attributes or {},
428 )
429 else:
430 self._upload_manager.upload_file(
431 namespace_name=self._namespace,
432 bucket_name=bucket,
433 object_name=key,
434 file_path=f,
435 metadata=validated_attributes or {},
436 )
437
438 return file_size
439
440 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
441 else:
442 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
443 if isinstance(f, io.StringIO):
444 f = io.BytesIO(f.getvalue().encode("utf-8"))
445
446 f.seek(0, io.SEEK_END)
447 file_size = f.tell()
448 f.seek(0)
449
450 def _invoke_api() -> int:
451 if file_size > self._multipart_threshold:
452 self._upload_manager.upload_stream(
453 namespace_name=self._namespace,
454 bucket_name=bucket,
455 object_name=key,
456 stream_ref=f,
457 part_size=self._multipart_chunksize,
458 allow_parallel_uploads=True,
459 metadata=validated_attributes or {},
460 )
461 else:
462 self._upload_manager.upload_stream(
463 namespace_name=self._namespace,
464 bucket_name=bucket,
465 object_name=key,
466 stream_ref=f,
467 metadata=validated_attributes or {},
468 )
469
470 return file_size
471
472 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
473
474 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
475 self._refresh_oci_client_if_needed()
476
477 if metadata is None:
478 metadata = self._get_object_metadata(remote_path)
479
480 bucket, key = split_path(remote_path)
481
482 if isinstance(f, str):
483 if os.path.dirname(f):
484 os.makedirs(os.path.dirname(f), exist_ok=True)
485
486 def _invoke_api() -> int:
487 response = self._oci_client.get_object(
488 namespace_name=self._namespace, bucket_name=bucket, object_name=key
489 )
490 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
491 temp_file_path = fp.name
492 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
493 fp.write(chunk)
494 os.rename(src=temp_file_path, dst=f)
495
496 return metadata.content_length
497
498 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
499 else:
500
501 def _invoke_api() -> int:
502 response = self._oci_client.get_object(
503 namespace_name=self._namespace, bucket_name=bucket, object_name=key
504 )
505 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
506 if isinstance(f, io.StringIO):
507 bytes_fileobj = io.BytesIO()
508 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
509 bytes_fileobj.write(chunk)
510 f.write(bytes_fileobj.getvalue().decode("utf-8"))
511 else:
512 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
513 f.write(chunk)
514
515 return metadata.content_length
516
517 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)