1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import time
20from collections.abc import Callable, Iterator, Sequence, Sized
21from typing import IO, Any, Optional, TypeVar, Union
22
23import oci
24import opentelemetry.metrics as api_metrics
25from dateutil.parser import parse as dateutil_parser
26from oci._vendor.requests.exceptions import (
27 ChunkedEncodingError,
28 ConnectionError,
29 ContentDecodingError,
30)
31from oci.exceptions import ServiceError
32from oci.object_storage import ObjectStorageClient, UploadManager
33from oci.retry import DEFAULT_RETRY_STRATEGY, RetryStrategyBuilder
34
35from ..telemetry import Telemetry
36from ..telemetry.attributes.base import AttributesProvider
37from ..types import (
38 AWARE_DATETIME_MIN,
39 CredentialsProvider,
40 ObjectMetadata,
41 PreconditionFailedError,
42 Range,
43 RetryableError,
44)
45from ..utils import split_path, validate_attributes
46from .base import BaseStorageProvider
47
48_T = TypeVar("_T")
49
50MB = 1024 * 1024
51
52MULTIPART_THRESHOLD = 512 * MB
53MULTIPART_CHUNK_SIZE = 256 * MB
54
55PROVIDER = "oci"
56
57
[docs]
58class OracleStorageProvider(BaseStorageProvider):
59 """
60 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with
61 Oracle Cloud Infrastructure (OCI) Object Storage.
62 """
63
64 def __init__(
65 self,
66 namespace: str,
67 base_path: str = "",
68 credentials_provider: Optional[CredentialsProvider] = None,
69 retry_strategy: Optional[dict[str, Any]] = None,
70 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
71 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
72 metric_attributes_providers: Sequence[AttributesProvider] = (),
73 **kwargs: Any,
74 ) -> None:
75 """
76 Initializes an instance of :py:class:`OracleStorageProvider`.
77
78 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy.
79 :param base_path: The root prefix path within the bucket where all operations will be scoped.
80 :param credentials_provider: The provider to retrieve OCI credentials.
81 :param retry_strategy: ``oci.retry.RetryStrategyBuilder`` parameters.
82 :param metric_counters: Metric counters.
83 :param metric_gauges: Metric gauges.
84 :param metric_attributes_providers: Metric attributes providers.
85 """
86 super().__init__(
87 base_path=base_path,
88 provider_name=PROVIDER,
89 metric_counters=metric_counters,
90 metric_gauges=metric_gauges,
91 metric_attributes_providers=metric_attributes_providers,
92 )
93
94 self._namespace = namespace
95 self._credentials_provider = credentials_provider
96 self._retry_strategy = (
97 DEFAULT_RETRY_STRATEGY
98 if retry_strategy is None
99 else RetryStrategyBuilder(**retry_strategy).get_retry_strategy()
100 )
101 self._oci_client = self._create_oci_client()
102 self._upload_manager = UploadManager(self._oci_client)
103 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD))
104 self._multipart_chunk_size = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNK_SIZE))
105
106 def _create_oci_client(self) -> ObjectStorageClient:
107 config = oci.config.from_file()
108 return ObjectStorageClient(config, retry_strategy=self._retry_strategy)
109
110 def _refresh_oci_client_if_needed(self) -> None:
111 """
112 Refreshes the OCI client if the current credentials are expired.
113 """
114 if self._credentials_provider:
115 credentials = self._credentials_provider.get_credentials()
116 if credentials.is_expired():
117 self._credentials_provider.refresh_credentials()
118 self._oci_client = self._create_oci_client()
119 self._upload_manager = UploadManager(
120 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4
121 )
122
123 def _collect_metrics(
124 self,
125 func: Callable[[], _T],
126 operation: str,
127 bucket: str,
128 key: str,
129 put_object_size: Optional[int] = None,
130 get_object_size: Optional[int] = None,
131 ) -> _T:
132 """
133 Collects and records performance metrics around object storage operations such as PUT, GET, DELETE, etc.
134
135 This method wraps an object storage operation and measures the time it takes to complete, along with recording
136 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
137 proper logging of duration and object size.
138
139 :param func: The function that performs the actual object storage operation.
140 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
141 :param bucket: The name of the object storage bucket involved in the operation.
142 :param key: The key of the object within the object storage bucket.
143 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
144 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
145
146 :return: The result of the object storage operation, typically the return value of the `func` callable.
147 """
148 start_time = time.time()
149 status_code = 200
150
151 object_size = None
152 if operation == "PUT":
153 object_size = put_object_size
154 elif operation == "GET" and get_object_size:
155 object_size = get_object_size
156
157 try:
158 result = func()
159 if operation == "GET" and object_size is None and isinstance(result, Sized):
160 object_size = len(result)
161 return result
162 except ServiceError as error:
163 status_code = error.status
164 request_id = error.request_id
165 endpoint = error.request_endpoint
166 error_info = f"request_id: {request_id}, endpoint: {endpoint}, status_code: {status_code}"
167
168 if status_code == 404:
169 raise FileNotFoundError(f"Object {bucket}/{key} does not exist. {error_info}") # pylint: disable=raise-missing-from
170 elif status_code == 412:
171 raise PreconditionFailedError(
172 f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}"
173 ) from error
174 elif status_code == 429:
175 raise RetryableError(
176 f"Too many request to {operation} object(s) at {bucket}/{key}. {error_info}"
177 ) from error
178 else:
179 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}. {error_info}") from error
180 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error:
181 status_code = -1
182 raise RetryableError(
183 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}"
184 ) from error
185 except Exception as error:
186 status_code = -1
187 raise RuntimeError(
188 f"Failed to {operation} object(s) at {bucket}/{key}, error type: {type(error).__name__}, error: {error}"
189 ) from error
190 finally:
191 elapsed_time = time.time() - start_time
192 self._metric_helper.record_duration(
193 elapsed_time, provider=self._provider_name, operation=operation, bucket=bucket, status_code=status_code
194 )
195 if object_size:
196 self._metric_helper.record_object_size(
197 object_size,
198 provider=self._provider_name,
199 operation=operation,
200 bucket=bucket,
201 status_code=status_code,
202 )
203
204 def _put_object(
205 self,
206 path: str,
207 body: bytes,
208 if_match: Optional[str] = None,
209 if_none_match: Optional[str] = None,
210 attributes: Optional[dict[str, str]] = None,
211 ) -> int:
212 bucket, key = split_path(path)
213 self._refresh_oci_client_if_needed()
214
215 # OCI only supports if_none_match=="*"
216 # refer: https://docs.oracle.com/en-us/iaas/tools/python/2.150.0/api/object_storage/client/oci.object_storage.ObjectStorageClient.html?highlight=put_object#oci.object_storage.ObjectStorageClient.put_object
217 def _invoke_api() -> int:
218 validated_attributes = validate_attributes(attributes)
219 self._oci_client.put_object(
220 namespace_name=self._namespace,
221 bucket_name=bucket,
222 object_name=key,
223 put_object_body=body,
224 opc_meta=validated_attributes or {}, # Pass metadata or empty dict
225 if_match=if_match,
226 if_none_match=if_none_match,
227 )
228
229 return len(body)
230
231 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
232
233 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
234 bucket, key = split_path(path)
235 self._refresh_oci_client_if_needed()
236
237 def _invoke_api() -> bytes:
238 if byte_range:
239 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
240 else:
241 bytes_range = None
242 response = self._oci_client.get_object(
243 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range
244 )
245 return response.data.content # pyright: ignore [reportOptionalMemberAccess]
246
247 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
248
249 def _copy_object(self, src_path: str, dest_path: str) -> int:
250 src_bucket, src_key = split_path(src_path)
251 dest_bucket, dest_key = split_path(dest_path)
252 self._refresh_oci_client_if_needed()
253
254 src_object = self._get_object_metadata(src_path)
255
256 def _invoke_api() -> int:
257 copy_details = oci.object_storage.models.CopyObjectDetails(
258 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key
259 )
260
261 self._oci_client.copy_object(
262 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details
263 )
264
265 return src_object.content_length
266
267 return self._collect_metrics(
268 _invoke_api,
269 operation="COPY",
270 bucket=src_bucket,
271 key=src_key,
272 put_object_size=src_object.content_length,
273 )
274
275 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
276 bucket, key = split_path(path)
277 self._refresh_oci_client_if_needed()
278
279 def _invoke_api() -> None:
280 namespace_name = self._namespace
281 bucket_name = bucket
282 object_name = key
283 if if_match is not None:
284 self._oci_client.delete_object(namespace_name, bucket_name, object_name, if_match=if_match)
285 else:
286 self._oci_client.delete_object(namespace_name, bucket_name, object_name)
287
288 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
289
290 def _is_dir(self, path: str) -> bool:
291 # Ensure the path ends with '/' to mimic a directory
292 path = self._append_delimiter(path)
293
294 bucket, key = split_path(path)
295 self._refresh_oci_client_if_needed()
296
297 def _invoke_api() -> bool:
298 # List objects with the given prefix
299 response = self._oci_client.list_objects(
300 namespace_name=self._namespace,
301 bucket_name=bucket,
302 prefix=key,
303 delimiter="/",
304 )
305 # Check if there are any contents or common prefixes
306 if response:
307 return bool(response.data.objects or response.data.prefixes)
308 return False
309
310 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key)
311
312 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
313 if path.endswith("/"):
314 # If path is a "directory", then metadata is not guaranteed to exist if
315 # it is a "virtual prefix" that was never explicitly created.
316 if self._is_dir(path):
317 return ObjectMetadata(
318 key=path,
319 type="directory",
320 content_length=0,
321 last_modified=AWARE_DATETIME_MIN,
322 )
323 else:
324 raise FileNotFoundError(f"Directory {path} does not exist.")
325 else:
326 bucket, key = split_path(path)
327 self._refresh_oci_client_if_needed()
328
329 def _invoke_api() -> ObjectMetadata:
330 response = self._oci_client.head_object(
331 namespace_name=self._namespace, bucket_name=bucket, object_name=key
332 )
333 return ObjectMetadata(
334 key=path,
335 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess]
336 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess]
337 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess]
338 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess]
339 )
340
341 try:
342 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
343 except FileNotFoundError as error:
344 if strict:
345 # If the object does not exist on the given path, we will append a trailing slash and
346 # check if the path is a directory.
347 path = self._append_delimiter(path)
348 if self._is_dir(path):
349 return ObjectMetadata(
350 key=path,
351 type="directory",
352 content_length=0,
353 last_modified=AWARE_DATETIME_MIN,
354 )
355 raise error
356
357 def _list_objects(
358 self,
359 prefix: str,
360 start_after: Optional[str] = None,
361 end_at: Optional[str] = None,
362 include_directories: bool = False,
363 ) -> Iterator[ObjectMetadata]:
364 bucket, prefix = split_path(prefix)
365 self._refresh_oci_client_if_needed()
366
367 def _invoke_api() -> Iterator[ObjectMetadata]:
368 # ListObjects only includes object names by default.
369 #
370 # Request additional fields needed for creating an ObjectMetadata.
371 fields = ",".join(
372 [
373 "etag",
374 "name",
375 "size",
376 "timeModified",
377 ]
378 )
379 next_start_with: Optional[str] = start_after
380 while True:
381 if include_directories:
382 response = self._oci_client.list_objects(
383 namespace_name=self._namespace,
384 bucket_name=bucket,
385 prefix=prefix,
386 # This is ≥ instead of >.
387 start=next_start_with,
388 delimiter="/",
389 fields=fields,
390 )
391 else:
392 response = self._oci_client.list_objects(
393 namespace_name=self._namespace,
394 bucket_name=bucket,
395 prefix=prefix,
396 # This is ≥ instead of >.
397 start=next_start_with,
398 fields=fields,
399 )
400
401 if not response:
402 return []
403
404 if include_directories:
405 for directory in response.data.prefixes:
406 yield ObjectMetadata(
407 key=directory.rstrip("/"),
408 type="directory",
409 content_length=0,
410 last_modified=AWARE_DATETIME_MIN,
411 )
412
413 # OCI guarantees lexicographical order.
414 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess]
415 key = response_object.name
416 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
417 if key.endswith("/"):
418 if include_directories:
419 yield ObjectMetadata(
420 key=os.path.join(bucket, key.rstrip("/")),
421 type="directory",
422 content_length=0,
423 last_modified=response_object.time_modified,
424 )
425 else:
426 yield ObjectMetadata(
427 key=os.path.join(bucket, key),
428 type="file",
429 content_length=response_object.size,
430 last_modified=response_object.time_modified,
431 etag=response_object.etag,
432 )
433 elif start_after != key:
434 return
435 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess]
436 if next_start_with is None or (end_at is not None and end_at < next_start_with):
437 return
438
439 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
440
441 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
442 bucket, key = split_path(remote_path)
443 file_size: int = 0
444 self._refresh_oci_client_if_needed()
445
446 if isinstance(f, str):
447 file_size = os.path.getsize(f)
448
449 def _invoke_api() -> int:
450 if file_size > self._multipart_threshold:
451 self._upload_manager.upload_file(
452 namespace_name=self._namespace,
453 bucket_name=bucket,
454 object_name=key,
455 file_path=f,
456 part_size=self._multipart_chunk_size,
457 allow_parallel_uploads=True,
458 )
459 else:
460 self._upload_manager.upload_file(
461 namespace_name=self._namespace, bucket_name=bucket, object_name=key, file_path=f
462 )
463
464 return file_size
465
466 return self._collect_metrics(
467 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
468 )
469 else:
470 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
471 if isinstance(f, io.StringIO):
472 f = io.BytesIO(f.getvalue().encode("utf-8"))
473
474 f.seek(0, io.SEEK_END)
475 file_size = f.tell()
476 f.seek(0)
477
478 def _invoke_api() -> int:
479 if file_size > self._multipart_threshold:
480 self._upload_manager.upload_stream(
481 namespace_name=self._namespace,
482 bucket_name=bucket,
483 object_name=key,
484 stream_ref=f,
485 part_size=self._multipart_chunk_size,
486 allow_parallel_uploads=True,
487 )
488 else:
489 self._upload_manager.upload_stream(
490 namespace_name=self._namespace, bucket_name=bucket, object_name=key, stream_ref=f
491 )
492
493 return file_size
494
495 return self._collect_metrics(
496 _invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=file_size
497 )
498
499 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
500 self._refresh_oci_client_if_needed()
501
502 if metadata is None:
503 metadata = self._get_object_metadata(remote_path)
504
505 bucket, key = split_path(remote_path)
506
507 if isinstance(f, str):
508 os.makedirs(os.path.dirname(f), exist_ok=True)
509
510 def _invoke_api() -> int:
511 response = self._oci_client.get_object(
512 namespace_name=self._namespace, bucket_name=bucket, object_name=key
513 )
514 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
515 temp_file_path = fp.name
516 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
517 fp.write(chunk)
518 os.rename(src=temp_file_path, dst=f)
519
520 return metadata.content_length
521
522 return self._collect_metrics(
523 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
524 )
525 else:
526
527 def _invoke_api() -> int:
528 response = self._oci_client.get_object(
529 namespace_name=self._namespace, bucket_name=bucket, object_name=key
530 )
531 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
532 if isinstance(f, io.StringIO):
533 bytes_fileobj = io.BytesIO()
534 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
535 bytes_fileobj.write(chunk)
536 f.write(bytes_fileobj.getvalue().decode("utf-8"))
537 else:
538 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
539 f.write(chunk)
540
541 return metadata.content_length
542
543 return self._collect_metrics(
544 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
545 )