1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import glob
17import json
18import logging
19import os
20import shutil
21import tempfile
22from collections.abc import Callable, Iterator
23from datetime import datetime, timezone
24from enum import Enum
25from io import BytesIO, StringIO
26from typing import IO, Any, Optional, TypeVar, Union
27
28import xattr
29
30from ..telemetry import Telemetry
31from ..types import AWARE_DATETIME_MIN, ObjectMetadata, Range
32from ..utils import create_attribute_filter_evaluator, matches_attribute_filter_expression, validate_attributes
33from .base import BaseStorageProvider
34
35_T = TypeVar("_T")
36
37PROVIDER = "file"
38READ_CHUNK_SIZE = 8192
39
40logger = logging.getLogger(__name__)
41
42
43class _EntryType(Enum):
44 """
45 An enum representing the type of an entry in a directory.
46 """
47
48 FILE = 1
49 DIRECTORY = 2
50 DIRECTORY_TO_EXPLORE = 3
51
52
[docs]
53def atomic_write(source: Union[str, IO], destination: str, attributes: Optional[dict[str, str]] = None):
54 """
55 Writes the contents of a file to the specified destination path.
56
57 This function ensures that the file write operation is atomic, meaning the output file is either fully written or not modified at all.
58 This is achieved by writing to a temporary file first and then renaming it to the destination path.
59
60 :param source: The input file to read from. It can be a string representing the path to a file, or an open file-like object (IO).
61 :param destination: The path to the destination file where the contents should be written.
62 :param attributes: The attributes to set on the file.
63 """
64
65 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(destination), prefix=".") as fp:
66 temp_file_path = fp.name
67 if isinstance(source, str):
68 with open(source, mode="rb") as src:
69 while chunk := src.read(READ_CHUNK_SIZE):
70 fp.write(chunk)
71 else:
72 while chunk := source.read(READ_CHUNK_SIZE):
73 fp.write(chunk)
74
75 # Set attributes on temp file if provided
76 validated_attributes = validate_attributes(attributes)
77 if validated_attributes:
78 try:
79 xattr.setxattr(temp_file_path, "user.json", json.dumps(validated_attributes).encode("utf-8"))
80 except OSError as e:
81 logger.debug(f"Failed to set extended attributes on temp file {temp_file_path}: {e}")
82
83 os.rename(src=temp_file_path, dst=destination)
84
85
[docs]
86class PosixFileStorageProvider(BaseStorageProvider):
87 """
88 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with POSIX file systems.
89 """
90
91 def __init__(
92 self,
93 base_path: str,
94 config_dict: Optional[dict[str, Any]] = None,
95 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
96 **kwargs: Any,
97 ) -> None:
98 """
99 :param base_path: The root prefix path within the POSIX file system where all operations will be scoped.
100 :param config_dict: Resolved MSC config.
101 :param telemetry_provider: A function that provides a telemetry instance.
102 """
103
104 # Validate POSIX path
105 if base_path == "":
106 base_path = "/"
107
108 if not base_path.startswith("/"):
109 raise ValueError(f"The base_path {base_path} must be an absolute path.")
110
111 super().__init__(
112 base_path=base_path,
113 provider_name=PROVIDER,
114 config_dict=config_dict,
115 telemetry_provider=telemetry_provider,
116 )
117
118 def _translate_errors(
119 self,
120 func: Callable[[], _T],
121 operation: str,
122 path: str,
123 ) -> _T:
124 """
125 Translates errors like timeouts and client errors.
126
127 :param func: The function that performs the actual file operation.
128 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
129 :param path: The path to the object.
130
131 :return: The result of the file operation, typically the return value of the `func` callable.
132 """
133 try:
134 return func()
135 except FileNotFoundError:
136 raise
137 except Exception as error:
138 raise RuntimeError(f"Failed to {operation} object(s) at {path}, error: {error}") from error
139
140 def _put_object(
141 self,
142 path: str,
143 body: bytes,
144 if_match: Optional[str] = None,
145 if_none_match: Optional[str] = None,
146 attributes: Optional[dict[str, str]] = None,
147 ) -> int:
148 def _invoke_api() -> int:
149 os.makedirs(os.path.dirname(path), exist_ok=True)
150 atomic_write(source=BytesIO(body), destination=path, attributes=attributes)
151 return len(body)
152
153 return self._translate_errors(_invoke_api, operation="PUT", path=path)
154
155 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
156 def _invoke_api() -> bytes:
157 if byte_range:
158 with open(path, "rb") as f:
159 f.seek(byte_range.offset)
160 return f.read(byte_range.size)
161 else:
162 with open(path, "rb") as f:
163 return f.read()
164
165 return self._translate_errors(_invoke_api, operation="GET", path=path)
166
167 def _copy_object(self, src_path: str, dest_path: str) -> int:
168 src_object = self._get_object_metadata(src_path)
169
170 def _invoke_api() -> int:
171 os.makedirs(os.path.dirname(dest_path), exist_ok=True)
172 atomic_write(source=src_path, destination=dest_path, attributes=src_object.metadata)
173
174 return src_object.content_length
175
176 return self._translate_errors(_invoke_api, operation="COPY", path=src_path)
177
178 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
179 def _invoke_api() -> None:
180 if os.path.exists(path) and os.path.isfile(path):
181 os.remove(path)
182
183 return self._translate_errors(_invoke_api, operation="DELETE", path=path)
184
185 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
186 is_dir = os.path.isdir(path)
187 if is_dir:
188 path = self._append_delimiter(path)
189
190 def _invoke_api() -> ObjectMetadata:
191 # Get basic file attributes
192 metadata_dict = {}
193 try:
194 json_bytes = xattr.getxattr(path, "user.json")
195 metadata_dict = json.loads(json_bytes.decode("utf-8"))
196 except (OSError, IOError, KeyError, json.JSONDecodeError, AttributeError) as e:
197 # Silently ignore if xattr doesn't exist, can't be read, or is corrupted
198 logger.debug(f"Failed to read extended attributes from {path}: {e}")
199 pass
200
201 return ObjectMetadata(
202 key=path,
203 type="directory" if is_dir else "file",
204 content_length=0 if is_dir else os.path.getsize(path),
205 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
206 metadata=metadata_dict if metadata_dict else None,
207 )
208
209 return self._translate_errors(_invoke_api, operation="HEAD", path=path)
210
211 def _list_objects(
212 self,
213 path: str,
214 start_after: Optional[str] = None,
215 end_at: Optional[str] = None,
216 include_directories: bool = False,
217 follow_symlinks: bool = True,
218 ) -> Iterator[ObjectMetadata]:
219 start_after = os.path.relpath(start_after, self._base_path) if start_after else None
220 end_at = os.path.relpath(end_at, self._base_path) if end_at else None
221
222 def _invoke_api() -> Iterator[ObjectMetadata]:
223 if os.path.isfile(path):
224 yield ObjectMetadata(
225 key=os.path.relpath(path, self._base_path), # relative path to the base path
226 content_length=os.path.getsize(path),
227 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
228 )
229 dir_path = path.rstrip("/") + "/"
230 if not os.path.isdir(dir_path): # expect the input to be a directory
231 return
232
233 yield from self._explore_directory(dir_path, start_after, end_at, include_directories, follow_symlinks)
234
235 return self._translate_errors(_invoke_api, operation="LIST", path=path)
236
237 def _explore_directory(
238 self,
239 dir_path: str,
240 start_after: Optional[str],
241 end_at: Optional[str],
242 include_directories: bool,
243 follow_symlinks: bool = True,
244 ) -> Iterator[ObjectMetadata]:
245 """
246 Recursively explore a directory and yield objects in lexicographical order.
247
248 :param dir_path: The directory path to explore
249 :param start_after: The key to start after
250 :param end_at: The key to end at
251 :param include_directories: Whether to include directories in the result
252 :param follow_symlinks: Whether to follow symbolic links. When False, symlinks are skipped.
253 """
254 try:
255 # List contents of current directory
256 dir_entries = os.listdir(dir_path)
257 dir_entries.sort() # Sort entries for consistent ordering
258
259 # Collect all entries in this directory
260 entries = []
261
262 for entry in dir_entries:
263 full_path = os.path.join(dir_path, entry)
264
265 # Skip symlinks if follow_symlinks is False
266 if not follow_symlinks and os.path.islink(full_path):
267 continue
268
269 relative_path = os.path.relpath(full_path, self._base_path)
270
271 # Check if this entry is within our range
272 if (start_after is None or start_after < relative_path) and (end_at is None or relative_path <= end_at):
273 if os.path.isfile(full_path):
274 entries.append((relative_path, full_path, _EntryType.FILE))
275 elif os.path.isdir(full_path):
276 if include_directories:
277 entries.append((relative_path, full_path, _EntryType.DIRECTORY))
278 else:
279 # Add directory for recursive exploration
280 entries.append((relative_path, full_path, _EntryType.DIRECTORY_TO_EXPLORE))
281
282 # Sort entries by relative path
283 entries.sort(key=lambda x: x[0])
284
285 # Process entries in order
286 for relative_path, full_path, entry_type in entries:
287 if entry_type == _EntryType.FILE:
288 yield ObjectMetadata(
289 key=relative_path,
290 content_length=os.path.getsize(full_path),
291 last_modified=datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc),
292 )
293 elif entry_type == _EntryType.DIRECTORY:
294 yield ObjectMetadata(
295 key=relative_path,
296 content_length=0,
297 type="directory",
298 last_modified=AWARE_DATETIME_MIN,
299 )
300 elif entry_type == _EntryType.DIRECTORY_TO_EXPLORE:
301 # Recursively explore this directory
302 yield from self._explore_directory(
303 full_path, start_after, end_at, include_directories, follow_symlinks
304 )
305
306 except (OSError, PermissionError) as e:
307 logger.warning(f"Failed to list contents of {dir_path}, caused by: {e}")
308 return
309
310 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
311 os.makedirs(os.path.dirname(remote_path), exist_ok=True)
312
313 filesize: int = 0
314 if isinstance(f, str):
315 filesize = os.path.getsize(f)
316 elif isinstance(f, StringIO):
317 filesize = len(f.getvalue().encode("utf-8"))
318 else:
319 filesize = len(f.getvalue()) # type: ignore
320
321 def _invoke_api() -> int:
322 atomic_write(source=f, destination=remote_path, attributes=attributes)
323
324 return filesize
325
326 return self._translate_errors(_invoke_api, operation="PUT", path=remote_path)
327
328 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
329 filesize = metadata.content_length if metadata else os.path.getsize(remote_path)
330
331 if isinstance(f, str):
332
333 def _invoke_api() -> int:
334 if os.path.dirname(f):
335 os.makedirs(os.path.dirname(f), exist_ok=True)
336 atomic_write(source=remote_path, destination=f)
337
338 return filesize
339
340 return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
341 elif isinstance(f, StringIO):
342
343 def _invoke_api() -> int:
344 with open(remote_path, "r", encoding="utf-8") as src:
345 while chunk := src.read(READ_CHUNK_SIZE):
346 f.write(chunk)
347
348 return filesize
349
350 return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
351 else:
352
353 def _invoke_api() -> int:
354 with open(remote_path, "rb") as src:
355 while chunk := src.read(READ_CHUNK_SIZE):
356 f.write(chunk)
357
358 return filesize
359
360 return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
361
[docs]
362 def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]:
363 pattern = self._prepend_base_path(pattern)
364 keys = list(glob.glob(pattern, recursive=True))
365 if attribute_filter_expression:
366 filtered_keys = []
367 evaluator = create_attribute_filter_evaluator(attribute_filter_expression)
368 for key in keys:
369 obj_metadata = self._get_object_metadata(key)
370 if matches_attribute_filter_expression(obj_metadata, evaluator):
371 filtered_keys.append(key)
372 keys = filtered_keys
373 if self._base_path == "/":
374 return keys
375 else:
376 # NOTE: PosixStorageProvider does not have the concept of bucket and prefix.
377 # So we drop the base_path from it.
378 return [key.replace(self._base_path, "", 1).lstrip("/") for key in keys]
379
[docs]
380 def is_file(self, path: str) -> bool:
381 path = self._prepend_base_path(path)
382 return os.path.isfile(path)
383
[docs]
384 def rmtree(self, path: str) -> None:
385 path = self._prepend_base_path(path)
386 shutil.rmtree(path)