1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import time
20from collections.abc import Callable, Iterator, Sequence, Sized
21from typing import IO, Any, Optional, TypeVar, Union
22
23import oci
24import opentelemetry.metrics as api_metrics
25from dateutil.parser import parse as dateutil_parser
26from oci._vendor.requests.exceptions import (
27 ChunkedEncodingError,
28 ConnectionError,
29 ContentDecodingError,
30)
31from oci.exceptions import ServiceError
32from oci.object_storage import ObjectStorageClient, UploadManager
33from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
34
35from ..telemetry import Telemetry
36from ..telemetry.attributes.base import AttributesProvider
37from ..types import (
38 AWARE_DATETIME_MIN,
39 CredentialsProvider,
40 ObjectMetadata,
41 PreconditionFailedError,
42 Range,
43 RetryableError,
44)
45from ..utils import split_path, validate_attributes
46from .base import BaseStorageProvider
47
48_T = TypeVar("_T")
49
50MB = 1024 * 1024
51
52MULTIPART_THRESHOLD = 512 * MB
53MULTIPART_CHUNKSIZE = 256 * MB
54
55PROVIDER = "oci"
56
57
[docs]
58class OracleStorageProvider(BaseStorageProvider):
59 """
60 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with
61 Oracle Cloud Infrastructure (OCI) Object Storage.
62 """
63
64 def __init__(
65 self,
66 namespace: str,
67 base_path: str = "",
68 credentials_provider: Optional[CredentialsProvider] = None,
69 retry_strategy: Optional[dict[str, Any]] = None,
70 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
71 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
72 metric_attributes_providers: Sequence[AttributesProvider] = (),
73 **kwargs: Any,
74 ) -> None:
75 """
76 Initializes an instance of :py:class:`OracleStorageProvider`.
77
78 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy.
79 :param base_path: The root prefix path within the bucket where all operations will be scoped.
80 :param credentials_provider: The provider to retrieve OCI credentials.
81 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters.
82 :param metric_counters: Metric counters.
83 :param metric_gauges: Metric gauges.
84 :param metric_attributes_providers: Metric attributes providers.
85 """
86 super().__init__(
87 base_path=base_path,
88 provider_name=PROVIDER,
89 metric_counters=metric_counters,
90 metric_gauges=metric_gauges,
91 metric_attributes_providers=metric_attributes_providers,
92 )
93
94 self._namespace = namespace
95 self._credentials_provider = credentials_provider
96 self._retry_strategy = (
97 DEFAULT_RETRY_STRATEGY
98 if retry_strategy is None
99 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy()
100 )
101 self._oci_client = self._create_oci_client()
102 self._upload_manager = UploadManager(self._oci_client)
103 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD))
104 self._multipart_chunksize = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNKSIZE))
105
106 def _create_oci_client(self) -> ObjectStorageClient:
107 config = oci.config.from_file()
108 return ObjectStorageClient(config, retry_strategy=self._retry_strategy)
109
110 def _refresh_oci_client_if_needed(self) -> None:
111 """
112 Refreshes the OCI client if the current credentials are expired.
113 """
114 if self._credentials_provider:
115 credentials = self._credentials_provider.get_credentials()
116 if credentials.is_expired():
117 self._credentials_provider.refresh_credentials()
118 self._oci_client = self._create_oci_client()
119 self._upload_manager = UploadManager(
120 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4
121 )
122
123 def _collect_metrics(
124 self,
125 func: Callable[[], _T],
126 operation: str,
127 bucket: str,
128 key: str,
129 put_object_size: Optional[int] = None,
130 get_object_size: Optional[int] = None,
131 ) -> _T:
132 """
133 Collects and records performance metrics around object storage operations such as PUT, GET, DELETE, etc.
134
135 This method wraps an object storage operation and measures the time it takes to complete, along with recording
136 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
137 proper logging of duration and object size.
138
139 :param func: The function that performs the actual object storage operation.
140 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
141 :param bucket: The name of the object storage bucket involved in the operation.
142 :param key: The key of the object within the object storage bucket.
143 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
144 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
145
146 :return: The result of the object storage operation, typically the return value of the `func` callable.
147 """
148 start_time = time.time()
149 status_code = 200
150
151 object_size = None
152 if operation == "PUT":
153 object_size = put_object_size
154 elif operation == "GET" and get_object_size:
155 object_size = get_object_size
156
157 try:
158 result = func()
159 if operation == "GET" and object_size is None and isinstance(result, Sized):
160 object_size = len(result)
161 return result
162 except ServiceError as error:
163 status_code = error.status
164 request_id = error.request_id
165 endpoint = error.request_endpoint
166 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}"
167
168 if status_code == 404:
169 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
170 elif status_code == 412:
171 raise PreconditionFailedError(
172 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
173 ) from error
174 elif status_code == 429:
175 raise RetryableError(
176 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
177 ) from error
178 else:
179 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
180 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error:
181 status_code = -1
182 raise RetryableError(
183 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}"
184 ) from error
185 except Exception as error:
186 status_code = -1
187 raise RuntimeError(
188 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
189 ) from error
190 finally:
191 elapsed_time = time.time() - start_time
192 self._metric_helper.record_duration(
193 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code
194 )
195 if object_size:
196 self._metric_helper.record_object_size(
197 object_size,
198 provider=self._provider_name,
199 operation=operation,
200 bucket=bucket,
201 status_code=status_code,
202 )
203
204 def _put_object(
205 self,
206 path: str,
207 body: bytes,
208 if_match: Optional[str] = None,
209 if_none_match: Optional[str] = None,
210 attributes: Optional[dict[str, str]] = None,
211 ) -> int:
212 bucket, key = split_path(path)
213 self._refresh_oci_client_if_needed()
214
215 # OCI only supports if_none_match=="*"
216 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object
217 def _invoke_api() -> int:
218 validated_attributes = validate_attributes(attributes)
219 self._oci_client.put_object(
220 namespace_name=self._namespace,
221 bucket_name=bucket,
222 object_name=key,
223 put_object_body=body,
224 opc_meta=validated_attributes or {}, # Pass metadata or empty dict
225 if_match=if_match,
226 if_none_match=if_none_match,
227 )
228
229 return len(body)
230
231 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
232
233 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
234 bucket, key = split_path(path)
235 self._refresh_oci_client_if_needed()
236
237 def _invoke_api() -> bytes:
238 if byte_range:
239 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
240 else:
241 bytes_range = None
242 response = self._oci_client.get_object(
243 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range
244 )
245 return response.data.content # pyright: ignore [reportOptionalMemberAccess]
246
247 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
248
249 def _copy_object(self, src_path: str, dest_path: str) -> int:
250 src_bucket, src_key = split_path(src_path)
251 dest_bucket, dest_key = split_path(dest_path)
252 self._refresh_oci_client_if_needed()
253
254 src_object = self._get_object_metadata(src_path)
255
256 def _invoke_api() -> int:
257 copy_details = oci.object_storage.models.CopyObjectDetails(
258 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key
259 )
260
261 self._oci_client.copy_object(
262 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details
263 )
264
265 return src_object.content_length
266
267 return self._collect_metrics(
268 _invoke_api,
269 operation="COPY",
270 bucket=src_bucket,
271 key=src_key,
272 put_object_size=src_object.content_length,
273 )
274
275 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
276 bucket, key = split_path(path)
277 self._refresh_oci_client_if_needed()
278
279 def _invoke_api() -> None:
280 namespace_name = self._namespace
281 bucket_name = bucket
282 object_name = key
283 if if_match is not None:
284 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match)
285 else:
286 self._oci_client.delete_object(namespace_name, bucket_name, object_name)
287
288 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
289
290 def _is_dir(self, path: str) -> bool:
291 # Ensure the path ends with '/' to mimic a directory
292 path = self._append_delimiter(path)
293
294 bucket, key = split_path(path)
295 self._refresh_oci_client_if_needed()
296
297 def _invoke_api() -> bool:
298 # List objects with the given prefix
299 response = self._oci_client.list_objects(
300 namespace_name=self._namespace,
301 bucket_name=bucket,
302 prefix=key,
303 delimiter="/",
304 )
305 # Check if there are any contents or common prefixes
306 if response:
307 return bool(response.data.objects or response.data.prefixes)
308 return False
309
310 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key)
311
312 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
313 bucket, key = split_path(path)
314 if path.endswith("/") or (bucket and not key):
315 # If path ends with "/" or empty key name is provided, then assume it's a "directory",
316 # which metadata is not guaranteed to exist for cases such as
317 # "virtual prefix" that was never explicitly created.
318 if self._is_dir(path):
319 return ObjectMetadata(
320 key=path,
321 type="directory",
322 content_length=0,
323 last_modified=AWARE_DATETIME_MIN,
324 )
325 else:
326 raise FileNotFoundError(f"Directory {path} does not exist.")
327 else:
328 self._refresh_oci_client_if_needed()
329
330 def _invoke_api() -> ObjectMetadata:
331 response = self._oci_client.head_object(
332 namespace_name=self._namespace, bucket_name=bucket, object_name=key
333 )
334
335 # Extract custom metadata from headers with 'opc-meta-' prefix
336 attributes = {}
337 if response.headers: # pyright: ignore [reportOptionalMemberAccess]
338 for metadata_key, metadata_val in response.headers.items(): # pyright: ignore [reportOptionalMemberAccess]
339 if metadata_key.startswith("opc-meta-"):
340 # Remove the 'opc-meta-' prefix to get the original key
341 metadata_key = metadata_key[len("opc-meta-") :]
342 attributes[metadata_key] = metadata_val
343
344 return ObjectMetadata(
345 key=path,
346 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess]
347 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess]
348 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess]
349 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess]
350 metadata=attributes if attributes else None,
351 )
352
353 try:
354 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
355 except FileNotFoundError as error:
356 if strict:
357 # If the object does not exist on the given path, we will append a trailing slash and
358 # check if the path is a directory.
359 path = self._append_delimiter(path)
360 if self._is_dir(path):
361 return ObjectMetadata(
362 key=path,
363 type="directory",
364 content_length=0,
365 last_modified=AWARE_DATETIME_MIN,
366 )
367 raise error
368
369 def _list_objects(
370 self,
371 path: str,
372 start_after: Optional[str] = None,
373 end_at: Optional[str] = None,
374 include_directories: bool = False,
375 ) -> Iterator[ObjectMetadata]:
376 bucket, prefix = split_path(path)
377 self._refresh_oci_client_if_needed()
378
379 def _invoke_api() -> Iterator[ObjectMetadata]:
380 # ListObjects only includes object names by default.
381 #
382 # Request additional fields needed for creating an ObjectMetadata.
383 fields = ",".join(
384 [
385 "etag",
386 "name",
387 "size",
388 "timeModified",
389 ]
390 )
391 next_start_with: Optional[str] = start_after
392 while True:
393 if include_directories:
394 response = self._oci_client.list_objects(
395 namespace_name=self._namespace,
396 bucket_name=bucket,
397 prefix=prefix,
398 # This is ≥ instead of >.
399 start=next_start_with,
400 delimiter="/",
401 fields=fields,
402 )
403 else:
404 response = self._oci_client.list_objects(
405 namespace_name=self._namespace,
406 bucket_name=bucket,
407 prefix=prefix,
408 # This is ≥ instead of >.
409 start=next_start_with,
410 fields=fields,
411 )
412
413 if not response:
414 return []
415
416 if include_directories:
417 for directory in response.data.prefixes:
418 yield ObjectMetadata(
419 key=directory.rstrip("/"),
420 type="directory",
421 content_length=0,
422 last_modified=AWARE_DATETIME_MIN,
423 )
424
425 # OCI guarantees lexicographical order.
426 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess]
427 key = response_object.name
428 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
429 if key.endswith("/"):
430 if include_directories:
431 yield ObjectMetadata(
432 key=os.path.join(bucket, key.rstrip("/")),
433 type="directory",
434 content_length=0,
435 last_modified=response_object.time_modified,
436 )
437 else:
438 yield ObjectMetadata(
439 key=os.path.join(bucket, key),
440 type="file",
441 content_length=response_object.size,
442 last_modified=response_object.time_modified,
443 etag=response_object.etag,
444 )
445 elif start_after != key:
446 return
447 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess]
448 if next_start_with is None or (end_at is not None and end_at < next_start_with):
449 return
450
451 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
452
453 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
454 bucket, key = split_path(remote_path)
455 file_size: int = 0
456 self._refresh_oci_client_if_needed()
457
458 validated_attributes = validate_attributes(attributes)
459 if isinstance(f, str):
460 file_size = os.path.getsize(f)
461
462 def _invoke_api() -> int:
463 if file_size > self._multipart_threshold:
464 self._upload_manager.upload_file(
465 namespace_name=self._namespace,
466 bucket_name=bucket,
467 object_name=key,
468 file_path=f,
469 part_size=self._multipart_chunksize,
470 allow_parallel_uploads=True,
471 metadata=validated_attributes or {},
472 )
473 else:
474 self._upload_manager.upload_file(
475 namespace_name=self._namespace,
476 bucket_name=bucket,
477 object_name=key,
478 file_path=f,
479 metadata=validated_attributes or {},
480 )
481
482 return file_size
483
484 return self._collect_metrics(
485 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
486 )
487 else:
488 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
489 if isinstance(f, io.StringIO):
490 f = io.BytesIO(f.getvalue().encode("utf-8"))
491
492 f.seek(0, io.SEEK_END)
493 file_size = f.tell()
494 f.seek(0)
495
496 def _invoke_api() -> int:
497 if file_size > self._multipart_threshold:
498 self._upload_manager.upload_stream(
499 namespace_name=self._namespace,
500 bucket_name=bucket,
501 object_name=key,
502 stream_ref=f,
503 part_size=self._multipart_chunksize,
504 allow_parallel_uploads=True,
505 metadata=validated_attributes or {},
506 )
507 else:
508 self._upload_manager.upload_stream(
509 namespace_name=self._namespace,
510 bucket_name=bucket,
511 object_name=key,
512 stream_ref=f,
513 metadata=validated_attributes or {},
514 )
515
516 return file_size
517
518 return self._collect_metrics(
519 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
520 )
521
522 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
523 self._refresh_oci_client_if_needed()
524
525 if metadata is None:
526 metadata = self._get_object_metadata(remote_path)
527
528 bucket, key = split_path(remote_path)
529
530 if isinstance(f, str):
531 if os.path.dirname(f):
532 os.makedirs(os.path.dirname(f), exist_ok=True)
533
534 def _invoke_api() -> int:
535 response = self._oci_client.get_object(
536 namespace_name=self._namespace, bucket_name=bucket, object_name=key
537 )
538 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
539 temp_file_path = fp.name
540 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
541 fp.write(chunk)
542 os.rename(src=temp_file_path, dst=f)
543
544 return metadata.content_length
545
546 return self._collect_metrics(
547 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
548 )
549 else:
550
551 def _invoke_api() -> int:
552 response = self._oci_client.get_object(
553 namespace_name=self._namespace, bucket_name=bucket, object_name=key
554 )
555 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
556 if isinstance(f, io.StringIO):
557 bytes_fileobj = io.BytesIO()
558 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
559 bytes_fileobj.write(chunk)
560 f.write(bytes_fileobj.getvalue().decode("utf-8"))
561 else:
562 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
563 f.write(chunk)
564
565 return metadata.content_length
566
567 return self._collect_metrics(
568 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
569 )