1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17import io
18import os
19import time
20from datetime import datetime
21from typing import IO, Any, Callable, Iterator, Optional, Tuple, Union
22
23from aistore.sdk import Client
24from aistore.sdk.authn import AuthNClient
25from aistore.sdk.errors import AISError
26from dateutil.parser import parse as dateutil_parser
27from requests.exceptions import HTTPError
28
29from ..types import (
30 Credentials,
31 CredentialsProvider,
32 ObjectMetadata,
33 Range,
34)
35from ..utils import split_path
36from .base import BaseStorageProvider
37
38PROVIDER = "ais"
39
40
[docs]
41class StaticAISCredentialProvider(CredentialsProvider):
42 """
43 A concrete implementation of the :py:class:`multistorageclient.types.CredentialsProvider` that provides static S3 credentials.
44 """
45
46 _username: Optional[str]
47 _password: Optional[str]
48 _authn_endpoint: Optional[str]
49 _token: Optional[str]
50 _skip_verify: bool
51 _ca_cert: Optional[str]
52
53 def __init__(
54 self,
55 username: Optional[str] = None,
56 password: Optional[str] = None,
57 authn_endpoint: Optional[str] = None,
58 token: Optional[str] = None,
59 skip_verify: bool = True,
60 ca_cert: Optional[str] = None,
61 ):
62 """
63 Initializes the :py:class:`StaticAISCredentialProvider` with the given credentials.
64
65 :param username: The username for the AIStore authentication.
66 :param password: The password for the AIStore authentication.
67 :param authn_endpoint: The AIStore authentication endpoint.
68 :param token: The AIStore authentication token. This is used for authentication if username,
69 password and authn_endpoint are not provided.
70 :param skip_verify: If true, skip SSL certificate verification.
71 :param ca_cert: Path to a CA certificate file for SSL verification.
72
73 """
74 self._username = username
75 self._password = password
76 self._authn_endpoint = authn_endpoint
77 self._token = token
78 self._skip_verify = skip_verify
79 self._ca_cert = ca_cert
80
[docs]
81 def get_credentials(self) -> Credentials:
82 if self._username and self._password and self._authn_endpoint:
83 authn_client = AuthNClient(self._authn_endpoint, self._skip_verify, self._ca_cert)
84 self._token = authn_client.login(self._username, self._password)
85 return Credentials(token=self._token, access_key="", secret_key="", expiration=None)
86
[docs]
87 def refresh_credentials(self) -> None:
88 pass
89
90
[docs]
91class AIStoreStorageProvider(BaseStorageProvider):
92 def __init__(
93 self,
94 endpoint: str,
95 provider: str = PROVIDER,
96 skip_verify: bool = True,
97 ca_cert: Optional[str] = None,
98 timeout: Optional[Union[float, Tuple[float, float]]] = None,
99 base_path: str = "",
100 credentials_provider: Optional[CredentialsProvider] = None,
101 **kwargs: Any,
102 ) -> None:
103 """
104 AIStore client for managing buckets, objects, and ETL jobs.
105
106 :param endpoint: The AIStore endpoint.
107 :param skip_verify: Whether to skip SSL certificate verification.
108 :param ca_cert: Path to a CA certificate file for SSL verification.
109 :param timeout: Request timeout in seconds; a single float
110 for both connect/read timeouts (e.g., ``5.0``), a tuple for separate connect/read
111 timeouts (e.g., ``(3.0, 10.0)``), or ``None`` to disable timeout.
112 :param token: Authorization token. If not provided, the ``AIS_AUTHN_TOKEN`` environment variable will be used.
113 :param base_path: The root prefix path within the bucket where all operations will be scoped.
114 """
115 super().__init__(base_path=base_path, provider_name=PROVIDER)
116
117 token = None
118 if credentials_provider:
119 token = credentials_provider.get_credentials().token
120 self.client = Client(
121 endpoint=endpoint, skip_verify=skip_verify, ca_cert=ca_cert, timeout=timeout, token=token
122 )
123 else:
124 self.client = Client(endpoint=endpoint)
125 self.provider = provider
126
127 def _collect_metrics(
128 self,
129 func: Callable,
130 operation: str,
131 bucket: str,
132 key: str,
133 put_object_size: Optional[int] = None,
134 get_object_size: Optional[int] = None,
135 ) -> Any:
136 """
137 Collects and records performance metrics around object storage operations
138 such as ``PUT``, ``GET``, ``DELETE``, etc.
139
140 This method wraps an object storage operation and measures the time it takes to complete, along with recording
141 the size of the object if applicable. It handles errors like timeouts and client errors and ensures
142 proper logging of duration and object size.
143
144 :param func: The function that performs the actual object storage operation.
145 :param operation: The type of operation being performed (e.g., ``PUT``, ``GET``, ``DELETE``).
146 :param bucket: The name of the object storage bucket involved in the operation.
147 :param key: The key of the object within the object storage bucket.
148 :param put_object_size: The size of the object being uploaded, if applicable (for ``PUT`` operations).
149 :param get_object_size: The size of the object being downloaded, if applicable (for ``GET`` operations).
150
151 :return: The result of the object storage operation, typically the return value of the `func` callable.
152 """
153 start_time = time.time()
154 status_code = 200
155
156 object_size = None
157 if operation == "PUT":
158 object_size = put_object_size
159 elif operation == "GET" and get_object_size:
160 object_size = get_object_size
161
162 try:
163 result = func()
164 if operation == "GET" and object_size is None:
165 object_size = len(result)
166 return result
167 except AISError as error:
168 status_code = error.status_code
169 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error
170 except HTTPError as error:
171 status_code = error.response.status_code
172 if status_code == 404:
173 raise FileNotFoundError(f"Object {bucket}/{key} does not exist.") # pylint: disable=raise-missing-from
174 else:
175 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error
176 except Exception as error:
177 status_code = -1
178 raise RuntimeError(f"Failed to {operation} object(s) at {bucket}/{key}") from error
179 finally:
180 elapsed_time = time.time() - start_time
181 self._metric_helper.record_duration(
182 elapsed_time, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code
183 )
184 if object_size:
185 self._metric_helper.record_object_size(
186 object_size, provider=PROVIDER, operation=operation, bucket=bucket, status_code=status_code
187 )
188
189 def _put_object(self, path: str, body: bytes) -> None:
190 bucket, key = split_path(path)
191
192 def _invoke_api() -> None:
193 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
194 obj.put_content(body)
195
196 return self._collect_metrics(_invoke_api, operation="PUT", bucket=bucket, key=key, put_object_size=len(body))
197
198 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
199 bucket, key = split_path(path)
200 if byte_range:
201 bytes_range = f"bytes={byte_range.offset}-{byte_range.offset + byte_range.size - 1}"
202 else:
203 bytes_range = None
204
205 def _invoke_api() -> bytes:
206 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
207 if byte_range:
208 reader = obj.get(byte_range=bytes_range) # pyright: ignore [reportArgumentType]
209 else:
210 reader = obj.get()
211 return reader.read_all()
212
213 return self._collect_metrics(_invoke_api, operation="GET", bucket=bucket, key=key)
214
215 def _copy_object(self, src_path: str, dest_path: str) -> None:
216 raise AttributeError("AIStore does not support copy operations")
217
218 def _delete_object(self, path: str) -> None:
219 bucket, key = split_path(path)
220
221 def _invoke_api() -> None:
222 obj = self.client.bucket(bucket, self.provider).object(obj_name=key)
223 obj.delete()
224
225 return self._collect_metrics(_invoke_api, operation="DELETE", bucket=bucket, key=key)
226
227 def _get_object_metadata(self, path: str) -> ObjectMetadata:
228 bucket, key = split_path(path)
229
230 def _invoke_api() -> ObjectMetadata:
231 obj = self.client.bucket(bck_name=bucket, provider=self.provider).object(obj_name=key)
232 props = obj.head()
233 last_modified = datetime.fromtimestamp(int(props.get("Ais-Atime")) // 1_000_000_000) # pyright: ignore [reportArgumentType]
234 return ObjectMetadata(
235 key=key,
236 content_length=int(props.get("Content-Length")), # pyright: ignore [reportArgumentType]
237 last_modified=last_modified,
238 etag=props.get("Ais-Checksum-Value", None),
239 )
240
241 return self._collect_metrics(_invoke_api, operation="HEAD", bucket=bucket, key=key)
242
243 def _list_objects(
244 self, prefix: str, start_after: Optional[str] = None, end_at: Optional[str] = None
245 ) -> Iterator[ObjectMetadata]:
246 bucket, prefix = split_path(prefix)
247
248 def _invoke_api() -> Iterator[ObjectMetadata]:
249 # AIS has no start key option like other object stores.
250 all_objects = self.client.bucket(bck_name=bucket, provider=self.provider).list_all_objects_iter(
251 prefix=prefix, props="name,size,atime,checksum,cone"
252 )
253
254 # Assume AIS guarantees lexicographical order.
255 for obj in all_objects:
256 key = obj.name
257 if (start_after is None or start_after < key) and (end_at is None or key <= end_at):
258 yield ObjectMetadata(
259 key=key,
260 content_length=int(obj.props.size),
261 last_modified=dateutil_parser(obj.props.access_time),
262 etag=obj.props.checksum_value,
263 )
264 elif end_at is not None and end_at < key:
265 return
266
267 return self._collect_metrics(_invoke_api, operation="LIST", bucket=bucket, key=prefix)
268
269 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None:
270 if isinstance(f, str):
271 with open(f, "rb") as fp:
272 self._put_object(remote_path, fp.read())
273 else:
274 if isinstance(f, io.StringIO):
275 self._put_object(remote_path, f.read().encode("utf-8"))
276 else:
277 self._put_object(remote_path, f.read())
278
279 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
280 if not metadata:
281 metadata = self._get_object_metadata(remote_path)
282
283 if isinstance(f, str):
284 os.makedirs(os.path.dirname(f), exist_ok=True)
285 with open(f, "wb") as fp:
286 fp.write(self._get_object(remote_path))
287 else:
288 if isinstance(f, io.StringIO):
289 f.write(self._get_object(remote_path).decode("utf-8"))
290 else:
291 f.write(self._get_object(remote_path))