1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import glob
17import json
18import logging
19import os
20import shutil
21import tempfile
22from collections.abc import Callable, Iterator
23from datetime import datetime, timezone
24from enum import Enum
25from io import BytesIO, StringIO
26from typing import IO, Any, Optional, TypeVar, Union
27
28import xattr
29
30from ..telemetry import Telemetry
31from ..types import AWARE_DATETIME_MIN, ObjectMetadata, Range
32from ..utils import create_attribute_filter_evaluator, matches_attribute_filter_expression, validate_attributes
33from .base import BaseStorageProvider
34
35_T = TypeVar("_T")
36
37PROVIDER = "file"
38READ_CHUNK_SIZE = 8192
39
40logger = logging.getLogger(__name__)
41
42
43class _EntryType(Enum):
44 """
45 An enum representing the type of an entry in a directory.
46 """
47
48 FILE = 1
49 DIRECTORY = 2
50 DIRECTORY_TO_EXPLORE = 3
51
52
[docs]
53def atomic_write(source: Union[str, IO], destination: str, attributes: Optional[dict[str, str]] = None):
54 """
55 Writes the contents of a file to the specified destination path.
56
57 This function ensures that the file write operation is atomic, meaning the output file is either fully written or not modified at all.
58 This is achieved by writing to a temporary file first and then renaming it to the destination path.
59
60 :param source: The input file to read from. It can be a string representing the path to a file, or an open file-like object (IO).
61 :param destination: The path to the destination file where the contents should be written.
62 :param attributes: The attributes to set on the file.
63 """
64
65 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(destination), prefix=".") as fp:
66 temp_file_path = fp.name
67 if isinstance(source, str):
68 with open(source, mode="rb") as src:
69 while chunk := src.read(READ_CHUNK_SIZE):
70 fp.write(chunk)
71 else:
72 while chunk := source.read(READ_CHUNK_SIZE):
73 fp.write(chunk)
74
75 # Set attributes on temp file if provided
76 validated_attributes = validate_attributes(attributes)
77 if validated_attributes:
78 try:
79 xattr.setxattr(temp_file_path, "user.json", json.dumps(validated_attributes).encode("utf-8"))
80 except OSError as e:
81 logger.debug(f"Failed to set extended attributes on temp file {temp_file_path}: {e}")
82
83 os.rename(src=temp_file_path, dst=destination)
84
85
[docs]
86class PosixFileStorageProvider(BaseStorageProvider):
87 """
88 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with POSIX file systems.
89 """
90
91 def __init__(
92 self,
93 base_path: str,
94 config_dict: Optional[dict[str, Any]] = None,
95 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
96 **kwargs: Any,
97 ) -> None:
98 """
99 :param base_path: The root prefix path within the POSIX file system where all operations will be scoped.
100 :param config_dict: Resolved MSC config.
101 :param telemetry_provider: A function that provides a telemetry instance.
102 """
103
104 # Validate POSIX path
105 if base_path == "":
106 base_path = "/"
107
108 if not base_path.startswith("/"):
109 raise ValueError(f"The base_path {base_path} must be an absolute path.")
110
111 super().__init__(
112 base_path=base_path,
113 provider_name=PROVIDER,
114 config_dict=config_dict,
115 telemetry_provider=telemetry_provider,
116 )
117
118 def _translate_errors(
119 self,
120 func: Callable[[], _T],
121 operation: str,
122 path: str,
123 ) -> _T:
124 """
125 Translates errors like timeouts and client errors.
126
127 :param func: The function that performs the actual file operation.
128 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
129 :param path: The path to the object.
130
131 :return: The result of the file operation, typically the return value of the `func` callable.
132 """
133 try:
134 return func()
135 except FileNotFoundError:
136 raise
137 except Exception as error:
138 raise RuntimeError(f"Failed to {operation} object(s) at {path}, error: {error}") from error
139
140 def _put_object(
141 self,
142 path: str,
143 body: bytes,
144 if_match: Optional[str] = None,
145 if_none_match: Optional[str] = None,
146 attributes: Optional[dict[str, str]] = None,
147 ) -> int:
148 def _invoke_api() -> int:
149 os.makedirs(os.path.dirname(path), exist_ok=True)
150 atomic_write(source=BytesIO(body), destination=path, attributes=attributes)
151 return len(body)
152
153 return self._translate_errors(_invoke_api, operation="PUT", path=path)
154
155 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
156 def _invoke_api() -> bytes:
157 if byte_range:
158 with open(path, "rb") as f:
159 f.seek(byte_range.offset)
160 return f.read(byte_range.size)
161 else:
162 with open(path, "rb") as f:
163 return f.read()
164
165 return self._translate_errors(_invoke_api, operation="GET", path=path)
166
167 def _copy_object(self, src_path: str, dest_path: str) -> int:
168 src_object = self._get_object_metadata(src_path)
169
170 def _invoke_api() -> int:
171 os.makedirs(os.path.dirname(dest_path), exist_ok=True)
172 atomic_write(source=src_path, destination=dest_path, attributes=src_object.metadata)
173
174 return src_object.content_length
175
176 return self._translate_errors(_invoke_api, operation="COPY", path=src_path)
177
178 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
179 def _invoke_api() -> None:
180 if os.path.exists(path) and os.path.isfile(path):
181 os.remove(path)
182
183 return self._translate_errors(_invoke_api, operation="DELETE", path=path)
184
185 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
186 is_dir = os.path.isdir(path)
187 if is_dir:
188 path = self._append_delimiter(path)
189
190 def _invoke_api() -> ObjectMetadata:
191 # Get basic file attributes
192 metadata_dict = {}
193 try:
194 json_bytes = xattr.getxattr(path, "user.json")
195 metadata_dict = json.loads(json_bytes.decode("utf-8"))
196 except (OSError, IOError, KeyError, json.JSONDecodeError, AttributeError) as e:
197 # Silently ignore if xattr doesn't exist, can't be read, or is corrupted
198 logger.debug(f"Failed to read extended attributes from {path}: {e}")
199 pass
200
201 return ObjectMetadata(
202 key=path,
203 type="directory" if is_dir else "file",
204 content_length=0 if is_dir else os.path.getsize(path),
205 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
206 metadata=metadata_dict if metadata_dict else None,
207 )
208
209 return self._translate_errors(_invoke_api, operation="HEAD", path=path)
210
211 def _list_objects(
212 self,
213 path: str,
214 start_after: Optional[str] = None,
215 end_at: Optional[str] = None,
216 include_directories: bool = False,
217 ) -> Iterator[ObjectMetadata]:
218 def _invoke_api() -> Iterator[ObjectMetadata]:
219 if os.path.isfile(path):
220 yield ObjectMetadata(
221 key=os.path.relpath(path, self._base_path), # relative path to the base path
222 content_length=os.path.getsize(path),
223 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
224 )
225 dir_path = path.rstrip("/") + "/"
226 if not os.path.isdir(dir_path): # expect the input to be a directory
227 return
228
229 yield from self._explore_directory(dir_path, start_after, end_at, include_directories)
230
231 return self._translate_errors(_invoke_api, operation="LIST", path=path)
232
233 def _explore_directory(
234 self, dir_path: str, start_after: Optional[str], end_at: Optional[str], include_directories: bool
235 ) -> Iterator[ObjectMetadata]:
236 """
237 Recursively explore a directory and yield objects in lexicographical order.
238 """
239 try:
240 # List contents of current directory
241 dir_entries = os.listdir(dir_path)
242 dir_entries.sort() # Sort entries for consistent ordering
243
244 # Collect all entries in this directory
245 entries = []
246
247 for entry in dir_entries:
248 full_path = os.path.join(dir_path, entry)
249
250 relative_path = os.path.relpath(full_path, self._base_path)
251
252 # Check if this entry is within our range
253 if (start_after is None or start_after < relative_path) and (end_at is None or relative_path <= end_at):
254 if os.path.isfile(full_path):
255 entries.append((relative_path, full_path, _EntryType.FILE))
256 elif os.path.isdir(full_path):
257 if include_directories:
258 entries.append((relative_path, full_path, _EntryType.DIRECTORY))
259 else:
260 # Add directory for recursive exploration
261 entries.append((relative_path, full_path, _EntryType.DIRECTORY_TO_EXPLORE))
262
263 # Sort entries by relative path
264 entries.sort(key=lambda x: x[0])
265
266 # Process entries in order
267 for relative_path, full_path, entry_type in entries:
268 if entry_type == _EntryType.FILE:
269 yield ObjectMetadata(
270 key=relative_path,
271 content_length=os.path.getsize(full_path),
272 last_modified=datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc),
273 )
274 elif entry_type == _EntryType.DIRECTORY:
275 yield ObjectMetadata(
276 key=relative_path,
277 content_length=0,
278 type="directory",
279 last_modified=AWARE_DATETIME_MIN,
280 )
281 elif entry_type == _EntryType.DIRECTORY_TO_EXPLORE:
282 # Recursively explore this directory
283 yield from self._explore_directory(full_path, start_after, end_at, include_directories)
284
285 except (OSError, PermissionError) as e:
286 logger.warning(f"Failed to list contents of {dir_path}, caused by: {e}")
287 return
288
289 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
290 os.makedirs(os.path.dirname(remote_path), exist_ok=True)
291
292 filesize: int = 0
293 if isinstance(f, str):
294 filesize = os.path.getsize(f)
295 elif isinstance(f, StringIO):
296 filesize = len(f.getvalue().encode("utf-8"))
297 else:
298 filesize = len(f.getvalue()) # type: ignore
299
300 def _invoke_api() -> int:
301 atomic_write(source=f, destination=remote_path, attributes=attributes)
302
303 return filesize
304
305 return self._translate_errors(_invoke_api, operation="PUT", path=remote_path)
306
307 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
308 filesize = metadata.content_length if metadata else os.path.getsize(remote_path)
309
310 if isinstance(f, str):
311
312 def _invoke_api() -> int:
313 if os.path.dirname(f):
314 os.makedirs(os.path.dirname(f), exist_ok=True)
315 atomic_write(source=remote_path, destination=f)
316
317 return filesize
318
319 return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
320 elif isinstance(f, StringIO):
321
322 def _invoke_api() -> int:
323 with open(remote_path, "r", encoding="utf-8") as src:
324 while chunk := src.read(READ_CHUNK_SIZE):
325 f.write(chunk)
326
327 return filesize
328
329 return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
330 else:
331
332 def _invoke_api() -> int:
333 with open(remote_path, "rb") as src:
334 while chunk := src.read(READ_CHUNK_SIZE):
335 f.write(chunk)
336
337 return filesize
338
339 return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
340
[docs]
341 def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]:
342 pattern = self._prepend_base_path(pattern)
343 keys = list(glob.glob(pattern, recursive=True))
344 if attribute_filter_expression:
345 filtered_keys = []
346 evaluator = create_attribute_filter_evaluator(attribute_filter_expression)
347 for key in keys:
348 obj_metadata = self._get_object_metadata(key)
349 if matches_attribute_filter_expression(obj_metadata, evaluator):
350 filtered_keys.append(key)
351 keys = filtered_keys
352 if self._base_path == "/":
353 return keys
354 else:
355 # NOTE: PosixStorageProvider does not have the concept of bucket and prefix.
356 # So we drop the base_path from it.
357 return [key.replace(self._base_path, "", 1).lstrip("/") for key in keys]
358
[docs]
359 def is_file(self, path: str) -> bool:
360 path = self._prepend_base_path(path)
361 return os.path.isfile(path)
362
[docs]
363 def rmtree(self, path: str) -> None:
364 path = self._prepend_base_path(path)
365 shutil.rmtree(path)