1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import asyncio
17import atexit
18import os
19from collections.abc import Callable
20from concurrent.futures import ThreadPoolExecutor
21from functools import partial
22from typing import Any, Union
23
24from fsspec.asyn import AsyncFileSystem, _run_coros_in_chunks
25
26from ..client import StorageClient
27from ..file import ObjectFile, PosixFile
28from ..shortcuts import resolve_storage_client
29from ..types import MSC_PROTOCOL_NAME
30
31_GLOBAL_THREAD_POOL = ThreadPoolExecutor(max_workers=int(os.getenv("MSC_MAX_WORKERS", "8")))
32
33atexit.register(lambda: _GLOBAL_THREAD_POOL.shutdown(wait=False))
34
35
36# pyright: reportIncompatibleMethodOverride=false
[docs]
37class MultiStorageAsyncFileSystem(AsyncFileSystem):
38 """
39 Custom :py:class:`fsspec.asyn.AsyncFileSystem` implementation for MSC protocol (``msc://``).
40 Uses :py:class:`multistorageclient.StorageClient` for backend operations.
41 """
42
43 protocol = MSC_PROTOCOL_NAME
44
45 def __init__(self, **kwargs: Any) -> None:
46 """
47 Initializes the :py:class:`MultiStorageAsyncFileSystem`.
48
49 :param kwargs: Additional arguments for the :py:class:`fsspec.asyn.AsyncFileSystem`.
50 """
51 super().__init__(**kwargs)
52
[docs]
53 def resolve_path_and_storage_client(self, path: Union[str, os.PathLike]) -> tuple[StorageClient, str]:
54 """
55 Resolves the path and retrieves the associated :py:class:`multistorageclient.StorageClient`.
56
57 :param path: The file path to resolve.
58
59 :return: A tuple containing the :py:class:`multistorageclient.StorageClient` and the resolved path.
60 """
61 # Use unstrip_protocol to prepend our 'msc://' protocol only if it wasn't given in "path".
62 return resolve_storage_client(self.unstrip_protocol(str(path).lstrip("/")))
63
[docs]
64 @staticmethod
65 def asynchronize_sync(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
66 """
67 Runs a synchronous function asynchronously using asyncio.
68
69 :param func: The synchronous function to be executed asynchronously.
70 :param args: Positional arguments to pass to the function.
71 :param kwargs: Keyword arguments to pass to the function.
72
73 :return: The result of the asynchronous execution of the function.
74 """
75 loop = asyncio.get_event_loop()
76 return loop.run_in_executor(_GLOBAL_THREAD_POOL, partial(func, *args, **kwargs))
77
[docs]
78 def ls(self, path: str, detail: bool = True, **kwargs: Any) -> Union[list[dict[str, Any]], list[str]]:
79 """
80 Lists the contents of a directory.
81
82 :param path: The directory path to list.
83 :param detail: Whether to return detailed information for each file.
84 :param kwargs: Additional arguments for list functionality.
85
86 :return: A list of file names or detailed information depending on the 'detail' argument.
87 """
88 storage_client, dir_path = self.resolve_path_and_storage_client(path)
89
90 if dir_path and not dir_path.endswith("/"):
91 dir_path += "/"
92
93 objects = storage_client.list(path=dir_path, include_directories=True)
94
95 if detail:
96 return [
97 {
98 "name": os.path.join(storage_client.profile, obj.key),
99 "ETag": obj.etag,
100 "LastModified": obj.last_modified,
101 "size": obj.content_length,
102 "ContentType": obj.content_type,
103 "type": obj.type,
104 }
105 for obj in objects
106 ]
107 else:
108 return [os.path.join(storage_client.profile, obj.key) for obj in objects]
109
110 async def _ls(self, path: str, detail: bool = True, **kwargs: Any) -> Union[list[dict[str, Any]], list[str]]:
111 """
112 Asynchronously lists the contents of a directory.
113
114 :param path: The directory path to list.
115 :param detail: Whether to return detailed information for each file.
116 :param kwargs: Additional arguments for list functionality.
117
118 :return: A list of file names or detailed information depending on the 'detail' argument.
119 """
120 return await self.asynchronize_sync(self.ls, path, detail, **kwargs)
121
[docs]
122 def info(self, path: str, **kwargs: Any) -> dict[str, Any]:
123 """
124 Retrieves metadata information for a file.
125
126 :param path: The file path to retrieve information for.
127 :param kwargs: Additional arguments for info functionality.
128
129 :return: A dictionary containing file metadata such as ETag, last modified, and size.
130 """
131 storage_client, file_path = self.resolve_path_and_storage_client(path)
132 metadata = storage_client.info(file_path)
133 return {
134 "name": os.path.join(storage_client.profile, metadata.key),
135 "ETag": metadata.etag,
136 "LastModified": metadata.last_modified,
137 "size": metadata.content_length,
138 "ContentType": metadata.content_type,
139 "type": metadata.type,
140 }
141
142 async def _info(self, path: str, **kwargs: Any) -> dict[str, Any]:
143 """
144 Asynchronously retrieves metadata information for a file.
145
146 :param path: The file path to retrieve information for.
147 :param kwargs: Additional arguments for info functionality.
148
149 :return: A dictionary containing file metadata such as ETag, last modified, and size.
150 """
151 return await self.asynchronize_sync(self.info, path, **kwargs)
152
[docs]
153 def rm_file(self, path: str, **kwargs: Any):
154 """
155 Removes a file.
156
157 :param path: The file or directory path to remove.
158 :param kwargs: Additional arguments for remove functionality.
159 """
160 storage_client, file_path = self.resolve_path_and_storage_client(path)
161 recursive = kwargs.get("recursive", False)
162 storage_client.delete(file_path, recursive=recursive)
163
164 async def _rm_file(self, path: str, **kwargs: Any):
165 """
166 Asynchronously removes a file.
167
168 :param path: The file or directory path to remove.
169 :param kwargs: Additional arguments for remove functionality.
170 """
171 return await self.asynchronize_sync(self.rm_file, path, **kwargs)
172
173 async def _rm(self, path, recursive=False, batch_size=None, **kwargs):
174 """
175 Asynchronously removes a file or directory.
176 Instead of using the implementation in the parent class to expand the path and parallel delete the files,
177 we explicitly pass down the recursive value and use the delete method in the StorageClient to handle the directory deletion.
178
179 :param path: The file or directory path to remove.
180 :param recursive: Whether to recursively remove directories.
181 :param batch_size: The number of files to process in each batch.
182 :param kwargs: Additional arguments for remove functionality.
183 """
184
185 if "recursive" not in kwargs:
186 kwargs["recursive"] = recursive
187
188 return await _run_coros_in_chunks(
189 [self._rm_file(path, **kwargs)],
190 batch_size=-1, # no throttling
191 nofiles=True,
192 )
193
[docs]
194 def cp_file(self, path1: str, path2: str, **kwargs: Any):
195 """
196 Copies a file from the source path to the destination path.
197
198 :param path1: The source file path.
199 :param path2: The destination file path.
200 :param kwargs: Additional arguments for copy functionality.
201
202 :raises AttributeError: If the source and destination paths are associated with different profiles.
203 """
204 src_storage_client, src_path = self.resolve_path_and_storage_client(path1)
205 dest_storage_client, dest_path = self.resolve_path_and_storage_client(path2)
206
207 if src_storage_client != dest_storage_client:
208 raise AttributeError(
209 f"Cannot copy file from '{path1}' to '{path2}' because the source and destination paths are associated with different profiles. Cross-profile file operations are not supported."
210 )
211
212 src_storage_client.copy(src_path, dest_path)
213
214 async def _cp_file(self, path1: str, path2: str, **kwargs: Any):
215 """
216 Asynchronously copies a file from the source path to the destination path.
217
218 :param path1: The source file path.
219 :param path2: The destination file path.
220 :param kwargs: Additional arguments for copy functionality.
221
222 :raises AttributeError: If the source and destination paths are associated with different profiles.
223 """
224 await self.asynchronize_sync(self.cp_file, path1, path2, **kwargs)
225
[docs]
226 def get_file(self, rpath: str, lpath: str, **kwargs: Any) -> None:
227 """
228 Downloads a file from the remote path to the local path.
229
230 :param rpath: The remote path of the file to download.
231 :param lpath: The local path to store the file.
232 :param kwargs: Additional arguments for file retrieval functionality.
233 """
234 storage_client, rpath = self.resolve_path_and_storage_client(rpath)
235 storage_client.download_file(rpath, lpath)
236
237 async def _get_file(self, rpath: str, lpath: str, **kwargs: Any) -> None:
238 """
239 Asynchronously downloads a file from the remote path to the local path.
240
241 :param rpath: The remote path of the file to download.
242 :param lpath: The local path to store the file.
243 :param kwargs: Additional arguments for file retrieval functionality.
244 """
245 await self.asynchronize_sync(self.get_file, rpath, lpath, **kwargs)
246
[docs]
247 def put_file(self, lpath: str, rpath: str, **kwargs: Any) -> None:
248 """
249 Uploads a local file to the remote path.
250
251 :param lpath: The local path of the file to upload.
252 :param rpath: The remote path to store the file.
253 :param kwargs: Additional arguments for file upload functionality.
254 """
255 storage_client, rpath = self.resolve_path_and_storage_client(rpath)
256 storage_client.upload_file(rpath, lpath)
257
258 async def _put_file(self, lpath: str, rpath: str, **kwargs: Any) -> None:
259 """
260 Asynchronously uploads a local file to the remote path.
261
262 :param lpath: The local path of the file to upload.
263 :param rpath: The remote path to store the file.
264 :param kwargs: Additional arguments for file upload functionality.
265 """
266 await self.asynchronize_sync(self.put_file, lpath, rpath, **kwargs)
267
[docs]
268 def open(self, path: str, mode: str = "rb", **kwargs: Any) -> Union[PosixFile, ObjectFile]:
269 """
270 Opens a file at the given path.
271
272 :param path: The file path to open.
273 :param mode: The mode in which to open the file.
274 :param kwargs: Additional arguments for file opening.
275
276 :return: A ManagedFile object representing the opened file.
277 """
278 storage_client, path = self.resolve_path_and_storage_client(path)
279 return storage_client.open(path, mode)
280
281 async def _open(self, path: str, mode: str = "rb", **kwargs: Any) -> Union[PosixFile, ObjectFile]:
282 """
283 Asynchronously opens a file at the given path.
284
285 :param path: The file path to open.
286 :param mode: The mode in which to open the file.
287 :param kwargs: Additional arguments for file opening.
288
289 :return: A ManagedFile object representing the opened file.
290 """
291 return await self.asynchronize_sync(self.open, path, mode, **kwargs)
292
[docs]
293 def pipe_file(self, path: str, value: bytes, **kwargs: Any) -> None:
294 """
295 Writes a value (bytes) directly to a file at the given path.
296
297 :param path: The file path to write the value to.
298 :param value: The bytes to write to the file.
299 :param kwargs: Additional arguments for writing functionality.
300 """
301 storage_client, path = self.resolve_path_and_storage_client(path)
302 storage_client.write(path, value)
303
304 async def _pipe_file(self, path: str, value: bytes, **kwargs: Any) -> None:
305 """
306 Asynchronously writes a value (bytes) directly to a file at the given path.
307
308 :param path: The file path to write the value to.
309 :param value: The bytes to write to the file.
310 :param kwargs: Additional arguments for writing functionality.
311 """
312 await self.asynchronize_sync(self.pipe_file, path, value, **kwargs)
313
[docs]
314 def cat_file(self, path: str, **kwargs: Any) -> bytes:
315 """
316 Reads the contents of a file at the given path.
317
318 :param path: The file path to read from.
319 :param kwargs: Additional arguments for file reading functionality.
320
321 :return: The contents of the file as bytes.
322 """
323 storage_client, path = self.resolve_path_and_storage_client(path)
324 return storage_client.read(path)
325
326 async def _cat_file(self, path: str, **kwargs: Any) -> bytes:
327 """
328 Asynchronously reads the contents of a file at the given path.
329
330 :param path: The file path to read from.
331 :param kwargs: Additional arguments for file reading functionality.
332
333 :return: The contents of the file as bytes.
334 """
335 return await self.asynchronize_sync(self.cat_file, path, **kwargs)