1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import glob
17import json
18import logging
19import os
20import shutil
21import tempfile
22import time
23from collections.abc import Callable, Iterator, Sequence, Sized
24from datetime import datetime, timezone
25from enum import Enum
26from io import BytesIO, StringIO
27from typing import IO, Any, Optional, TypeVar, Union
28
29import opentelemetry.metrics as api_metrics
30import xattr
31
32from ..telemetry import Telemetry
33from ..telemetry.attributes.base import AttributesProvider
34from ..types import AWARE_DATETIME_MIN, ObjectMetadata, Range
35from ..utils import create_attribute_filter_evaluator, matches_attribute_filter_expression, validate_attributes
36from .base import BaseStorageProvider
37
38_T = TypeVar("_T")
39
40PROVIDER = "file"
41
42logger = logging.getLogger(__name__)
43
44
45class _EntryType(Enum):
46 """
47 An enum representing the type of an entry in a directory.
48 """
49
50 FILE = 1
51 DIRECTORY = 2
52 DIRECTORY_TO_EXPLORE = 3
53
54
[docs]
55def atomic_write(source: Union[str, IO], destination: str, attributes: Optional[dict[str, str]] = None):
56 """
57 Writes the contents of a file to the specified destination path.
58
59 This function ensures that the file write operation is atomic, meaning the output file is either fully written or not modified at all.
60 This is achieved by writing to a temporary file first and then renaming it to the destination path.
61
62 :param source: The input file to read from. It can be a string representing the path to a file, or an open file-like object (IO).
63 :param destination: The path to the destination file where the contents should be written.
64 :param attributes: The attributes to set on the file.
65 """
66
67 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(destination), prefix=".") as fp:
68 temp_file_path = fp.name
69 if isinstance(source, str):
70 with open(source, mode="rb") as src:
71 fp.write(src.read())
72 else:
73 fp.write(source.read())
74
75 # Set attributes on temp file if provided
76 validated_attributes = validate_attributes(attributes)
77 if validated_attributes:
78 try:
79 xattr.setxattr(temp_file_path, "user.json", json.dumps(validated_attributes).encode("utf-8"))
80 except OSError as e:
81 logger.debug(f"Failed to set extended attributes on temp file {temp_file_path}: {e}")
82
83 os.rename(src=temp_file_path, dst=destination)
84
85
[docs]
86class PosixFileStorageProvider(BaseStorageProvider):
87 """
88 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with POSIX file systems.
89 """
90
91 def __init__(
92 self,
93 base_path: str,
94 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
95 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
96 metric_attributes_providers: Sequence[AttributesProvider] = (),
97 **kwargs: Any,
98 ) -> None:
99 """
100 :param base_path: The root prefix path within the POSIX file system where all operations will be scoped.
101 :param metric_counters: Metric counters.
102 :param metric_gauges: Metric gauges.
103 :param metric_attributes_providers: Metric attributes providers.
104 """
105
106 # Validate POSIX path
107 if base_path == "":
108 base_path = "/"
109
110 if not base_path.startswith("/"):
111 raise ValueError(f"The base_path {base_path} must be an absolute path.")
112
113 super().__init__(
114 base_path=base_path,
115 provider_name=PROVIDER,
116 metric_counters=metric_counters,
117 metric_gauges=metric_gauges,
118 metric_attributes_providers=metric_attributes_providers,
119 )
120
121 def _collect_metrics(
122 self,
123 func: Callable[[], _T],
124 operation: str,
125 path: str,
126 put_object_size: Optional[int] = None,
127 get_object_size: Optional[int] = None,
128 ) -> _T:
129 """
130 Collects and records performance metrics around file operations such as PUT, GET, DELETE, etc.
131
132 This method wraps an file operation and measures the time it takes to complete, along with recording
133 the size of the object if applicable.
134
135 :param func: The function that performs the actual file operation.
136 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
137 :param path: The path to the object.
138 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
139 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
140
141 :return: The result of the file operation, typically the return value of the `func` callable.
142 """
143 start_time = time.time()
144 status_code = 200
145
146 object_size = None
147 if operation == "PUT":
148 object_size = put_object_size
149 elif operation == "GET" and get_object_size is not None:
150 object_size = get_object_size
151
152 try:
153 result = func()
154 if operation == "GET" and object_size is None and isinstance(result, Sized):
155 object_size = len(result)
156 return result
157 except FileNotFoundError as error:
158 status_code = 404
159 raise error
160 except Exception as error:
161 status_code = -1
162 raise RuntimeError(f"Failed to {operation} object(s) at {path}, error: {error}") from error
163 finally:
164 elapsed_time = time.time() - start_time
165 self._metric_helper.record_duration(
166 elapsed_time, provider=self._provider_name, operation=operation, bucket="", status_code=status_code
167 )
168 if object_size:
169 self._metric_helper.record_object_size(
170 object_size, provider=self._provider_name, operation=operation, bucket="", status_code=status_code
171 )
172
173 def _put_object(
174 self,
175 path: str,
176 body: bytes,
177 if_match: Optional[str] = None,
178 if_none_match: Optional[str] = None,
179 attributes: Optional[dict[str, str]] = None,
180 ) -> int:
181 def _invoke_api() -> int:
182 os.makedirs(os.path.dirname(path), exist_ok=True)
183 atomic_write(source=BytesIO(body), destination=path, attributes=attributes)
184 return len(body)
185
186 return self._collect_metrics(_invoke_api, operation="PUT", path=path, put_object_size=len(body))
187
188 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
189 def _invoke_api() -> bytes:
190 if byte_range:
191 with open(path, "rb") as f:
192 f.seek(byte_range.offset)
193 return f.read(byte_range.size)
194 else:
195 with open(path, "rb") as f:
196 return f.read()
197
198 return self._collect_metrics(_invoke_api, operation="GET", path=path)
199
200 def _copy_object(self, src_path: str, dest_path: str) -> int:
201 src_object = self._get_object_metadata(src_path)
202
203 def _invoke_api() -> int:
204 os.makedirs(os.path.dirname(dest_path), exist_ok=True)
205 atomic_write(source=src_path, destination=dest_path, attributes=src_object.metadata)
206
207 return src_object.content_length
208
209 return self._collect_metrics(
210 _invoke_api,
211 operation="COPY",
212 path=src_path,
213 put_object_size=src_object.content_length,
214 )
215
216 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
217 def _invoke_api() -> None:
218 if os.path.exists(path) and os.path.isfile(path):
219 os.remove(path)
220
221 return self._collect_metrics(_invoke_api, operation="DELETE", path=path)
222
223 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
224 is_dir = os.path.isdir(path)
225 if is_dir:
226 path = self._append_delimiter(path)
227
228 def _invoke_api() -> ObjectMetadata:
229 # Get basic file attributes
230 metadata_dict = {}
231 try:
232 json_bytes = xattr.getxattr(path, "user.json")
233 metadata_dict = json.loads(json_bytes.decode("utf-8"))
234 except (OSError, IOError, KeyError, json.JSONDecodeError, AttributeError) as e:
235 # Silently ignore if xattr doesn't exist, can't be read, or is corrupted
236 logger.debug(f"Failed to read extended attributes from {path}: {e}")
237 pass
238
239 return ObjectMetadata(
240 key=path,
241 type="directory" if is_dir else "file",
242 content_length=0 if is_dir else os.path.getsize(path),
243 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
244 metadata=metadata_dict if metadata_dict else None,
245 )
246
247 return self._collect_metrics(_invoke_api, operation="HEAD", path=path)
248
249 def _list_objects(
250 self,
251 path: str,
252 start_after: Optional[str] = None,
253 end_at: Optional[str] = None,
254 include_directories: bool = False,
255 ) -> Iterator[ObjectMetadata]:
256 def _invoke_api() -> Iterator[ObjectMetadata]:
257 if os.path.isfile(path):
258 yield ObjectMetadata(
259 key=os.path.relpath(path, self._base_path), # relative path to the base path
260 content_length=os.path.getsize(path),
261 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
262 )
263 dir_path = path.rstrip("/") + "/"
264 if not os.path.isdir(dir_path): # expect the input to be a directory
265 return
266
267 yield from self._explore_directory(dir_path, start_after, end_at, include_directories)
268
269 return self._collect_metrics(_invoke_api, operation="LIST", path=path)
270
271 def _explore_directory(
272 self, dir_path: str, start_after: Optional[str], end_at: Optional[str], include_directories: bool
273 ) -> Iterator[ObjectMetadata]:
274 """
275 Recursively explore a directory and yield objects in lexicographical order.
276 """
277 try:
278 # List contents of current directory
279 dir_entries = os.listdir(dir_path)
280 dir_entries.sort() # Sort entries for consistent ordering
281
282 # Collect all entries in this directory
283 entries = []
284
285 for entry in dir_entries:
286 full_path = os.path.join(dir_path, entry)
287
288 relative_path = os.path.relpath(full_path, self._base_path)
289
290 # Check if this entry is within our range
291 if (start_after is None or start_after < relative_path) and (end_at is None or relative_path <= end_at):
292 if os.path.isfile(full_path):
293 entries.append((relative_path, full_path, _EntryType.FILE))
294 elif os.path.isdir(full_path):
295 if include_directories:
296 entries.append((relative_path, full_path, _EntryType.DIRECTORY))
297 else:
298 # Add directory for recursive exploration
299 entries.append((relative_path, full_path, _EntryType.DIRECTORY_TO_EXPLORE))
300
301 # Sort entries by relative path
302 entries.sort(key=lambda x: x[0])
303
304 # Process entries in order
305 for relative_path, full_path, entry_type in entries:
306 if entry_type == _EntryType.FILE:
307 yield ObjectMetadata(
308 key=relative_path,
309 content_length=os.path.getsize(full_path),
310 last_modified=datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc),
311 )
312 elif entry_type == _EntryType.DIRECTORY:
313 yield ObjectMetadata(
314 key=relative_path,
315 content_length=0,
316 type="directory",
317 last_modified=AWARE_DATETIME_MIN,
318 )
319 elif entry_type == _EntryType.DIRECTORY_TO_EXPLORE:
320 # Recursively explore this directory
321 yield from self._explore_directory(full_path, start_after, end_at, include_directories)
322
323 except (OSError, PermissionError) as e:
324 logger.warning(f"Failed to list contents of {dir_path}, caused by: {e}")
325 return
326
327 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
328 os.makedirs(os.path.dirname(remote_path), exist_ok=True)
329
330 filesize: int = 0
331 if isinstance(f, str):
332 filesize = os.path.getsize(f)
333 elif isinstance(f, StringIO):
334 filesize = len(f.getvalue().encode("utf-8"))
335 else:
336 filesize = len(f.getvalue()) # type: ignore
337
338 def _invoke_api() -> int:
339 atomic_write(source=f, destination=remote_path, attributes=attributes)
340
341 return filesize
342
343 return self._collect_metrics(_invoke_api, operation="PUT", path=remote_path, put_object_size=filesize)
344
345 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
346 filesize = metadata.content_length if metadata else os.path.getsize(remote_path)
347
348 if isinstance(f, str):
349
350 def _invoke_api() -> int:
351 if os.path.dirname(f):
352 os.makedirs(os.path.dirname(f), exist_ok=True)
353 atomic_write(source=remote_path, destination=f)
354
355 return filesize
356
357 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
358 elif isinstance(f, StringIO):
359
360 def _invoke_api() -> int:
361 with open(remote_path, "r", encoding="utf-8") as src:
362 f.write(src.read())
363
364 return filesize
365
366 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
367 else:
368
369 def _invoke_api() -> int:
370 with open(remote_path, "rb") as src:
371 f.write(src.read())
372
373 return filesize
374
375 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
376
[docs]
377 def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]:
378 pattern = self._prepend_base_path(pattern)
379 keys = list(glob.glob(pattern, recursive=True))
380 if attribute_filter_expression:
381 filtered_keys = []
382 evaluator = create_attribute_filter_evaluator(attribute_filter_expression)
383 for key in keys:
384 obj_metadata = self._get_object_metadata(key)
385 if matches_attribute_filter_expression(obj_metadata, evaluator):
386 filtered_keys.append(key)
387 keys = filtered_keys
388 if self._base_path == "/":
389 return keys
390 else:
391 # NOTE: PosixStorageProvider does not have the concept of bucket and prefix.
392 # So we drop the base_path from it.
393 return [key.replace(self._base_path, "", 1).lstrip("/") for key in keys]
394
[docs]
395 def is_file(self, path: str) -> bool:
396 path = self._prepend_base_path(path)
397 return os.path.isfile(path)
398
[docs]
399 def rmtree(self, path: str) -> None:
400 path = self._prepend_base_path(path)
401 shutil.rmtree(path)