1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import time
20from datetime import datetime
21from typing import IO, Any, Callable, Iterator, Optional, Union
22
23import oci
24from dateutil.parser import parse as dateutil_parser
25from oci._vendor.requests.exceptions import (
26 ChunkedEncodingError,
27 ConnectionError,
28 ContentDecodingError,
29)
30from oci.exceptions import ServiceError
31from oci.object_storage import ObjectStorageClient, UploadManager
32from oci.retry import DEFAULT_RETRY_STRATEGY
33
34from ..types import (
35 CredentialsProvider,
36 ObjectMetadata,
37 Range,
38 RetryableError,
39)
40from ..utils import split_path
41from .base import BaseStorageProvider
42
43MB = 1024 * 1024
44
45MULTIPART_THRESHOLD = 512 * MB
46MULTIPART_CHUNK_SIZE = 256 * MB
47
48PROVIDER = "oci"
49
50
[docs]
51class OracleStorageProvider(BaseStorageProvider):
52 """
53 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with
54 Oracle Cloud Infrastructure (OCI) Object Storage.
55 """
56
57 def __init__(
58 self,
59 namespace: str,
60 base_path: str = "",
61 credentials_provider: Optional[CredentialsProvider] = None,
62 **kwargs: Any,
63 ) -> None:
64 """
65 Initializes an instance of :py:class:`OracleStorageProvider`.
66
67 :param namespace: The OCI Object Storage namespace. This is a unique identifier assigned to each tenancy.
68 :param base_path: The root prefix path within the bucket where all operations will be scoped.
69 :param credentials_provider: The provider to retrieve OCI credentials.
70 """
71 super().__init__(base_path=base_path, provider_name=PROVIDER)
72
73 self._namespace = namespace
74 self._credentials_provider = credentials_provider
75 self._oci_client = self._create_oci_client()
76 self._upload_manager = UploadManager(self._oci_client)
77 self._multipart_threshold = int(kwargs.get("multipart_threshold", MULTIPART_THRESHOLD))
78 self._multipart_chunk_size = int(kwargs.get("multipart_chunksize", MULTIPART_CHUNK_SIZE))
79
80 def _create_oci_client(self) -> ObjectStorageClient:
81 config = oci.config.from_file()
82 return ObjectStorageClient(config, retry_strategy=DEFAULT_RETRY_STRATEGY)
83
84 def _refresh_oci_client_if_needed(self) -> None:
85 """
86 Refreshes the OCI client if the current credentials are expired.
87 """
88 if self._credentials_provider:
89 credentials = self._credentials_provider.get_credentials()
90 if credentials.is_expired():
91 self._credentials_provider.refresh_credentials()
92 self._oci_client = self._create_oci_client()
93 self._upload_manager = UploadManager(
94 self._oci_client, allow_parallel_uploads=True, parallel_process_count=4
95 )
96
97 def _collect_metrics(
98 self,
99 func: Callable,
100 operation: str,
101 bucket: str,
102 key: str,
103 put_object_size: Optional[int] = None,
104 get_object_size: Optional[int] = None,
105 ) -> Any:
106 """
107 Collects and records performance metrics around object storage operations such as PUT, GET, DELETE, etc.
108
109 This method wraps an object storage operation and measures the time it takes to complete, along with recording
110 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
111 proper logging of duration and object size.
112
113 :param func: The function that performs the actual object storage operation.
114 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
115 :param bucket: The name of the object storage bucket involved in the operation.
116 :param key: The key of the object within the object storage bucket.
117 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
118 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
119
120 :return: The result of the object storage operation, typically the return value of the `func` callable.
121 """
122 start_time = time.time()
123 status_code = 200
124
125 object_size = None
126 if operation == "PUT":
127 object_size = put_object_size
128 elif operation == "GET" and get_object_size:
129 object_size = get_object_size
130
131 try:
132 result = func()
133 if operation == "GET" and object_size is None:
134 object_size = len(result)
135 return result
136 except ServiceError as error:
137 status_code = error.status
138 if status_code == 404:
139 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
140 elif status_code == 429:
141 raise RetryableError(f"Too many request to {operation} object(s) at {bucket}/{key}.") from error
142 else:
143 raise RuntimeError("Failed to {operation} object(s) at {bucket}/{key}") from error
144 except (ConnectionError, ChunkedEncodingError, ContentDecodingError) as error:
145 status_code = -1
146 raise RetryableError(f"Failed to {operation} object(s) at {bucket}/{key}") from error
147 except Exception as error:
148 status_code = -1
149 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error
150 finally:
151 elapsed_time = time.time() - start_time
152 self._metric_helper.record_duration(
153 elapsed_time, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code
154 )
155 if object_size:
156 self._metric_helper.record_object_size(
157 object_size, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code
158 )
159
160 def _put_object(self, path: str, body: bytes) -> None:
161 bucket, key = split_path(path)
162 self._refresh_oci_client_if_needed()
163
164 def _invoke_api() -> None:
165 self._oci_client.put_object(
166 namespace_name=self._namespace, bucket_name=bucket, object_name=key, put_object_body=body
167 )
168
169 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
170
171 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
172 bucket, key = split_path(path)
173 self._refresh_oci_client_if_needed()
174
175 def _invoke_api() -> bytes:
176 if byte_range:
177 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
178 else:
179 bytes_range = None
180 response = self._oci_client.get_object(
181 namespace_name=self._namespace, bucket_name=bucket, object_name=key, range=bytes_range
182 )
183 return response.data.content # pyright: ignore [reportOptionalMemberAccess]
184
185 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
186
187 def _copy_object(self, src_path: str, dest_path: str) -> None:
188 src_bucket, src_key = split_path(src_path)
189 dest_bucket, dest_key = split_path(dest_path)
190 self._refresh_oci_client_if_needed()
191
192 def _invoke_api() -> None:
193 copy_details = oci.object_storage.models.CopyObjectDetails(
194 source_object_name=src_key, destination_bucket=dest_bucket, destination_object_name=dest_key
195 )
196
197 self._oci_client.copy_object(
198 namespace_name=self._namespace, bucket_name=src_bucket, copy_object_details=copy_details
199 )
200
201 src_object = self._get_object_metadata(src_path)
202
203 return self._collect_metrics(
204 _invoke_api,
205 operation="COPY",
206 bucket=src_bucket,
207 key=src_key,
208 put_object_size=src_object.content_length,
209 )
210
211 def _delete_object(self, path: str) -> None:
212 bucket, key = split_path(path)
213 self._refresh_oci_client_if_needed()
214
215 def _invoke_api() -> None:
216 self._oci_client.delete_object(namespace_name=self._namespace, bucket_name=bucket, object_name=key)
217
218 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
219
220 def _is_dir(self, path: str) -> bool:
221 # Ensure the path ends with '/' to mimic a directory
222 path = self._append_delimiter(path)
223
224 bucket, key = split_path(path)
225 self._refresh_oci_client_if_needed()
226
227 def _invoke_api() -> bool:
228 # List objects with the given prefix
229 response = self._oci_client.list_objects(
230 namespace_name=self._namespace,
231 bucket_name=bucket,
232 prefix=key,
233 delimiter="/",
234 )
235 # Check if there are any contents or common prefixes
236 if response:
237 return bool(response.data.objects or response.data.prefixes)
238 return False
239
240 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=key)
241
242 def _get_object_metadata(self, path: str) -> ObjectMetadata:
243 if path.endswith("/"):
244 # If path is a "directory", then metadata is not guaranteed to exist if
245 # it is a "virtual prefix" that was never explicitly created.
246 if self._is_dir(path):
247 return ObjectMetadata(
248 key=path,
249 type="directory",
250 content_length=0,
251 last_modified=datetime.min,
252 )
253 else:
254 raise FileNotFoundError(f"Directory {path} does not exist.")
255 else:
256 bucket, key = split_path(path)
257 self._refresh_oci_client_if_needed()
258
259 def _invoke_api() -> ObjectMetadata:
260 response = self._oci_client.head_object(
261 namespace_name=self._namespace, bucket_name=bucket, object_name=key
262 )
263 return ObjectMetadata(
264 key=path,
265 content_length=int(response.headers["Content-Length"]), # pyright: ignore [reportOptionalMemberAccess]
266 content_type=response.headers.get("Content-Type", None), # pyright: ignore [reportOptionalMemberAccess]
267 last_modified=dateutil_parser(response.headers["last-modified"]), # pyright: ignore [reportOptionalMemberAccess]
268 etag=response.headers.get("etag", None), # pyright: ignore [reportOptionalMemberAccess]
269 )
270
271 try:
272 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
273 except FileNotFoundError as error:
274 # If the object does not exist on the given path, we will append a trailing slash and
275 # check if the path is a directory.
276 path = self._append_delimiter(path)
277 if self._is_dir(path):
278 return ObjectMetadata(
279 key=path,
280 type="directory",
281 content_length=0,
282 last_modified=datetime.min,
283 )
284 else:
285 raise error
286
287 def _list_objects(
288 self,
289 prefix: str,
290 start_after: Optional[str] = None,
291 end_at: Optional[str] = None,
292 include_directories: bool = False,
293 ) -> Iterator[ObjectMetadata]:
294 bucket, prefix = split_path(prefix)
295 self._refresh_oci_client_if_needed()
296
297 def _invoke_api() -> Iterator[ObjectMetadata]:
298 next_start_with: Optional[str] = start_after
299 while True:
300 if include_directories:
301 response = self._oci_client.list_objects(
302 namespace_name=self._namespace,
303 bucket_name=bucket,
304 prefix=prefix,
305 # This is ≥ instead of >.
306 start=next_start_with,
307 delimiter="/",
308 )
309 else:
310 response = self._oci_client.list_objects(
311 namespace_name=self._namespace,
312 bucket_name=bucket,
313 prefix=prefix,
314 # This is ≥ instead of >.
315 start=next_start_with,
316 )
317
318 if not response:
319 return []
320
321 if include_directories:
322 for directory in response.data.prefixes:
323 yield ObjectMetadata(
324 key=directory.rstrip("/"),
325 type="directory",
326 content_length=0,
327 last_modified=datetime.min,
328 )
329
330 # OCI guarantees lexicographical order.
331 for response_object in response.data.objects: # pyright: ignore [reportOptionalMemberAccess]
332 key = response_object.name
333 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
334 yield ObjectMetadata(
335 key=key,
336 content_length=response_object.size,
337 last_modified=response_object.time_modified,
338 etag=response_object.etag,
339 )
340 elif start_after != key:
341 return
342 next_start_with = response.data.next_start_with # pyright: ignore [reportOptionalMemberAccess]
343 if next_start_with is None or (end_at is not None and end_at < next_start_with):
344 return
345
346 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
347
348 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None:
349 bucket, key = split_path(remote_path)
350 self._refresh_oci_client_if_needed()
351
352 if isinstance(f, str):
353 filesize = os.path.getsize(f)
354
355 def _invoke_api() -> None:
356 if filesize > self._multipart_threshold:
357 self._upload_manager.upload_file(
358 namespace_name=self._namespace,
359 bucket_name=bucket,
360 object_name=key,
361 file_path=f,
362 part_size=self._multipart_chunk_size,
363 allow_parallel_uploads=True,
364 )
365 else:
366 self._upload_manager.upload_file(
367 namespace_name=self._namespace, bucket_name=bucket, object_name=key, file_path=f
368 )
369
370 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize)
371 else:
372 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
373 if isinstance(f, io.StringIO):
374 f = io.BytesIO(f.getvalue().encode("utf-8"))
375
376 f.seek(0, io.SEEK_END)
377 filesize = f.tell()
378 f.seek(0)
379
380 def _invoke_api() -> None:
381 if filesize > self._multipart_threshold:
382 self._upload_manager.upload_stream(
383 namespace_name=self._namespace,
384 bucket_name=bucket,
385 object_name=key,
386 stream_ref=f,
387 part_size=self._multipart_chunk_size,
388 allow_parallel_uploads=True,
389 )
390 else:
391 self._upload_manager.upload_stream(
392 namespace_name=self._namespace, bucket_name=bucket, object_name=key, stream_ref=f
393 )
394
395 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=filesize)
396
397 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
398 self._refresh_oci_client_if_needed()
399
400 if not metadata:
401 metadata = self._get_object_metadata(remote_path)
402
403 bucket, key = split_path(remote_path)
404
405 if isinstance(f, str):
406 os.makedirs(os.path.dirname(f), exist_ok=True)
407
408 def _invoke_api() -> None:
409 response = self._oci_client.get_object(
410 namespace_name=self._namespace, bucket_name=bucket, object_name=key
411 )
412 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(f), prefix=".") as fp:
413 temp_file_path = fp.name
414 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
415 fp.write(chunk)
416 os.rename(src=temp_file_path, dst=f)
417
418 return self._collect_metrics(
419 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
420 )
421 else:
422
423 def _invoke_api() -> None:
424 response = self._oci_client.get_object(
425 namespace_name=self._namespace, bucket_name=bucket, object_name=key
426 )
427 # Convert file-like object to BytesIO because stream_ref cannot work with StringIO.
428 if isinstance(f, io.StringIO):
429 bytes_fileobj = io.BytesIO()
430 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
431 bytes_fileobj.write(chunk)
432 f.write(bytes_fileobj.getvalue().decode("utf-8"))
433 else:
434 for chunk in response.data.raw.stream(1024 * 1024, decode_content=False): # pyright: ignore [reportOptionalMemberAccess]
435 f.write(chunk)
436
437 return self._collect_metrics(
438 _invoke_api, operation="GET", bucket=bucket, key=key, get_object_size=metadata.content_length
439 )