1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19from collections.abc import Callable, Iterator
20from typing import IO, Any, Optional, TypeVar, Union
21
22import oci
23from dateutil.parser import parse as dateutil_parser
24from oci._vendor.requests.exceptions import (
25 ChunkedEncodingError,
26 ConnectionError,
27 ContentDecodingError,
28)
29from oci.exceptions import ServiceError
30from oci.object_storage import ObjectStorageClient, UploadManager
31from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
32
33from ..telemetry import Telemetry
34from ..types import (
35 AWARE_DATETIME_MIN,
36 CredentialsProvider,
37 ObjectMetadata,
38 PreconditionFailedError,
39 Range,
40 RetryableError,
41)
42from ..utils import split_path, validate_attributes
43from .base import BaseStorageProvider
44
45_T = TypeVar("_T")
46
47MB = 1024 * 1024
48
49MULTIPART_THRESHOLD = 64 * MB
50MULTIPART_CHUNKSIZE = 32 * MB
51
52PROVIDER = "oci"
53
54
[docs]
55class OracleStorageProvider(BaseStorageProvider):
56 """
57 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with
58 Oracle Cloud Infrastructure (OCI) Object Storage.
59 """
60
61 def __init__(
62 self,
63 namespace: str,
64 base_path: str = "",
65 credentials_provider: Optional[CredentialsProvider] = None,
66 retry_strategy: Optional[dict[str, Any]] = None,
67 config_dict: Optional[dict[str, Any]] = None,
68 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
69 **kwargs: Any,
70 ) -> None:
71 """
72 Initializes an instance of :py:class:`OracleStorageProvider`.
73
74 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy.
75 :param base_path: The root prefix path within the bucket where all operations will be scoped.
76 :param credentials_provider: The provider to retrieve OCI credentials.
77 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters.
78 :param config_dict: Resolved MSC config.
79 :param telemetry_provider: A function that provides a telemetry instance.
80 """
81 super().__init__(
82 base_path=base_path,
83 provider_name=PROVIDER,
84 config_dict=config_dict,
85 telemetry_provider=telemetry_provider,
86 )
87
88 self._namespace = namespace
89 self._credentials_provider = credentials_provider
90 self._retry_strategy = (
91 DEFAULT_RETRY_STRATEGY
92 if retry_strategy is None
93 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy()
94 )
95 self._oci_client = self._create_oci_client()
96 self._upload_manager = UploadManager(self._oci_client)
97 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD))
98 self._multipart_chunksize = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE))
99
100 def _create_oci_client(self) -> ObjectStorageClient:
101 config = oci.config.from_file()
102 return ObjectStorageClient(config, retry_strategy=self._retry_strategy)
103
104 def _refresh_oci_client_if_needed(self) -> None:
105 """
106 Refreshes the OCI client if the current credentials are expired.
107 """
108 if self._credentials_provider:
109 credentials = self._credentials_provider.get_credentials()
110 if credentials.is_expired():
111 self._credentials_provider.refresh_credentials()
112 self._oci_client = self._create_oci_client()
113 self._upload_manager = UploadManager(
114 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4
115 )
116
117 def _translate_errors(
118 self,
119 func: Callable[[], _T],
120 operation: str,
121 bucket: str,
122 key: str,
123 ) -> _T:
124 """
125 Translates errors like timeouts and client errors.
126
127 :param func: The function that performs the actual object storage operation.
128 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
129 :param bucket: The name of the object storage bucket involved in the operation.
130 :param key: The key of the object within the object storage bucket.
131
132 :return: The result of the object storage operation, typically the return value of the `func` callable.
133 """
134 try:
135 return func()
136 except ServiceError as error:
137 status_code = error.status
138 request_id = error.request_id
139 endpoint = error.request_endpoint
140 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}"
141
142 if status_code == 404:
143 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
144 elif status_code == 412:
145 raise PreconditionFailedError(
146 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
147 ) from error
148 elif status_code == 429:
149 raise RetryableError(
150 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
151 ) from error
152 else:
153 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
154 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error:
155 raise RetryableError(
156 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}"
157 ) from error
158 except Exception as error:
159 raise RuntimeError(
160 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
161 ) from error
162
163 def _put_object(
164 self,
165 path: str,
166 body: bytes,
167 if_match: Optional[str] = None,
168 if_none_match: Optional[str] = None,
169 attributes: Optional[dict[str, str]] = None,
170 ) -> int:
171 bucket, key = split_path(path)
172 self._refresh_oci_client_if_needed()
173
174 # OCI only supports if_none_match=="*"
175 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object
176 def _invoke_api() -> int:
177 validated_attributes = validate_attributes(attributes)
178 self._oci_client.put_object(
179 namespace_name=self._namespace,
180 bucket_name=bucket,
181 object_name=key,
182 put_object_body=body,
183 opc_meta=validated_attributes or {}, # Pass metadata or empty dict
184 if_match=if_match,
185 if_none_match=if_none_match,
186 )
187
188 return len(body)
189
190 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
191
192 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
193 bucket, key = split_path(path)
194 self._refresh_oci_client_if_needed()
195
196 def _invoke_api() -> bytes:
197 if byte_range:
198 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
199 else:
200 bytes_range = None
201 response = self._oci_client.get_object(
202 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range
203 )
204 return response.data.content # pyright: ignore [reportOptionalMemberAccess]
205
206 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
207
208 def _copy_object(self, src_path: str, dest_path: str) -> int:
209 src_bucket, src_key = split_path(src_path)
210 dest_bucket, dest_key = split_path(dest_path)
211 self._refresh_oci_client_if_needed()
212
213 src_object = self._get_object_metadata(src_path)
214
215 def _invoke_api() -> int:
216 copy_details = oci.object_storage.models.CopyObjectDetails(
217 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key
218 )
219
220 self._oci_client.copy_object(
221 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details
222 )
223
224 return src_object.content_length
225
226 return self._translate_errors(_invoke_api, operation="COPY", bucket=src_bucket, key=src_key)
227
228 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
229 bucket, key = split_path(path)
230 self._refresh_oci_client_if_needed()
231
232 def _invoke_api() -> None:
233 namespace_name = self._namespace
234 bucket_name = bucket
235 object_name = key
236 if if_match is not None:
237 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match)
238 else:
239 self._oci_client.delete_object(namespace_name, bucket_name, object_name)
240
241 return self._translate_errors(_invoke_api, operation="DELETE", bucket=bucket, key=key)
242
243 def _is_dir(self, path: str) -> bool:
244 # Ensure the path ends with '/' to mimic a directory
245 path = self._append_delimiter(path)
246
247 bucket, key = split_path(path)
248 self._refresh_oci_client_if_needed()
249
250 def _invoke_api() -> bool:
251 # List objects with the given prefix
252 response = self._oci_client.list_objects(
253 namespace_name=self._namespace,
254 bucket_name=bucket,
255 prefix=key,
256 delimiter="/",
257 )
258 # Check if there are any contents or common prefixes
259 if response:
260 return bool(response.data.objects or response.data.prefixes)
261 return False
262
263 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=key)
264
265 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
266 bucket, key = split_path(path)
267 if path.endswith("/") or (bucket and not key):
268 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
269 # which metadata is not guaranteed to exist for cases such as
270 # "virtual prefix" that was never explicitly created.
271 if self._is_dir(path):
272 return ObjectMetadata(
273 key=path,
274 type="directory",
275 content_length=0,
276 last_modified=AWARE_DATETIME_MIN,
277 )
278 else:
279 raise FileNotFoundError(f"Directory {path} does not exist.")
280 else:
281 self._refresh_oci_client_if_needed()
282
283 def _invoke_api() -> ObjectMetadata:
284 response = self._oci_client.head_object(
285 namespace_name=self._namespace, bucket_name=bucket, object_name=key
286 )
287
288 # Extract custom metadata from headers with 'opc-meta-' prefix
289 attributes = {}
290 if response.headers: # pyright: ignore [reportOptionalMemberAccess]
291 for metadata_key, metadata_val in response.headers.items(): # pyright: ignore [reportOptionalMemberAccess]
292 if metadata_key.startswith("opc-meta-"):
293 # Remove the 'opc-meta-' prefix to get the original key
294 metadata_key = metadata_key[len("opc-meta-") :]
295 attributes[metadata_key] = metadata_val
296
297 return ObjectMetadata(
298 key=path,
299 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess]
300 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess]
301 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess]
302 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess]
303 metadata=attributes if attributes else None,
304 )
305
306 try:
307 return self._translate_errors(_invoke_api, operation="HEAD", bucket=bucket, key=key)
308 except FileNotFoundError as error:
309 if strict:
310 # If the object does not exist on the given path, we will append a trailing slash and
311 # check if the path is a directory.
312 path = self._append_delimiter(path)
313 if self._is_dir(path):
314 return ObjectMetadata(
315 key=path,
316 type="directory",
317 content_length=0,
318 last_modified=AWARE_DATETIME_MIN,
319 )
320 raise error
321
322 def _list_objects(
323 self,
324 path: str,
325 start_after: Optional[str] = None,
326 end_at: Optional[str] = None,
327 include_directories: bool = False,
328 ) -> Iterator[ObjectMetadata]:
329 bucket, prefix = split_path(path)
330 self._refresh_oci_client_if_needed()
331
332 def _invoke_api() -> Iterator[ObjectMetadata]:
333 # ListObjects only includes object names by default.
334 #
335 # Request additional fields needed for creating an ObjectMetadata.
336 fields = ",".join(
337 [
338 "etag",
339 "name",
340 "size",
341 "timeModified",
342 ]
343 )
344 next_start_with: Optional[str] = start_after
345 while True:
346 if include_directories:
347 response = self._oci_client.list_objects(
348 namespace_name=self._namespace,
349 bucket_name=bucket,
350 prefix=prefix,
351 # This is ≥ instead of >.
352 start=next_start_with,
353 delimiter="/",
354 fields=fields,
355 )
356 else:
357 response = self._oci_client.list_objects(
358 namespace_name=self._namespace,
359 bucket_name=bucket,
360 prefix=prefix,
361 # This is ≥ instead of >.
362 start=next_start_with,
363 fields=fields,
364 )
365
366 if not response:
367 return []
368
369 if include_directories:
370 for directory in response.data.prefixes:
371 yield ObjectMetadata(
372 key=directory.rstrip("/"),
373 type="directory",
374 content_length=0,
375 last_modified=AWARE_DATETIME_MIN,
376 )
377
378 # OCI guarantees lexicographical order.
379 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess]
380 key = response_object.name
381 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
382 if key.endswith("/"):
383 if include_directories:
384 yield ObjectMetadata(
385 key=os.path.join(bucket, key.rstrip("/")),
386 type="directory",
387 content_length=0,
388 last_modified=response_object.time_modified,
389 )
390 else:
391 yield ObjectMetadata(
392 key=os.path.join(bucket, key),
393 type="file",
394 content_length=response_object.size,
395 last_modified=response_object.time_modified,
396 etag=response_object.etag,
397 )
398 elif start_after != key:
399 return
400 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess]
401 if next_start_with is None or (end_at is not None and end_at < next_start_with):
402 return
403
404 return self._translate_errors(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
405
406 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
407 bucket, key = split_path(remote_path)
408 file_size: int = 0
409 self._refresh_oci_client_if_needed()
410
411 validated_attributes = validate_attributes(attributes)
412 if isinstance(f, str):
413 file_size = os.path.getsize(f)
414
415 def _invoke_api() -> int:
416 if file_size > self._multipart_threshold:
417 self._upload_manager.upload_file(
418 namespace_name=self._namespace,
419 bucket_name=bucket,
420 object_name=key,
421 file_path=f,
422 part_size=self._multipart_chunksize,
423 allow_parallel_uploads=True,
424 metadata=validated_attributes or {},
425 )
426 else:
427 self._upload_manager.upload_file(
428 namespace_name=self._namespace,
429 bucket_name=bucket,
430 object_name=key,
431 file_path=f,
432 metadata=validated_attributes or {},
433 )
434
435 return file_size
436
437 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
438 else:
439 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
440 if isinstance(f, io.StringIO):
441 f = io.BytesIO(f.getvalue().encode("utf-8"))
442
443 f.seek(0, io.SEEK_END)
444 file_size = f.tell()
445 f.seek(0)
446
447 def _invoke_api() -> int:
448 if file_size > self._multipart_threshold:
449 self._upload_manager.upload_stream(
450 namespace_name=self._namespace,
451 bucket_name=bucket,
452 object_name=key,
453 stream_ref=f,
454 part_size=self._multipart_chunksize,
455 allow_parallel_uploads=True,
456 metadata=validated_attributes or {},
457 )
458 else:
459 self._upload_manager.upload_stream(
460 namespace_name=self._namespace,
461 bucket_name=bucket,
462 object_name=key,
463 stream_ref=f,
464 metadata=validated_attributes or {},
465 )
466
467 return file_size
468
469 return self._translate_errors(_invoke_api, operation="PUT", bucket=bucket, key=key)
470
471 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
472 self._refresh_oci_client_if_needed()
473
474 if metadata is None:
475 metadata = self._get_object_metadata(remote_path)
476
477 bucket, key = split_path(remote_path)
478
479 if isinstance(f, str):
480 if os.path.dirname(f):
481 os.makedirs(os.path.dirname(f), exist_ok=True)
482
483 def _invoke_api() -> int:
484 response = self._oci_client.get_object(
485 namespace_name=self._namespace, bucket_name=bucket, object_name=key
486 )
487 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
488 temp_file_path = fp.name
489 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
490 fp.write(chunk)
491 os.rename(src=temp_file_path, dst=f)
492
493 return metadata.content_length
494
495 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)
496 else:
497
498 def _invoke_api() -> int:
499 response = self._oci_client.get_object(
500 namespace_name=self._namespace, bucket_name=bucket, object_name=key
501 )
502 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
503 if isinstance(f, io.StringIO):
504 bytes_fileobj = io.BytesIO()
505 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
506 bytes_fileobj.write(chunk)
507 f.write(bytes_fileobj.getvalue().decode("utf-8"))
508 else:
509 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
510 f.write(chunk)
511
512 return metadata.content_length
513
514 return self._translate_errors(_invoke_api, operation="GET", bucket=bucket, key=key)