1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import glob
17import json
18import logging
19import os
20import shutil
21import tempfile
22from collections.abc import Callable, Iterator
23from datetime import datetime, timezone
24from enum import Enum
25from io import BytesIO, StringIO
26from typing import IO, Any, Optional, TypeVar, Union
27
28import xattr
29
30from ..telemetry import Telemetry
31from ..types import AWARE_DATETIME_MIN, ObjectMetadata, Range
32from ..utils import (
33 create_attribute_filter_evaluator,
34 matches_attribute_filter_expression,
35 safe_makedirs,
36 validate_attributes,
37)
38from .base import BaseStorageProvider
39
40_T = TypeVar("_T")
41
42PROVIDER = "file"
43READ_CHUNK_SIZE = 8192
44
45logger = logging.getLogger(__name__)
46
47
48class _EntryType(Enum):
49 """
50 An enum representing the type of an entry in a directory.
51 """
52
53 FILE = 1
54 DIRECTORY = 2
55 DIRECTORY_TO_EXPLORE = 3
56
57
[docs]
58def atomic_write(source: Union[str, IO], destination: str, attributes: Optional[dict[str, str]] = None):
59 """
60 Writes the contents of a file to the specified destination path.
61
62 This function ensures that the file write operation is atomic, meaning the output file is either fully written or not modified at all.
63 This is achieved by writing to a temporary file first and then renaming it to the destination path.
64
65 :param source: The input file to read from. It can be a string representing the path to a file, or an open file-like object (IO).
66 :param destination: The path to the destination file where the contents should be written.
67 :param attributes: The attributes to set on the file.
68 """
69
70 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(destination), prefix=".") as fp:
71 temp_file_path = fp.name
72 if isinstance(source, str):
73 with open(source, mode="rb") as src:
74 while chunk := src.read(READ_CHUNK_SIZE):
75 fp.write(chunk)
76 else:
77 while chunk := source.read(READ_CHUNK_SIZE):
78 fp.write(chunk)
79
80 # Set attributes on temp file if provided
81 validated_attributes = validate_attributes(attributes)
82 if validated_attributes:
83 try:
84 xattr.setxattr(temp_file_path, "user.json", json.dumps(validated_attributes).encode("utf-8"))
85 except OSError as e:
86 logger.debug(f"Failed to set extended attributes on temp file {temp_file_path}: {e}")
87
88 os.rename(src=temp_file_path, dst=destination)
89
90
[docs]
91class PosixFileStorageProvider(BaseStorageProvider):
92 """
93 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with POSIX file systems.
94 """
95
96 def __init__(
97 self,
98 base_path: str,
99 config_dict: Optional[dict[str, Any]] = None,
100 telemetry_provider: Optional[Callable[[], Telemetry]] = None,
101 **kwargs: Any,
102 ) -> None:
103 """
104 :param base_path: The root prefix path within the POSIX file system where all operations will be scoped.
105 :param config_dict: Resolved MSC config.
106 :param telemetry_provider: A function that provides a telemetry instance.
107 """
108
109 # Validate POSIX path
110 if base_path == "":
111 base_path = "/"
112
113 if not base_path.startswith("/"):
114 raise ValueError(f"The base_path {base_path} must be an absolute path.")
115
116 super().__init__(
117 base_path=base_path,
118 provider_name=PROVIDER,
119 config_dict=config_dict,
120 telemetry_provider=telemetry_provider,
121 )
122
123 def _translate_errors(
124 self,
125 func: Callable[[], _T],
126 operation: str,
127 path: str,
128 ) -> _T:
129 """
130 Translates errors like timeouts and client errors.
131
132 :param func: The function that performs the actual file operation.
133 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
134 :param path: The path to the object.
135
136 :return: The result of the file operation, typically the return value of the `func` callable.
137 """
138 try:
139 return func()
140 except FileNotFoundError:
141 raise
142 except Exception as error:
143 raise RuntimeError(f"Failed to {operation} object(s) at {path}, error: {error}") from error
144
145 def _put_object(
146 self,
147 path: str,
148 body: bytes,
149 if_match: Optional[str] = None,
150 if_none_match: Optional[str] = None,
151 attributes: Optional[dict[str, str]] = None,
152 ) -> int:
153 def _invoke_api() -> int:
154 safe_makedirs(os.path.dirname(path))
155 atomic_write(source=BytesIO(body), destination=path, attributes=attributes)
156 return len(body)
157
158 return self._translate_errors(_invoke_api, operation="PUT", path=path)
159
160 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
161 def _invoke_api() -> bytes:
162 if byte_range:
163 with open(path, "rb") as f:
164 f.seek(byte_range.offset)
165 return f.read(byte_range.size)
166 else:
167 with open(path, "rb") as f:
168 return f.read()
169
170 return self._translate_errors(_invoke_api, operation="GET", path=path)
171
172 def _copy_object(self, src_path: str, dest_path: str) -> int:
173 src_object = self._get_object_metadata(src_path)
174
175 def _invoke_api() -> int:
176 safe_makedirs(os.path.dirname(dest_path))
177 atomic_write(source=src_path, destination=dest_path, attributes=src_object.metadata)
178
179 return src_object.content_length
180
181 return self._translate_errors(_invoke_api, operation="COPY", path=src_path)
182
183 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
184 def _invoke_api() -> None:
185 if os.path.exists(path) and os.path.isfile(path):
186 os.remove(path)
187
188 return self._translate_errors(_invoke_api, operation="DELETE", path=path)
189
190 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
191 is_dir = os.path.isdir(path)
192 if is_dir:
193 path = self._append_delimiter(path)
194
195 def _invoke_api() -> ObjectMetadata:
196 # Get basic file attributes
197 metadata_dict = {}
198 try:
199 json_bytes = xattr.getxattr(path, "user.json")
200 metadata_dict = json.loads(json_bytes.decode("utf-8"))
201 except (OSError, IOError, KeyError, json.JSONDecodeError, AttributeError) as e:
202 # Silently ignore if xattr doesn't exist, can't be read, or is corrupted
203 logger.debug(f"Failed to read extended attributes from {path}: {e}")
204 pass
205
206 return ObjectMetadata(
207 key=path,
208 type="directory" if is_dir else "file",
209 content_length=0 if is_dir else os.path.getsize(path),
210 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
211 metadata=metadata_dict if metadata_dict else None,
212 )
213
214 return self._translate_errors(_invoke_api, operation="HEAD", path=path)
215
216 def _list_objects(
217 self,
218 path: str,
219 start_after: Optional[str] = None,
220 end_at: Optional[str] = None,
221 include_directories: bool = False,
222 follow_symlinks: bool = True,
223 ) -> Iterator[ObjectMetadata]:
224 start_after = os.path.relpath(start_after, self._base_path) if start_after else None
225 end_at = os.path.relpath(end_at, self._base_path) if end_at else None
226
227 def _invoke_api() -> Iterator[ObjectMetadata]:
228 if os.path.isfile(path):
229 yield ObjectMetadata(
230 key=os.path.relpath(path, self._base_path), # relative path to the base path
231 content_length=os.path.getsize(path),
232 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
233 )
234 dir_path = path.rstrip("/") + "/"
235 if not os.path.isdir(dir_path): # expect the input to be a directory
236 return
237
238 yield from self._explore_directory(dir_path, start_after, end_at, include_directories, follow_symlinks)
239
240 return self._translate_errors(_invoke_api, operation="LIST", path=path)
241
242 def _explore_directory(
243 self,
244 dir_path: str,
245 start_after: Optional[str],
246 end_at: Optional[str],
247 include_directories: bool,
248 follow_symlinks: bool = True,
249 ) -> Iterator[ObjectMetadata]:
250 """
251 Recursively explore a directory and yield objects in lexicographical order.
252
253 :param dir_path: The directory path to explore
254 :param start_after: The key to start after
255 :param end_at: The key to end at
256 :param include_directories: Whether to include directories in the result
257 :param follow_symlinks: Whether to follow symbolic links. When False, symlinks are skipped.
258 """
259 try:
260 # List contents of current directory
261 dir_entries = os.listdir(dir_path)
262 dir_entries.sort() # Sort entries for consistent ordering
263
264 # Collect all entries in this directory
265 entries = []
266
267 for entry in dir_entries:
268 full_path = os.path.join(dir_path, entry)
269
270 # Skip symlinks if follow_symlinks is False
271 if not follow_symlinks and os.path.islink(full_path):
272 continue
273
274 relative_path = os.path.relpath(full_path, self._base_path)
275
276 # Check if this entry is within our range
277 if (start_after is None or start_after < relative_path) and (end_at is None or relative_path <= end_at):
278 if os.path.isfile(full_path):
279 entries.append((relative_path, full_path, _EntryType.FILE))
280 elif os.path.isdir(full_path):
281 if include_directories:
282 entries.append((relative_path, full_path, _EntryType.DIRECTORY))
283 else:
284 # Add directory for recursive exploration
285 entries.append((relative_path, full_path, _EntryType.DIRECTORY_TO_EXPLORE))
286
287 # Sort entries by relative path
288 entries.sort(key=lambda x: x[0])
289
290 # Process entries in order
291 for relative_path, full_path, entry_type in entries:
292 if entry_type == _EntryType.FILE:
293 yield ObjectMetadata(
294 key=relative_path,
295 content_length=os.path.getsize(full_path),
296 last_modified=datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc),
297 )
298 elif entry_type == _EntryType.DIRECTORY:
299 yield ObjectMetadata(
300 key=relative_path,
301 content_length=0,
302 type="directory",
303 last_modified=AWARE_DATETIME_MIN,
304 )
305 elif entry_type == _EntryType.DIRECTORY_TO_EXPLORE:
306 # Recursively explore this directory
307 yield from self._explore_directory(
308 full_path, start_after, end_at, include_directories, follow_symlinks
309 )
310
311 except (OSError, PermissionError) as e:
312 logger.warning(f"Failed to list contents of {dir_path}, caused by: {e}")
313 return
314
315 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
316 safe_makedirs(os.path.dirname(remote_path))
317
318 filesize: int = 0
319 if isinstance(f, str):
320 filesize = os.path.getsize(f)
321 elif isinstance(f, StringIO):
322 filesize = len(f.getvalue().encode("utf-8"))
323 else:
324 filesize = len(f.getvalue()) # type: ignore
325
326 def _invoke_api() -> int:
327 atomic_write(source=f, destination=remote_path, attributes=attributes)
328
329 return filesize
330
331 return self._translate_errors(_invoke_api, operation="PUT", path=remote_path)
332
333 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
334 filesize = metadata.content_length if metadata else os.path.getsize(remote_path)
335
336 if isinstance(f, str):
337
338 def _invoke_api() -> int:
339 if os.path.dirname(f):
340 safe_makedirs(os.path.dirname(f))
341 atomic_write(source=remote_path, destination=f)
342
343 return filesize
344
345 return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
346 elif isinstance(f, StringIO):
347
348 def _invoke_api() -> int:
349 with open(remote_path, "r", encoding="utf-8") as src:
350 while chunk := src.read(READ_CHUNK_SIZE):
351 f.write(chunk)
352
353 return filesize
354
355 return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
356 else:
357
358 def _invoke_api() -> int:
359 with open(remote_path, "rb") as src:
360 while chunk := src.read(READ_CHUNK_SIZE):
361 f.write(chunk)
362
363 return filesize
364
365 return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
366
[docs]
367 def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]:
368 pattern = self._prepend_base_path(pattern)
369 keys = list(glob.glob(pattern, recursive=True))
370 if attribute_filter_expression:
371 filtered_keys = []
372 evaluator = create_attribute_filter_evaluator(attribute_filter_expression)
373 for key in keys:
374 obj_metadata = self._get_object_metadata(key)
375 if matches_attribute_filter_expression(obj_metadata, evaluator):
376 filtered_keys.append(key)
377 keys = filtered_keys
378 if self._base_path == "/":
379 return keys
380 else:
381 # NOTE: PosixStorageProvider does not have the concept of bucket and prefix.
382 # So we drop the base_path from it.
383 return [key.replace(self._base_path, "", 1).lstrip("/") for key in keys]
384
[docs]
385 def is_file(self, path: str) -> bool:
386 path = self._prepend_base_path(path)
387 return os.path.isfile(path)
388
[docs]
389 def rmtree(self, path: str) -> None:
390 path = self._prepend_base_path(path)
391 shutil.rmtree(path)