1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import glob
17import os
18import json
19import tempfile
20import time
21from datetime import datetime, timezone
22from io import BytesIO, StringIO
23from typing import IO, Any, Callable, Iterator, List, Optional, Union, Dict
24
25from ..types import ObjectMetadata, Range, AWARE_DATETIME_MIN
26from .base import BaseStorageProvider
27
28PROVIDER = "file"
29
30
31def atomic_write(source: Union[str, IO], destination: str):
32 """
33 Writes the contents of a file to the specified destination path.
34
35 This function ensures that the file write operation is atomic, meaning the output file is either fully written or not modified at all.
36 This is achieved by writing to a temporary file first and then renaming it to the destination path.
37
38 :param source: The input file to read from. It can be a string representing the path to a file, or an open file-like object (IO).
39 :param destination: The path to the destination file where the contents should be written.
40 """
41 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(destination), prefix=".") as fp:
42 temp_file_path = fp.name
43 if isinstance(source, str):
44 with open(source, mode="rb") as src:
45 fp.write(src.read())
46 else:
47 fp.write(source.read())
48 os.rename(src=temp_file_path, dst=destination)
49
50
[docs]
51class PosixFileStorageProvider(BaseStorageProvider):
52 def __init__(self, base_path: str, **kwargs: Any) -> None:
53 # Validate POSIX path
54 if base_path == "":
55 base_path = "/"
56
57 if not base_path.startswith("/"):
58 raise ValueError(f"The base_path {base_path} must be an absolute path.")
59
60 super().__init__(base_path=base_path, provider_name=PROVIDER)
61
62 def _collect_metrics(
63 self,
64 func: Callable,
65 operation: str,
66 path: str,
67 put_object_size: Optional[int] = None,
68 get_object_size: Optional[int] = None,
69 ) -> Any:
70 """
71 Collects and records performance metrics around file operations such as PUT, GET, DELETE, etc.
72
73 This method wraps an file operation and measures the time it takes to complete, along with recording
74 the size of the object if applicable.
75
76 :param func: The function that performs the actual file operation.
77 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
78 :param path: The path to the object.
79 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
80 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
81
82 :return: The result of the file operation, typically the return value of the `func` callable.
83 """
84 start_time = time.time()
85 status_code = 200
86
87 object_size = None
88 if operation == "PUT":
89 object_size = put_object_size
90 elif operation == "GET" and get_object_size is not None:
91 object_size = get_object_size
92
93 try:
94 result = func()
95 if operation == "GET" and object_size is None:
96 object_size = len(result)
97 return result
98 except FileNotFoundError as error:
99 status_code = 404
100 raise error
101 except Exception as error:
102 status_code = -1
103 raise RuntimeError(f"Failed to {operation} object(s) at {path}") from error
104 finally:
105 elapsed_time = time.time() - start_time
106 self._metric_helper.record_duration(
107 elapsed_time, provider=self._provider_name, operation=operation, bucket="", status_code=status_code
108 )
109 if object_size:
110 self._metric_helper.record_object_size(
111 object_size, provider=self._provider_name, operation=operation, bucket="", status_code=status_code
112 )
113
114 def _put_object(self, path: str, body: bytes, metadata: Optional[Dict[str, str]] = None) -> None:
115 def _invoke_api() -> None:
116 os.makedirs(os.path.dirname(path), exist_ok=True)
117 atomic_write(source=BytesIO(body), destination=path)
118
119 # Set metadata attributes if setxattr is available
120 if metadata and hasattr(os, "setxattr"):
121 json_str = json.dumps(metadata)
122 json_bytes = json_str.encode("utf-8")
123 os.setxattr(path, "user.json", json_bytes, flags=0) # type: ignore
124
125 return self._collect_metrics(_invoke_api, operation="PUT", path=path, put_object_size=len(body))
126
127 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
128 def _invoke_api() -> bytes:
129 if byte_range:
130 with open(path, "rb") as f:
131 f.seek(byte_range.offset)
132 return f.read(byte_range.size)
133 else:
134 with open(path, "rb") as f:
135 return f.read()
136
137 return self._collect_metrics(_invoke_api, operation="GET", path=path)
138
139 def _copy_object(self, src_path: str, dest_path: str) -> None:
140 def _invoke_api() -> None:
141 os.makedirs(os.path.dirname(dest_path), exist_ok=True)
142 atomic_write(source=src_path, destination=dest_path)
143
144 src_object = self._get_object_metadata(src_path)
145
146 return self._collect_metrics(
147 _invoke_api,
148 operation="COPY",
149 path=src_path,
150 put_object_size=src_object.content_length,
151 )
152
153 def _delete_object(self, path: str, if_match: Optional[str] = None) -> None:
154 def _invoke_api() -> None:
155 if os.path.exists(path) and os.path.isfile(path):
156 os.remove(path)
157
158 return self._collect_metrics(_invoke_api, operation="DELETE", path=path)
159
160 def _get_object_metadata(self, path: str, strict: bool = True) -> ObjectMetadata:
161 is_dir = os.path.isdir(path)
162 if is_dir:
163 path = self._append_delimiter(path)
164
165 def _invoke_api() -> ObjectMetadata:
166 # Get basic file attributes
167 metadata_dict = {}
168 if hasattr(os, "getxattr"):
169 try:
170 json_bytes = os.getxattr(path, "user.json") # type: ignore
171 metadata_dict = json.loads(json_bytes.decode("utf-8"))
172 except OSError:
173 pass
174
175 return ObjectMetadata(
176 key=path,
177 type="directory" if is_dir else "file",
178 content_length=0 if is_dir else os.path.getsize(path),
179 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
180 metadata=metadata_dict if metadata_dict else None,
181 )
182
183 return self._collect_metrics(_invoke_api, operation="HEAD", path=path)
184
185 def _list_objects(
186 self,
187 prefix: str,
188 start_after: Optional[str] = None,
189 end_at: Optional[str] = None,
190 include_directories: bool = False,
191 ) -> Iterator[ObjectMetadata]:
192 def _invoke_api() -> Iterator[ObjectMetadata]:
193 # Assume the file system guarantees lexicographical order (some don't).
194 for root, dirs, files in os.walk(prefix):
195 dirs.sort()
196 if include_directories:
197 for dir in dirs:
198 full_path = os.path.join(root, dir)
199 relative_path = os.path.relpath(full_path, self._base_path)
200 yield ObjectMetadata(
201 key=relative_path,
202 content_length=0,
203 type="directory",
204 last_modified=AWARE_DATETIME_MIN,
205 )
206
207 # This is in reverse lexicographical order on some systems for some reason.
208 for name in sorted(files):
209 full_path = os.path.join(root, name)
210 # Changed the relative path from relative to prefix → relative to base path.
211 relative_path = os.path.relpath(full_path, self._base_path)
212 if (start_after is None or start_after < relative_path) and (
213 end_at is None or relative_path <= end_at
214 ):
215 yield ObjectMetadata(
216 key=relative_path,
217 content_length=os.path.getsize(full_path),
218 last_modified=datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc),
219 )
220 elif end_at is not None and end_at < relative_path:
221 return
222
223 # Only walk one level
224 if include_directories:
225 break
226
227 return self._collect_metrics(_invoke_api, operation="LIST", path=prefix)
228
229 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None:
230 os.makedirs(os.path.dirname(remote_path), exist_ok=True)
231
232 def _invoke_api() -> None:
233 atomic_write(source=f, destination=remote_path)
234
235 if isinstance(f, str):
236 filesize = os.path.getsize(f)
237 return self._collect_metrics(_invoke_api, operation="PUT", path=remote_path, put_object_size=filesize)
238 elif isinstance(f, StringIO):
239 filesize = len(f.getvalue().encode("utf-8"))
240 return self._collect_metrics(_invoke_api, operation="PUT", path=remote_path, put_object_size=filesize)
241 else:
242 filesize = len(f.getvalue()) # type: ignore
243 return self._collect_metrics(_invoke_api, operation="PUT", path=remote_path, put_object_size=filesize)
244
245 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
246 filesize = metadata.content_length if metadata else os.path.getsize(remote_path)
247
248 if isinstance(f, str):
249
250 def _invoke_api() -> None:
251 os.makedirs(os.path.dirname(f), exist_ok=True)
252 atomic_write(source=remote_path, destination=f)
253
254 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
255 elif isinstance(f, StringIO):
256
257 def _invoke_api() -> None:
258 with open(remote_path, "r", encoding="utf-8") as src:
259 f.write(src.read())
260
261 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
262 else:
263
264 def _invoke_api() -> None:
265 with open(remote_path, "rb") as src:
266 f.write(src.read())
267
268 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
269
[docs]
270 def glob(self, pattern: str) -> List[str]:
271 pattern = self._realpath(pattern)
272 keys = list(glob.glob(pattern, recursive=True))
273 if self._base_path == "/":
274 return keys
275 else:
276 # NOTE: PosixStorageProvider does not have the concept of bucket and prefix.
277 # So we drop the base_path from it.
278 return [key.replace(self._base_path, "", 1).lstrip("/") for key in keys]
279
[docs]
280 def is_file(self, path: str) -> bool:
281 path = self._realpath(path)
282 return os.path.isfile(path)