Source code for multistorageclient.providers.posix_file

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import glob
 17import json
 18import logging
 19import os
 20import shutil
 21import tempfile
 22from collections.abc import Callable, Iterator
 23from datetime import datetime, timezone
 24from enum import Enum
 25from io import BytesIO, StringIO
 26from typing import IO, Any, Optional, TypeVar, Union
 27
 28import xattr
 29
 30from ..telemetry import Telemetry
 31from ..types import AWARE_DATETIME_MIN, ObjectMetadata, Range
 32from ..utils import create_attribute_filter_evaluator, matches_attribute_filter_expression, validate_attributes
 33from .base import BaseStorageProvider
 34
 35_T = TypeVar("_T")
 36
 37PROVIDER = "file"
 38READ_CHUNK_SIZE = 8192
 39
 40logger = logging.getLogger(__name__)
 41
 42
 43class _EntryType(Enum):
 44    """
 45    An enum representing the type of an entry in a directory.
 46    """
 47
 48    FILE = 1
 49    DIRECTORY = 2
 50    DIRECTORY_TO_EXPLORE = 3
 51
 52
[docs] 53def atomic_write(source: Union[str, IO], destination: str, attributes: Optional[dict[str, str]] = None): 54 """ 55 Writes the contents of a file to the specified destination path. 56 57 This function ensures that the file write operation is atomic, meaning the output file is either fully written or not modified at all. 58 This is achieved by writing to a temporary file first and then renaming it to the destination path. 59 60 :param source: The input file to read from. It can be a string representing the path to a file, or an open file-like object (IO). 61 :param destination: The path to the destination file where the contents should be written. 62 :param attributes: The attributes to set on the file. 63 """ 64 65 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(destination), prefix=".") as fp: 66 temp_file_path = fp.name 67 if isinstance(source, str): 68 with open(source, mode="rb") as src: 69 while chunk := src.read(READ_CHUNK_SIZE): 70 fp.write(chunk) 71 else: 72 while chunk := source.read(READ_CHUNK_SIZE): 73 fp.write(chunk) 74 75 # Set attributes on temp file if provided 76 validated_attributes = validate_attributes(attributes) 77 if validated_attributes: 78 try: 79 xattr.setxattr(temp_file_path, "user.json", json.dumps(validated_attributes).encode("utf-8")) 80 except OSError as e: 81 logger.debug(f"Failed to set extended attributes on temp file {temp_file_path}: {e}") 82 83 os.rename(src=temp_file_path, dst=destination)
84 85
[docs] 86class PosixFileStorageProvider(BaseStorageProvider): 87 """ 88 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with POSIX file systems. 89 """ 90 91 def __init__( 92 self, 93 base_path: str, 94 config_dict: Optional[dict[str, Any]] = None, 95 telemetry_provider: Optional[Callable[[], Telemetry]] = None, 96 **kwargs: Any, 97 ) -> None: 98 """ 99 :param base_path: The root prefix path within the POSIX file system where all operations will be scoped. 100 :param config_dict: Resolved MSC config. 101 :param telemetry_provider: A function that provides a telemetry instance. 102 """ 103 104 # Validate POSIX path 105 if base_path == "": 106 base_path = "/" 107 108 if not base_path.startswith("/"): 109 raise ValueError(f"The base_path {base_path} must be an absolute path.") 110 111 super().__init__( 112 base_path=base_path, 113 provider_name=PROVIDER, 114 config_dict=config_dict, 115 telemetry_provider=telemetry_provider, 116 ) 117 118 def _translate_errors( 119 self, 120 func: Callable[[], _T], 121 operation: str, 122 path: str, 123 ) -> _T: 124 """ 125 Translates errors like timeouts and client errors. 126 127 :param func: The function that performs the actual file operation. 128 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE"). 129 :param path: The path to the object. 130 131 :return: The result of the file operation, typically the return value of the `func` callable. 132 """ 133 try: 134 return func() 135 except FileNotFoundError: 136 raise 137 except Exception as error: 138 raise RuntimeError(f"Failed to {operation} object(s) at {path}, error: {error}") from error 139 140 def _put_object( 141 self, 142 path: str, 143 body: bytes, 144 if_match: Optional[str] = None, 145 if_none_match: Optional[str] = None, 146 attributes: Optional[dict[str, str]] = None, 147 ) -> int: 148 def _invoke_api() -> int: 149 os.makedirs(os.path.dirname(path), exist_ok=True) 150 atomic_write(source=BytesIO(body), destination=path, attributes=attributes) 151 return len(body) 152 153 return self._translate_errors(_invoke_api, operation="PUT", path=path) 154 155 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes: 156 def _invoke_api() -> bytes: 157 if byte_range: 158 with open(path, "rb") as f: 159 f.seek(byte_range.offset) 160 return f.read(byte_range.size) 161 else: 162 with open(path, "rb") as f: 163 return f.read() 164 165 return self._translate_errors(_invoke_api, operation="GET", path=path) 166 167 def _copy_object(self, src_path: str, dest_path: str) -> int: 168 src_object = self._get_object_metadata(src_path) 169 170 def _invoke_api() -> int: 171 os.makedirs(os.path.dirname(dest_path), exist_ok=True) 172 atomic_write(source=src_path, destination=dest_path, attributes=src_object.metadata) 173 174 return src_object.content_length 175 176 return self._translate_errors(_invoke_api, operation="COPY", path=src_path) 177 178 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None: 179 def _invoke_api() -> None: 180 if os.path.exists(path) and os.path.isfile(path): 181 os.remove(path) 182 183 return self._translate_errors(_invoke_api, operation="DELETE", path=path) 184 185 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata: 186 is_dir = os.path.isdir(path) 187 if is_dir: 188 path = self._append_delimiter(path) 189 190 def _invoke_api() -> ObjectMetadata: 191 # Get basic file attributes 192 metadata_dict = {} 193 try: 194 json_bytes = xattr.getxattr(path, "user.json") 195 metadata_dict = json.loads(json_bytes.decode("utf-8")) 196 except (OSError, IOError, KeyError, json.JSONDecodeError, AttributeError) as e: 197 # Silently ignore if xattr doesn't exist, can't be read, or is corrupted 198 logger.debug(f"Failed to read extended attributes from {path}: {e}") 199 pass 200 201 return ObjectMetadata( 202 key=path, 203 type="directory" if is_dir else "file", 204 content_length=0 if is_dir else os.path.getsize(path), 205 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc), 206 metadata=metadata_dict if metadata_dict else None, 207 ) 208 209 return self._translate_errors(_invoke_api, operation="HEAD", path=path) 210 211 def _list_objects( 212 self, 213 path: str, 214 start_after: Optional[str] = None, 215 end_at: Optional[str] = None, 216 include_directories: bool = False, 217 ) -> Iterator[ObjectMetadata]: 218 def _invoke_api() -> Iterator[ObjectMetadata]: 219 if os.path.isfile(path): 220 yield ObjectMetadata( 221 key=os.path.relpath(path, self._base_path), # relative path to the base path 222 content_length=os.path.getsize(path), 223 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc), 224 ) 225 dir_path = path.rstrip("/") + "/" 226 if not os.path.isdir(dir_path): # expect the input to be a directory 227 return 228 229 yield from self._explore_directory(dir_path, start_after, end_at, include_directories) 230 231 return self._translate_errors(_invoke_api, operation="LIST", path=path) 232 233 def _explore_directory( 234 self, dir_path: str, start_after: Optional[str], end_at: Optional[str], include_directories: bool 235 ) -> Iterator[ObjectMetadata]: 236 """ 237 Recursively explore a directory and yield objects in lexicographical order. 238 """ 239 try: 240 # List contents of current directory 241 dir_entries = os.listdir(dir_path) 242 dir_entries.sort() # Sort entries for consistent ordering 243 244 # Collect all entries in this directory 245 entries = [] 246 247 for entry in dir_entries: 248 full_path = os.path.join(dir_path, entry) 249 250 relative_path = os.path.relpath(full_path, self._base_path) 251 252 # Check if this entry is within our range 253 if (start_after is None or start_after < relative_path) and (end_at is None or relative_path <= end_at): 254 if os.path.isfile(full_path): 255 entries.append((relative_path, full_path, _EntryType.FILE)) 256 elif os.path.isdir(full_path): 257 if include_directories: 258 entries.append((relative_path, full_path, _EntryType.DIRECTORY)) 259 else: 260 # Add directory for recursive exploration 261 entries.append((relative_path, full_path, _EntryType.DIRECTORY_TO_EXPLORE)) 262 263 # Sort entries by relative path 264 entries.sort(key=lambda x: x[0]) 265 266 # Process entries in order 267 for relative_path, full_path, entry_type in entries: 268 if entry_type == _EntryType.FILE: 269 yield ObjectMetadata( 270 key=relative_path, 271 content_length=os.path.getsize(full_path), 272 last_modified=datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc), 273 ) 274 elif entry_type == _EntryType.DIRECTORY: 275 yield ObjectMetadata( 276 key=relative_path, 277 content_length=0, 278 type="directory", 279 last_modified=AWARE_DATETIME_MIN, 280 ) 281 elif entry_type == _EntryType.DIRECTORY_TO_EXPLORE: 282 # Recursively explore this directory 283 yield from self._explore_directory(full_path, start_after, end_at, include_directories) 284 285 except (OSError, PermissionError) as e: 286 logger.warning(f"Failed to list contents of {dir_path}, caused by: {e}") 287 return 288 289 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int: 290 os.makedirs(os.path.dirname(remote_path), exist_ok=True) 291 292 filesize: int = 0 293 if isinstance(f, str): 294 filesize = os.path.getsize(f) 295 elif isinstance(f, StringIO): 296 filesize = len(f.getvalue().encode("utf-8")) 297 else: 298 filesize = len(f.getvalue()) # type: ignore 299 300 def _invoke_api() -> int: 301 atomic_write(source=f, destination=remote_path, attributes=attributes) 302 303 return filesize 304 305 return self._translate_errors(_invoke_api, operation="PUT", path=remote_path) 306 307 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int: 308 filesize = metadata.content_length if metadata else os.path.getsize(remote_path) 309 310 if isinstance(f, str): 311 312 def _invoke_api() -> int: 313 if os.path.dirname(f): 314 os.makedirs(os.path.dirname(f), exist_ok=True) 315 atomic_write(source=remote_path, destination=f) 316 317 return filesize 318 319 return self._translate_errors(_invoke_api, operation="GET", path=remote_path) 320 elif isinstance(f, StringIO): 321 322 def _invoke_api() -> int: 323 with open(remote_path, "r", encoding="utf-8") as src: 324 while chunk := src.read(READ_CHUNK_SIZE): 325 f.write(chunk) 326 327 return filesize 328 329 return self._translate_errors(_invoke_api, operation="GET", path=remote_path) 330 else: 331 332 def _invoke_api() -> int: 333 with open(remote_path, "rb") as src: 334 while chunk := src.read(READ_CHUNK_SIZE): 335 f.write(chunk) 336 337 return filesize 338 339 return self._translate_errors(_invoke_api, operation="GET", path=remote_path) 340
[docs] 341 def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]: 342 pattern = self._prepend_base_path(pattern) 343 keys = list(glob.glob(pattern, recursive=True)) 344 if attribute_filter_expression: 345 filtered_keys = [] 346 evaluator = create_attribute_filter_evaluator(attribute_filter_expression) 347 for key in keys: 348 obj_metadata = self._get_object_metadata(key) 349 if matches_attribute_filter_expression(obj_metadata, evaluator): 350 filtered_keys.append(key) 351 keys = filtered_keys 352 if self._base_path == "/": 353 return keys 354 else: 355 # NOTE: PosixStorageProvider does not have the concept of bucket and prefix. 356 # So we drop the base_path from it. 357 return [key.replace(self._base_path, "", 1).lstrip("/") for key in keys]
358
[docs] 359 def is_file(self, path: str) -> bool: 360 path = self._prepend_base_path(path) 361 return os.path.isfile(path)
362
[docs] 363 def rmtree(self, path: str) -> None: 364 path = self._prepend_base_path(path) 365 shutil.rmtree(path)