1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import glob
 17import json
 18import logging
 19import os
 20import shutil
 21import tempfile
 22from collections.abc import Callable, Iterator
 23from datetime import datetime, timezone
 24from enum import Enum
 25from io import BytesIO, StringIO
 26from typing import IO, Any, Optional, TypeVar, Union
 27
 28import xattr
 29
 30from ..telemetry import Telemetry
 31from ..types import AWARE_DATETIME_MIN, ObjectMetadata, Range
 32from ..utils import create_attribute_filter_evaluator, matches_attribute_filter_expression, validate_attributes
 33from .base import BaseStorageProvider
 34
 35_T = TypeVar("_T")
 36
 37PROVIDER = "file"
 38READ_CHUNK_SIZE = 8192
 39
 40logger = logging.getLogger(__name__)
 41
 42
 43class _EntryType(Enum):
 44    """
 45    An enum representing the type of an entry in a directory.
 46    """
 47
 48    FILE = 1
 49    DIRECTORY = 2
 50    DIRECTORY_TO_EXPLORE = 3
 51
 52
[docs]
 53def atomic_write(source: Union[str, IO], destination: str, attributes: Optional[dict[str, str]] = None):
 54    """
 55    Writes the contents of a file to the specified destination path.
 56
 57    This function ensures that the file write operation is atomic, meaning the output file is either fully written or not modified at all.
 58    This is achieved by writing to a temporary file first and then renaming it to the destination path.
 59
 60    :param source: The input file to read from. It can be a string representing the path to a file, or an open file-like object (IO).
 61    :param destination: The path to the destination file where the contents should be written.
 62    :param attributes: The attributes to set on the file.
 63    """
 64
 65    with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(destination), prefix=".") as fp:
 66        temp_file_path = fp.name
 67        if isinstance(source, str):
 68            with open(source, mode="rb") as src:
 69                while chunk := src.read(READ_CHUNK_SIZE):
 70                    fp.write(chunk)
 71        else:
 72            while chunk := source.read(READ_CHUNK_SIZE):
 73                fp.write(chunk)
 74
 75        # Set attributes on temp file if provided
 76        validated_attributes = validate_attributes(attributes)
 77        if validated_attributes:
 78            try:
 79                xattr.setxattr(temp_file_path, "user.json", json.dumps(validated_attributes).encode("utf-8"))
 80            except OSError as e:
 81                logger.debug(f"Failed to set extended attributes on temp file {temp_file_path}: {e}")
 82
 83    os.rename(src=temp_file_path, dst=destination) 
 84
 85
[docs]
 86class PosixFileStorageProvider(BaseStorageProvider):
 87    """
 88    A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with POSIX file systems.
 89    """
 90
 91    def __init__(
 92        self,
 93        base_path: str,
 94        config_dict: Optional[dict[str, Any]] = None,
 95        telemetry_provider: Optional[Callable[[], Telemetry]] = None,
 96        **kwargs: Any,
 97    ) -> None:
 98        """
 99        :param base_path: The root prefix path within the POSIX file system where all operations will be scoped.
100        :param config_dict: Resolved MSC config.
101        :param telemetry_provider: A function that provides a telemetry instance.
102        """
103
104        # Validate POSIX path
105        if base_path == "":
106            base_path = "/"
107
108        if not base_path.startswith("/"):
109            raise ValueError(f"The base_path {base_path} must be an absolute path.")
110
111        super().__init__(
112            base_path=base_path,
113            provider_name=PROVIDER,
114            config_dict=config_dict,
115            telemetry_provider=telemetry_provider,
116        )
117
118    def _translate_errors(
119        self,
120        func: Callable[[], _T],
121        operation: str,
122        path: str,
123    ) -> _T:
124        """
125        Translates errors like timeouts and client errors.
126
127        :param func: The function that performs the actual file operation.
128        :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
129        :param path: The path to the object.
130
131        :return: The result of the file operation, typically the return value of the `func` callable.
132        """
133        try:
134            return func()
135        except FileNotFoundError:
136            raise
137        except Exception as error:
138            raise RuntimeError(f"Failed to {operation} object(s) at {path}, error: {error}") from error
139
140    def _put_object(
141        self,
142        path: str,
143        body: bytes,
144        if_match: Optional[str] = None,
145        if_none_match: Optional[str] = None,
146        attributes: Optional[dict[str, str]] = None,
147    ) -> int:
148        def _invoke_api() -> int:
149            os.makedirs(os.path.dirname(path), exist_ok=True)
150            atomic_write(source=BytesIO(body), destination=path, attributes=attributes)
151            return len(body)
152
153        return self._translate_errors(_invoke_api, operation="PUT", path=path)
154
155    def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
156        def _invoke_api() -> bytes:
157            if byte_range:
158                with open(path, "rb") as f:
159                    f.seek(byte_range.offset)
160                    return f.read(byte_range.size)
161            else:
162                with open(path, "rb") as f:
163                    return f.read()
164
165        return self._translate_errors(_invoke_api, operation="GET", path=path)
166
167    def _copy_object(self, src_path: str, dest_path: str) -> int:
168        src_object = self._get_object_metadata(src_path)
169
170        def _invoke_api() -> int:
171            os.makedirs(os.path.dirname(dest_path), exist_ok=True)
172            atomic_write(source=src_path, destination=dest_path, attributes=src_object.metadata)
173
174            return src_object.content_length
175
176        return self._translate_errors(_invoke_api, operation="COPY", path=src_path)
177
178    def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
179        def _invoke_api() -> None:
180            if os.path.exists(path) and os.path.isfile(path):
181                os.remove(path)
182
183        return self._translate_errors(_invoke_api, operation="DELETE", path=path)
184
185    def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
186        is_dir = os.path.isdir(path)
187        if is_dir:
188            path = self._append_delimiter(path)
189
190        def _invoke_api() -> ObjectMetadata:
191            # Get basic file attributes
192            metadata_dict = {}
193            try:
194                json_bytes = xattr.getxattr(path, "user.json")
195                metadata_dict = json.loads(json_bytes.decode("utf-8"))
196            except (OSError, IOError, KeyError, json.JSONDecodeError, AttributeError) as e:
197                # Silently ignore if xattr doesn't exist, can't be read, or is corrupted
198                logger.debug(f"Failed to read extended attributes from {path}: {e}")
199                pass
200
201            return ObjectMetadata(
202                key=path,
203                type="directory" if is_dir else "file",
204                content_length=0 if is_dir else os.path.getsize(path),
205                last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
206                metadata=metadata_dict if metadata_dict else None,
207            )
208
209        return self._translate_errors(_invoke_api, operation="HEAD", path=path)
210
211    def _list_objects(
212        self,
213        path: str,
214        start_after: Optional[str] = None,
215        end_at: Optional[str] = None,
216        include_directories: bool = False,
217        follow_symlinks: bool = True,
218    ) -> Iterator[ObjectMetadata]:
219        def _invoke_api() -> Iterator[ObjectMetadata]:
220            if os.path.isfile(path):
221                yield ObjectMetadata(
222                    key=os.path.relpath(path, self._base_path),  # relative path to the base path
223                    content_length=os.path.getsize(path),
224                    last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
225                )
226            dir_path = path.rstrip("/") + "/"
227            if not os.path.isdir(dir_path):  # expect the input to be a directory
228                return
229
230            yield from self._explore_directory(dir_path, start_after, end_at, include_directories, follow_symlinks)
231
232        return self._translate_errors(_invoke_api, operation="LIST", path=path)
233
234    def _explore_directory(
235        self,
236        dir_path: str,
237        start_after: Optional[str],
238        end_at: Optional[str],
239        include_directories: bool,
240        follow_symlinks: bool = True,
241    ) -> Iterator[ObjectMetadata]:
242        """
243        Recursively explore a directory and yield objects in lexicographical order.
244
245        :param dir_path: The directory path to explore
246        :param start_after: The key to start after
247        :param end_at: The key to end at
248        :param include_directories: Whether to include directories in the result
249        :param follow_symlinks: Whether to follow symbolic links. When False, symlinks are skipped.
250        """
251        try:
252            # List contents of current directory
253            dir_entries = os.listdir(dir_path)
254            dir_entries.sort()  # Sort entries for consistent ordering
255
256            # Collect all entries in this directory
257            entries = []
258
259            for entry in dir_entries:
260                full_path = os.path.join(dir_path, entry)
261
262                # Skip symlinks if follow_symlinks is False
263                if not follow_symlinks and os.path.islink(full_path):
264                    continue
265
266                relative_path = os.path.relpath(full_path, self._base_path)
267
268                # Check if this entry is within our range
269                if (start_after is None or start_after < relative_path) and (end_at is None or relative_path <= end_at):
270                    if os.path.isfile(full_path):
271                        entries.append((relative_path, full_path, _EntryType.FILE))
272                    elif os.path.isdir(full_path):
273                        if include_directories:
274                            entries.append((relative_path, full_path, _EntryType.DIRECTORY))
275                        else:
276                            # Add directory for recursive exploration
277                            entries.append((relative_path, full_path, _EntryType.DIRECTORY_TO_EXPLORE))
278
279            # Sort entries by relative path
280            entries.sort(key=lambda x: x[0])
281
282            # Process entries in order
283            for relative_path, full_path, entry_type in entries:
284                if entry_type == _EntryType.FILE:
285                    yield ObjectMetadata(
286                        key=relative_path,
287                        content_length=os.path.getsize(full_path),
288                        last_modified=datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc),
289                    )
290                elif entry_type == _EntryType.DIRECTORY:
291                    yield ObjectMetadata(
292                        key=relative_path,
293                        content_length=0,
294                        type="directory",
295                        last_modified=AWARE_DATETIME_MIN,
296                    )
297                elif entry_type == _EntryType.DIRECTORY_TO_EXPLORE:
298                    # Recursively explore this directory
299                    yield from self._explore_directory(
300                        full_path, start_after, end_at, include_directories, follow_symlinks
301                    )
302
303        except (OSError, PermissionError) as e:
304            logger.warning(f"Failed to list contents of {dir_path}, caused by: {e}")
305            return
306
307    def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
308        os.makedirs(os.path.dirname(remote_path), exist_ok=True)
309
310        filesize: int = 0
311        if isinstance(f, str):
312            filesize = os.path.getsize(f)
313        elif isinstance(f, StringIO):
314            filesize = len(f.getvalue().encode("utf-8"))
315        else:
316            filesize = len(f.getvalue())  # type: ignore
317
318        def _invoke_api() -> int:
319            atomic_write(source=f, destination=remote_path, attributes=attributes)
320
321            return filesize
322
323        return self._translate_errors(_invoke_api, operation="PUT", path=remote_path)
324
325    def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
326        filesize = metadata.content_length if metadata else os.path.getsize(remote_path)
327
328        if isinstance(f, str):
329
330            def _invoke_api() -> int:
331                if os.path.dirname(f):
332                    os.makedirs(os.path.dirname(f), exist_ok=True)
333                atomic_write(source=remote_path, destination=f)
334
335                return filesize
336
337            return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
338        elif isinstance(f, StringIO):
339
340            def _invoke_api() -> int:
341                with open(remote_path, "r", encoding="utf-8") as src:
342                    while chunk := src.read(READ_CHUNK_SIZE):
343                        f.write(chunk)
344
345                return filesize
346
347            return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
348        else:
349
350            def _invoke_api() -> int:
351                with open(remote_path, "rb") as src:
352                    while chunk := src.read(READ_CHUNK_SIZE):
353                        f.write(chunk)
354
355                return filesize
356
357            return self._translate_errors(_invoke_api, operation="GET", path=remote_path)
358
[docs]
359    def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]:
360        pattern = self._prepend_base_path(pattern)
361        keys = list(glob.glob(pattern, recursive=True))
362        if attribute_filter_expression:
363            filtered_keys = []
364            evaluator = create_attribute_filter_evaluator(attribute_filter_expression)
365            for key in keys:
366                obj_metadata = self._get_object_metadata(key)
367                if matches_attribute_filter_expression(obj_metadata, evaluator):
368                    filtered_keys.append(key)
369            keys = filtered_keys
370        if self._base_path == "/":
371            return keys
372        else:
373            # NOTE: PosixStorageProvider does not have the concept of bucket and prefix.
374            # So we drop the base_path from it.
375            return [key.replace(self._base_path, "", 1).lstrip("/") for key in keys] 
376
[docs]
377    def is_file(self, path: str) -> bool:
378        path = self._prepend_base_path(path)
379        return os.path.isfile(path) 
380
[docs]
381    def rmtree(self, path: str) -> None:
382        path = self._prepend_base_path(path)
383        shutil.rmtree(path)