1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import glob
17import json
18import logging
19import os
20import shutil
21import tempfile
22import time
23from collections.abc import Callable, Iterator, Sequence, Sized
24from datetime import datetime, timezone
25from io import BytesIO, StringIO
26from typing import IO, Any, Optional, TypeVar, Union
27
28import opentelemetry.metrics as api_metrics
29import xattr
30
31from ..telemetry import Telemetry
32from ..telemetry.attributes.base import AttributesProvider
33from ..types import AWARE_DATETIME_MIN, ObjectMetadata, Range
34from ..utils import validate_attributes
35from .base import BaseStorageProvider
36
37_T = TypeVar("_T")
38
39PROVIDER = "file"
40
41logger = logging.getLogger(__name__)
42
43
[docs]
44def atomic_write(source: Union[str, IO], destination: str, attributes: Optional[dict[str, str]] = None):
45 """
46 Writes the contents of a file to the specified destination path.
47
48 This function ensures that the file write operation is atomic, meaning the output file is either fully written or not modified at all.
49 This is achieved by writing to a temporary file first and then renaming it to the destination path.
50
51 :param source: The input file to read from. It can be a string representing the path to a file, or an open file-like object (IO).
52 :param destination: The path to the destination file where the contents should be written.
53 :param attributes: The attributes to set on the file.
54 """
55
56 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(destination), prefix=".") as fp:
57 temp_file_path = fp.name
58 if isinstance(source, str):
59 with open(source, mode="rb") as src:
60 fp.write(src.read())
61 else:
62 fp.write(source.read())
63
64 # Set attributes on temp file if provided
65 validated_attributes = validate_attributes(attributes)
66 if validated_attributes:
67 try:
68 xattr.setxattr(temp_file_path, "user.json", json.dumps(validated_attributes).encode("utf-8"))
69 except OSError as e:
70 logger.debug(f"Failed to set extended attributes on temp file {temp_file_path}: {e}")
71
72 os.rename(src=temp_file_path, dst=destination)
73
74
[docs]
75class PosixFileStorageProvider(BaseStorageProvider):
76 """
77 A concrete implementation of the :py:class:`multistorageclient.types.StorageProvider` for interacting with POSIX file systems.
78 """
79
80 def __init__(
81 self,
82 base_path: str,
83 metric_counters: dict[Telemetry.CounterName, api_metrics.Counter] = {},
84 metric_gauges: dict[Telemetry.GaugeName, api_metrics._Gauge] = {},
85 metric_attributes_providers: Sequence[AttributesProvider] = (),
86 **kwargs: Any,
87 ) -> None:
88 """
89 :param base_path: The root prefix path within the POSIX file system where all operations will be scoped.
90 :param metric_counters: Metric counters.
91 :param metric_gauges: Metric gauges.
92 :param metric_attributes_providers: Metric attributes providers.
93 """
94
95 # Validate POSIX path
96 if base_path == "":
97 base_path = "/"
98
99 if not base_path.startswith("/"):
100 raise ValueError(f"The base_path {base_path} must be an absolute path.")
101
102 super().__init__(
103 base_path=base_path,
104 provider_name=PROVIDER,
105 metric_counters=metric_counters,
106 metric_gauges=metric_gauges,
107 metric_attributes_providers=metric_attributes_providers,
108 )
109
110 def _collect_metrics(
111 self,
112 func: Callable[[], _T],
113 operation: str,
114 path: str,
115 put_object_size: Optional[int] = None,
116 get_object_size: Optional[int] = None,
117 ) -> _T:
118 """
119 Collects and records performance metrics around file operations such as PUT, GET, DELETE, etc.
120
121 This method wraps an file operation and measures the time it takes to complete, along with recording
122 the size of the object if applicable.
123
124 :param func: The function that performs the actual file operation.
125 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
126 :param path: The path to the object.
127 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
128 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
129
130 :return: The result of the file operation, typically the return value of the `func` callable.
131 """
132 start_time = time.time()
133 status_code = 200
134
135 object_size = None
136 if operation == "PUT":
137 object_size = put_object_size
138 elif operation == "GET" and get_object_size is not None:
139 object_size = get_object_size
140
141 try:
142 result = func()
143 if operation == "GET" and object_size is None and isinstance(result, Sized):
144 object_size = len(result)
145 return result
146 except FileNotFoundError as error:
147 status_code = 404
148 raise error
149 except Exception as error:
150 status_code = -1
151 raise RuntimeError(f"Failed to {operation} object(s) at {path}, error: {error}") from error
152 finally:
153 elapsed_time = time.time() - start_time
154 self._metric_helper.record_duration(
155 elapsed_time, provider=self._provider_name, operation=operation, bucket="", status_code=status_code
156 )
157 if object_size:
158 self._metric_helper.record_object_size(
159 object_size, provider=self._provider_name, operation=operation, bucket="", status_code=status_code
160 )
161
162 def _put_object(
163 self,
164 path: str,
165 body: bytes,
166 if_match: Optional[str] = None,
167 if_none_match: Optional[str] = None,
168 attributes: Optional[dict[str, str]] = None,
169 ) -> int:
170 def _invoke_api() -> int:
171 os.makedirs(os.path.dirname(path), exist_ok=True)
172 atomic_write(source=BytesIO(body), destination=path, attributes=attributes)
173 return len(body)
174
175 return self._collect_metrics(_invoke_api, operation="PUT", path=path, put_object_size=len(body))
176
177 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
178 def _invoke_api() -> bytes:
179 if byte_range:
180 with open(path, "rb") as f:
181 f.seek(byte_range.offset)
182 return f.read(byte_range.size)
183 else:
184 with open(path, "rb") as f:
185 return f.read()
186
187 return self._collect_metrics(_invoke_api, operation="GET", path=path)
188
189 def _copy_object(self, src_path: str, dest_path: str) -> int:
190 src_object = self._get_object_metadata(src_path)
191
192 def _invoke_api() -> int:
193 os.makedirs(os.path.dirname(dest_path), exist_ok=True)
194 atomic_write(source=src_path, destination=dest_path, attributes=src_object.metadata)
195
196 return src_object.content_length
197
198 return self._collect_metrics(
199 _invoke_api,
200 operation="COPY",
201 path=src_path,
202 put_object_size=src_object.content_length,
203 )
204
205 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
206 def _invoke_api() -> None:
207 if os.path.exists(path) and os.path.isfile(path):
208 os.remove(path)
209
210 return self._collect_metrics(_invoke_api, operation="DELETE", path=path)
211
212 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
213 is_dir = os.path.isdir(path)
214 if is_dir:
215 path = self._append_delimiter(path)
216
217 def _invoke_api() -> ObjectMetadata:
218 # Get basic file attributes
219 metadata_dict = {}
220 try:
221 json_bytes = xattr.getxattr(path, "user.json")
222 metadata_dict = json.loads(json_bytes.decode("utf-8"))
223 except (OSError, IOError, KeyError, json.JSONDecodeError, AttributeError) as e:
224 # Silently ignore if xattr doesn't exist, can't be read, or is corrupted
225 logger.debug(f"Failed to read extended attributes from {path}: {e}")
226 pass
227
228 return ObjectMetadata(
229 key=path,
230 type="directory" if is_dir else "file",
231 content_length=0 if is_dir else os.path.getsize(path),
232 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
233 metadata=metadata_dict if metadata_dict else None,
234 )
235
236 return self._collect_metrics(_invoke_api, operation="HEAD", path=path)
237
238 def _list_objects(
239 self,
240 prefix: str,
241 start_after: Optional[str] = None,
242 end_at: Optional[str] = None,
243 include_directories: bool = False,
244 ) -> Iterator[ObjectMetadata]:
245 def _invoke_api() -> Iterator[ObjectMetadata]:
246 # Get parent directory and list its contents
247 parent_dir = os.path.dirname(prefix)
248 if not os.path.exists(parent_dir):
249 return
250
251 # List contents of parent directory for prefix based filtering
252 try:
253 entries = os.listdir(parent_dir)
254 except (OSError, PermissionError) as e:
255 logger.warning(f"Failed to list contents of {parent_dir}, caused by: {e}")
256 entries = []
257
258 matching_entries = []
259 for entry in entries:
260 full_path = os.path.join(parent_dir, entry)
261 if full_path.startswith(prefix):
262 matching_entries.append(full_path)
263
264 matching_entries.sort()
265
266 # Collect directories to walk
267 directories_to_walk = []
268
269 for full_path in matching_entries:
270 relative_path = os.path.relpath(full_path, self._base_path)
271 if (start_after is None or start_after < relative_path) and (end_at is None or relative_path <= end_at):
272 if os.path.isfile(full_path):
273 yield ObjectMetadata(
274 key=relative_path,
275 content_length=os.path.getsize(full_path),
276 last_modified=datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc),
277 )
278 elif os.path.isdir(full_path):
279 if include_directories:
280 yield ObjectMetadata(
281 key=relative_path,
282 content_length=0,
283 type="directory",
284 last_modified=AWARE_DATETIME_MIN,
285 )
286 else:
287 directories_to_walk.append(full_path)
288
289 # Walk all directories only if include_directories is False
290 if not include_directories:
291 for dir_path in directories_to_walk:
292 for root, _, files in os.walk(dir_path):
293 for file_name in sorted(files):
294 file_path = os.path.join(root, file_name)
295 relative_path = os.path.relpath(file_path, self._base_path)
296
297 if (start_after is None or start_after < relative_path) and (
298 end_at is None or relative_path <= end_at
299 ):
300 yield ObjectMetadata(
301 key=relative_path,
302 content_length=os.path.getsize(file_path),
303 last_modified=datetime.fromtimestamp(os.path.getmtime(file_path), tz=timezone.utc),
304 )
305
306 return self._collect_metrics(_invoke_api, operation="LIST", path=prefix)
307
308 def _upload_file(self, remote_path: str, f: Union[str, IO], attributes: Optional[dict[str, str]] = None) -> int:
309 os.makedirs(os.path.dirname(remote_path), exist_ok=True)
310
311 filesize: int = 0
312 if isinstance(f, str):
313 filesize = os.path.getsize(f)
314 elif isinstance(f, StringIO):
315 filesize = len(f.getvalue().encode("utf-8"))
316 else:
317 filesize = len(f.getvalue()) # type: ignore
318
319 def _invoke_api() -> int:
320 atomic_write(source=f, destination=remote_path, attributes=attributes)
321
322 return filesize
323
324 return self._collect_metrics(_invoke_api, operation="PUT", path=remote_path, put_object_size=filesize)
325
326 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> int:
327 filesize = metadata.content_length if metadata else os.path.getsize(remote_path)
328
329 if isinstance(f, str):
330
331 def _invoke_api() -> int:
332 os.makedirs(os.path.dirname(f), exist_ok=True)
333 atomic_write(source=remote_path, destination=f)
334
335 return filesize
336
337 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
338 elif isinstance(f, StringIO):
339
340 def _invoke_api() -> int:
341 with open(remote_path, "r", encoding="utf-8") as src:
342 f.write(src.read())
343
344 return filesize
345
346 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
347 else:
348
349 def _invoke_api() -> int:
350 with open(remote_path, "rb") as src:
351 f.write(src.read())
352
353 return filesize
354
355 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
356
[docs]
357 def glob(self, pattern: str) -> list[str]:
358 pattern = self._prepend_base_path(pattern)
359 keys = list(glob.glob(pattern, recursive=True))
360 if self._base_path == "/":
361 return keys
362 else:
363 # NOTE: PosixStorageProvider does not have the concept of bucket and prefix.
364 # So we drop the base_path from it.
365 return [key.replace(self._base_path, "", 1).lstrip("/") for key in keys]
366
[docs]
367 def is_file(self, path: str) -> bool:
368 path = self._prepend_base_path(path)
369 return os.path.isfile(path)
370
[docs]
371 def rmtree(self, path: str) -> None:
372 path = self._prepend_base_path(path)
373 shutil.rmtree(path)