1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import glob
17import os
18import tempfile
19import time
20from datetime import datetime, timezone
21from io import BytesIO, StringIO
22from typing import IO, Any, Callable, Iterator, List, Optional, Union
23
24from ..types import ObjectMetadata, Range
25from .base import BaseStorageProvider
26
27PROVIDER = "file"
28
29
30def atomic_write(source: Union[str, IO], destination: str):
31 """
32 Writes the contents of a file to the specified destination path.
33
34 This function ensures that the file write operation is atomic, meaning the output file is either fully written or not modified at all.
35 This is achieved by writing to a temporary file first and then renaming it to the destination path.
36
37 :param source: The input file to read from. It can be a string representing the path to a file, or an open file-like object (IO).
38 :param destination: The path to the destination file where the contents should be written.
39 """
40 with tempfile.NamedTemporaryFile(mode="wb", delete=False, dir=os.path.dirname(destination), prefix=".") as fp:
41 temp_file_path = fp.name
42 if isinstance(source, str):
43 with open(source, mode="rb") as src:
44 fp.write(src.read())
45 else:
46 fp.write(source.read())
47 os.rename(src=temp_file_path, dst=destination)
48
49
[docs]
50class PosixFileStorageProvider(BaseStorageProvider):
51 def __init__(self, base_path: str, **kwargs: Any) -> None:
52 # Validate POSIX path
53 if base_path == "":
54 base_path = "/"
55
56 if not base_path.startswith("/"):
57 raise ValueError(f"The base_path {base_path} must be an absolute path.")
58
59 super().__init__(base_path=base_path, provider_name=PROVIDER)
60
61 def _collect_metrics(
62 self,
63 func: Callable,
64 operation: str,
65 path: str,
66 put_object_size: Optional[int] = None,
67 get_object_size: Optional[int] = None,
68 ) -> Any:
69 """
70 Collects and records performance metrics around file operations such as PUT, GET, DELETE, etc.
71
72 This method wraps an file operation and measures the time it takes to complete, along with recording
73 the size of the object if applicable.
74
75 :param func: The function that performs the actual file operation.
76 :param operation: The type of operation being performed (e.g., "PUT", "GET", "DELETE").
77 :param path: The path to the object.
78 :param put_object_size: The size of the object being uploaded, if applicable (for PUT operations).
79 :param get_object_size: The size of the object being downloaded, if applicable (for GET operations).
80
81 :return: The result of the file operation, typically the return value of the `func` callable.
82 """
83 start_time = time.time()
84 status_code = 200
85
86 object_size = None
87 if operation == "PUT":
88 object_size = put_object_size
89 elif operation == "GET" and get_object_size is not None:
90 object_size = get_object_size
91
92 try:
93 result = func()
94 if operation == "GET" and object_size is None:
95 object_size = len(result)
96 return result
97 except FileNotFoundError as error:
98 status_code = 404
99 raise error
100 except Exception as error:
101 status_code = -1
102 raise RuntimeError(f"Failed to {operation} object(s) at {path}") from error
103 finally:
104 elapsed_time = time.time() - start_time
105 self._metric_helper.record_duration(
106 elapsed_time, provider=PROVIDER, operation=operation, bucket="", status_code=status_code
107 )
108 if object_size:
109 self._metric_helper.record_object_size(
110 object_size, provider=PROVIDER, operation=operation, bucket="", status_code=status_code
111 )
112
113 def _put_object(self, path: str, body: bytes) -> None:
114 def _invoke_api() -> None:
115 os.makedirs(os.path.dirname(path), exist_ok=True)
116 atomic_write(source=BytesIO(body), destination=path)
117
118 return self._collect_metrics(_invoke_api, operation="PUT", path=path, put_object_size=len(body))
119
120 def _get_object(self, path: str, byte_range: Optional[Range] = None) -> bytes:
121 def _invoke_api() -> bytes:
122 if byte_range:
123 with open(path, "rb") as f:
124 f.seek(byte_range.offset)
125 return f.read(byte_range.size)
126 else:
127 with open(path, "rb") as f:
128 return f.read()
129
130 return self._collect_metrics(_invoke_api, operation="GET", path=path)
131
132 def _copy_object(self, src_path: str, dest_path: str) -> None:
133 def _invoke_api() -> None:
134 os.makedirs(os.path.dirname(dest_path), exist_ok=True)
135 atomic_write(source=src_path, destination=dest_path)
136
137 src_object = self._get_object_metadata(src_path)
138
139 return self._collect_metrics(
140 _invoke_api,
141 operation="COPY",
142 path=src_path,
143 put_object_size=src_object.content_length,
144 )
145
146 def _delete_object(self, path: str) -> None:
147 def _invoke_api() -> None:
148 if os.path.exists(path) and os.path.isfile(path):
149 os.remove(path)
150
151 return self._collect_metrics(_invoke_api, operation="DELETE", path=path)
152
153 def _get_object_metadata(self, path: str) -> ObjectMetadata:
154 is_dir = os.path.isdir(path)
155 if is_dir:
156 path = self._append_delimiter(path)
157
158 def _invoke_api() -> ObjectMetadata:
159 return ObjectMetadata(
160 key=path,
161 type="directory" if is_dir else "file",
162 content_length=os.path.getsize(path),
163 last_modified=datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc),
164 )
165
166 return self._collect_metrics(_invoke_api, operation="HEAD", path=path)
167
168 def _list_objects(
169 self,
170 prefix: str,
171 start_after: Optional[str] = None,
172 end_at: Optional[str] = None,
173 include_directories: bool = False,
174 ) -> Iterator[ObjectMetadata]:
175 def _invoke_api() -> Iterator[ObjectMetadata]:
176 # Assume the file system guarantees lexicographical order (some don't).
177 for root, dirs, files in os.walk(prefix):
178 if include_directories:
179 for dir in dirs:
180 full_path = os.path.join(root, dir)
181 relative_path = os.path.relpath(full_path, self._base_path)
182 yield ObjectMetadata(
183 key=relative_path,
184 content_length=0,
185 type="directory",
186 last_modified=datetime.min,
187 )
188
189 # This is in reverse lexicographical order on some systems for some reason.
190 for name in sorted(files):
191 full_path = os.path.join(root, name)
192 # Changed the relative path from relative to prefix → relative to base path.
193 relative_path = os.path.relpath(full_path, self._base_path)
194 if (start_after is None or start_after < relative_path) and (
195 end_at is None or relative_path <= end_at
196 ):
197 yield ObjectMetadata(
198 key=relative_path,
199 content_length=os.path.getsize(full_path),
200 last_modified=datetime.fromtimestamp(os.path.getmtime(full_path), tz=timezone.utc),
201 )
202 elif end_at is not None and end_at < relative_path:
203 return
204
205 # Only walk one level
206 if include_directories:
207 break
208
209 return self._collect_metrics(_invoke_api, operation="LIST", path=prefix)
210
211 def _upload_file(self, remote_path: str, f: Union[str, IO]) -> None:
212 os.makedirs(os.path.dirname(remote_path), exist_ok=True)
213
214 def _invoke_api() -> None:
215 atomic_write(source=f, destination=remote_path)
216
217 if isinstance(f, str):
218 filesize = os.path.getsize(f)
219 return self._collect_metrics(_invoke_api, operation="PUT", path=remote_path, put_object_size=filesize)
220 elif isinstance(f, StringIO):
221 filesize = len(f.getvalue().encode("utf-8"))
222 return self._collect_metrics(_invoke_api, operation="PUT", path=remote_path, put_object_size=filesize)
223 else:
224 filesize = len(f.getvalue()) # type: ignore
225 return self._collect_metrics(_invoke_api, operation="PUT", path=remote_path, put_object_size=filesize)
226
227 def _download_file(self, remote_path: str, f: Union[str, IO], metadata: Optional[ObjectMetadata] = None) -> None:
228 filesize = metadata.content_length if metadata else os.path.getsize(remote_path)
229
230 if isinstance(f, str):
231
232 def _invoke_api() -> None:
233 os.makedirs(os.path.dirname(f), exist_ok=True)
234 atomic_write(source=remote_path, destination=f)
235
236 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
237 elif isinstance(f, StringIO):
238
239 def _invoke_api() -> None:
240 with open(remote_path, "r", encoding="utf-8") as src:
241 f.write(src.read())
242
243 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
244 else:
245
246 def _invoke_api() -> None:
247 with open(remote_path, "rb") as src:
248 f.write(src.read())
249
250 return self._collect_metrics(_invoke_api, operation="GET", path=remote_path, get_object_size=filesize)
251
[docs]
252 def glob(self, pattern: str) -> List[str]:
253 pattern = self._realpath(pattern)
254 keys = list(glob.glob(pattern, recursive=True))
255 if self._base_path == "/":
256 return keys
257 else:
258 # NOTE: PosixStorageProvider does not have the concept of bucket and prefix.
259 # So we drop the base_path from it.
260 return [key.replace(self._base_path, "", 1).lstrip("/") for key in keys]
261
[docs]
262 def is_file(self, path: str) -> bool:
263 path = self._realpath(path)
264 return os.path.isfile(path)