1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import glob
17import json
18import logging
19import os
20import shutil
21import tempfile
22import time
23from collections.abc import Callable, Iterator, Sequence, Sized
24from datetime import datetime, timezone
25from enum import Enum
26from io import BytesIO, StringIO
27from typing import IO, Any, Optional, TypeVar, Union
28
29import opentelemetry.metrics as api_metrics
30import xattr
31
32from ..telemetry import Telemetry
33from ..telemetry.attributes.base import AttributesProvider
34from ..types import AWARE_DATETIME_MIN, ObjectMetadata, Range
35from ..utils import create_attribute_filter_evaluator, matches_attribute_filter_expression, validate_attributes
36from .base import BaseStorageProvider
37
38_T = TypeVar("_T")
39
40PROVIDER = "file"
41
42logger = logging.getLogger(__name__)
43
44
45class _EntryType(Enum):
46 """
47 An enum representing the type of an entry in a directory.
48 """
49
50 FILE = 1
51 DIRECTORY = 2
52 DIRECTORY_TO_EXPLORE = 3
53
54
[docs]
55def atomic_write(source: Union[str, IO], destination: str, attributes: Optional[dict[str, str]] = None):
56 """
57 Writes the contents of a file to the specified destination path.
58
59 This function ensures that the file write operation is atomic, meaning the output file is either fully written or not modified at all.
60 This is achieved by writing to a temporary file first and then renaming it to the destination path.
61
62 :param source: The input file to read from. It can be a string representing the path to a file, or an open file-like object (IO).
63 :param destination: The path to the destination file where the contents should be written.
64 :param attributes: The attributes to set on the file.
65 """
66
67 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(destination), prefix=".") as fp:
68 temp_file_path = fp.name
69 if isinstance(source, str):
70 with open(source, mode="rb") as src:
71 fp.write(src.read())
72 else:
73 fp.write(source.read())
74
75 # Set attributes on temp file if provided
76 validated_attributes = validate_attributes(attributes)
77 if validated_attributes:
78 try:
79 xattr.setxattr(temp_file_path, "user.json", json.dumps(validated_attributes).encode("utf-8"))
80 except OSError as e:
81 logger.debug(f"Failed to set extended attributes on temp file {temp_file_path}: {e}")
82
83 os.rename(src=temp_file_path, dst=destination)
84
85
[docs]
86class PosixFileStorageProvider(BaseStorageProvider):
87 """
88 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with POSIX file systems.
89 """
90
91 def __init__(
92 self,
93 base_path: str,
94 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
95 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
96 metric_attributes_providers: Sequence[AttributesProvider] = (),
97 **kwargs: Any,
98 ) -> None:
99 """
100 :param base_path: The root prefix path within the POSIX file system where all operations will be scoped.
101 :param metric_counters: Metric counters.
102 :param metric_gauges: Metric gauges.
103 :param metric_attributes_providers: Metric attributes providers.
104 """
105
106 # Validate POSIX path
107 if base_path == "":
108 base_path = "/"
109
110 if not base_path.startswith("/"):
111 raise ValueError(f"The base_path {base_path} must be an absolute path.")
112
113 super().__init__(
114 base_path=base_path,
115 provider_name=PROVIDER,
116 metric_counters=metric_counters,
117 metric_gauges=metric_gauges,
118 metric_attributes_providers=metric_attributes_providers,
119 )
120
121 def _collect_metrics(
122 self,
123 func: Callable[[], _T],
124 operation: str,
125 path: str,
126 put_object_size: Optional[int] = None,
127 get_object_size: Optional[int] = None,
128 ) -> _T:
129 """
130 Collects and records performance metrics around file operations such as PUT, GET, DELETE, etc.
131
132 This method wraps an file operation and measures the time it takes to complete, along with recording
133 the size of the object if applicable.
134
135 :param func: The function that performs the actual file operation.
136 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
137 :param path: The path to the object.
138 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
139 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
140
141 :return: The result of the file operation, typically the return value of the `func` callable.
142 """
143 start_time = time.time()
144 status_code = 200
145
146 object_size = None
147 if operation == "PUT":
148 object_size = put_object_size
149 elif operation == "GET" and get_object_size is not None:
150 object_size = get_object_size
151
152 try:
153 result = func()
154 if operation == "GET" and object_size is None and isinstance(result, Sized):
155 object_size = len(result)
156 return result
157 except FileNotFoundError as error:
158 status_code = 404
159 raise error
160 except Exception as error:
161 status_code = -1
162 raise RuntimeError(f"Failed to {operation} object(s) at {path}, error: {error}") from error
163 finally:
164 elapsed_time = time.time() - start_time
165 self._metric_helper.record_duration(
166 elapsed_time, provider=self._provider_name, operation=operation, bucket="", status_code=status_code
167 )
168 if object_size:
169 self._metric_helper.record_object_size(
170 object_size, provider=self._provider_name, operation=operation, bucket="", status_code=status_code
171 )
172
173 def _put_object(
174 self,
175 path: str,
176 body: bytes,
177 if_match: Optional[str] = None,
178 if_none_match: Optional[str] = None,
179 attributes: Optional[dict[str, str]] = None,
180 ) -> int:
181 def _invoke_api() -> int:
182 os.makedirs(os.path.dirname(path), exist_ok=True)
183 atomic_write(source=BytesIO(body), destination=path, attributes=attributes)
184 return len(body)
185
186 return self._collect_metrics(_invoke_api, operation="PUT", path=path, put_object_size=len(body))
187
188 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
189 def _invoke_api() -> bytes:
190 if byte_range:
191 with open(path, "rb") as f:
192 f.seek(byte_range.offset)
193 return f.read(byte_range.size)
194 else:
195 with open(path, "rb") as f:
196 return f.read()
197
198 return self._collect_metrics(_invoke_api, operation="GET", path=path)
199
200 def _copy_object(self, src_path: str, dest_path: str) -> int:
201 src_object = self._get_object_metadata(src_path)
202
203 def _invoke_api() -> int:
204 os.makedirs(os.path.dirname(dest_path), exist_ok=True)
205 atomic_write(source=src_path, destination=dest_path, attributes=src_object.metadata)
206
207 return src_object.content_length
208
209 return self._collect_metrics(
210 _invoke_api,
211 operation="COPY",
212 path=src_path,
213 put_object_size=src_object.content_length,
214 )
215
216 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
217 def _invoke_api() -> None:
218 if os.path.exists(path) and os.path.isfile(path):
219 os.remove(path)
220
221 return self._collect_metrics(_invoke_api, operation="DELETE", path=path)
222
223 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
224 is_dir = os.path.isdir(path)
225 if is_dir:
226 path = self._append_delimiter(path)
227
228 def _invoke_api() -> ObjectMetadata:
229 # Get basic file attributes
230 metadata_dict = {}
231 try:
232 json_bytes = xattr.getxattr(path, "user.json")
233 metadata_dict = json.loads(json_bytes.decode("utf-8"))
234 except (OSError, IOError, KeyError, json.JSONDecodeError, AttributeError) as e:
235 # Silently ignore if xattr doesn't exist, can't be read, or is corrupted
236 logger.debug(f"Failed to read extended attributes from {path}: {e}")
237 pass
238
239 return ObjectMetadata(
240 key=path,
241 type="directory" if is_dir else "file",
242 content_length=0 if is_dir else os.path.getsize(path),
243 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
244 metadata=metadata_dict if metadata_dict else None,
245 )
246
247 return self._collect_metrics(_invoke_api, operation="HEAD", path=path)
248
249 def _list_objects(
250 self,
251 prefix: str,
252 start_after: Optional[str] = None,
253 end_at: Optional[str] = None,
254 include_directories: bool = False,
255 ) -> Iterator[ObjectMetadata]:
256 def _invoke_api() -> Iterator[ObjectMetadata]:
257 # Get parent directory and list its contents
258 parent_dir = os.path.dirname(prefix)
259 if not os.path.exists(parent_dir):
260 return
261
262 yield from self._explore_directory(parent_dir, prefix, start_after, end_at, include_directories)
263
264 return self._collect_metrics(_invoke_api, operation="LIST", path=prefix)
265
266 def _explore_directory(
267 self, dir_path: str, prefix: str, start_after: Optional[str], end_at: Optional[str], include_directories: bool
268 ) -> Iterator[ObjectMetadata]:
269 """
270 Recursively explore a directory and yield objects in lexicographical order.
271 """
272 try:
273 # List contents of current directory
274 dir_entries = os.listdir(dir_path)
275 dir_entries.sort() # Sort entries for consistent ordering
276
277 # Collect all entries in this directory
278 entries = []
279
280 for entry in dir_entries:
281 full_path = os.path.join(dir_path, entry)
282 if not full_path.startswith(prefix):
283 continue
284
285 relative_path = os.path.relpath(full_path, self._base_path)
286
287 # Check if this entry is within our range
288 if (start_after is None or start_after < relative_path) and (end_at is None or relative_path <= end_at):
289 if os.path.isfile(full_path):
290 entries.append((relative_path, full_path, _EntryType.FILE))
291 elif os.path.isdir(full_path):
292 if include_directories:
293 entries.append((relative_path, full_path, _EntryType.DIRECTORY))
294 else:
295 # Add directory for recursive exploration
296 entries.append((relative_path, full_path, _EntryType.DIRECTORY_TO_EXPLORE))
297
298 # Sort entries by relative path
299 entries.sort(key=lambda x: x[0])
300
301 # Process entries in order
302 for relative_path, full_path, entry_type in entries:
303 if entry_type == _EntryType.FILE:
304 yield ObjectMetadata(
305 key=relative_path,
306 content_length=os.path.getsize(full_path),
307 last_modified=datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc),
308 )
309 elif entry_type == _EntryType.DIRECTORY:
310 yield ObjectMetadata(
311 key=relative_path,
312 content_length=0,
313 type="directory",
314 last_modified=AWARE_DATETIME_MIN,
315 )
316 elif entry_type == _EntryType.DIRECTORY_TO_EXPLORE:
317 # Recursively explore this directory
318 yield from self._explore_directory(full_path, prefix, start_after, end_at, include_directories)
319
320 except (OSError, PermissionError) as e:
321 logger.warning(f"Failed to list contents of {dir_path}, caused by: {e}")
322 return
323
324 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
325 os.makedirs(os.path.dirname(remote_path), exist_ok=True)
326
327 filesize: int = 0
328 if isinstance(f, str):
329 filesize = os.path.getsize(f)
330 elif isinstance(f, StringIO):
331 filesize = len(f.getvalue().encode("utf-8"))
332 else:
333 filesize = len(f.getvalue()) # type: ignore
334
335 def _invoke_api() -> int:
336 atomic_write(source=f, destination=remote_path, attributes=attributes)
337
338 return filesize
339
340 return self._collect_metrics(_invoke_api, operation="PUT", path=remote_path, put_object_size=filesize)
341
342 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
343 filesize = metadata.content_length if metadata else os.path.getsize(remote_path)
344
345 if isinstance(f, str):
346
347 def _invoke_api() -> int:
348 os.makedirs(os.path.dirname(f), exist_ok=True)
349 atomic_write(source=remote_path, destination=f)
350
351 return filesize
352
353 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
354 elif isinstance(f, StringIO):
355
356 def _invoke_api() -> int:
357 with open(remote_path, "r", encoding="utf-8") as src:
358 f.write(src.read())
359
360 return filesize
361
362 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
363 else:
364
365 def _invoke_api() -> int:
366 with open(remote_path, "rb") as src:
367 f.write(src.read())
368
369 return filesize
370
371 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
372
[docs]
373 def glob(self, pattern: str, attribute_filter_expression: Optional[str] = None) -> list[str]:
374 pattern = self._prepend_base_path(pattern)
375 keys = list(glob.glob(pattern, recursive=True))
376 if attribute_filter_expression:
377 filtered_keys = []
378 evaluator = create_attribute_filter_evaluator(attribute_filter_expression)
379 for key in keys:
380 obj_metadata = self._get_object_metadata(key)
381 if matches_attribute_filter_expression(obj_metadata, evaluator):
382 filtered_keys.append(key)
383 keys = filtered_keys
384 if self._base_path == "/":
385 return keys
386 else:
387 # NOTE: PosixStorageProvider does not have the concept of bucket and prefix.
388 # So we drop the base_path from it.
389 return [key.replace(self._base_path, "", 1).lstrip("/") for key in keys]
390
[docs]
391 def is_file(self, path: str) -> bool:
392 path = self._prepend_base_path(path)
393 return os.path.isfile(path)
394
[docs]
395 def rmtree(self, path: str) -> None:
396 path = self._prepend_base_path(path)
397 shutil.rmtree(path)