1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language overning permissions and
14# limitations under the License.
15
16import io
17import os
18from collections.abc import Generator
19from concurrent.futures import ThreadPoolExecutor
20from contextlib import contextmanager
21from typing import Union, cast
22
23from torch.distributed.checkpoint.filesystem import FileSystemBase, FileSystemReader, FileSystemWriter
24from torch.distributed.checkpoint.planner import (
25 LoadPlan,
26 LoadPlanner,
27 ReadItem,
28)
29from torch.futures import Future
30
31from ...pathlib import MultiStoragePath
32
33
[docs]
34class MultiStorageFileSystem(FileSystemBase):
35 """
36 A filesystem implementation that uses the MultiStoragePath class to handle paths.
37 """
38
[docs]
39 @contextmanager
40 def create_stream(self, path: Union[str, os.PathLike], mode: str) -> Generator[io.IOBase, None, None]:
41 with MultiStoragePath(path).open(mode=mode) as fp:
42 yield fp
43
[docs]
44 def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> Union[str, os.PathLike]:
45 return MultiStoragePath(path) / suffix
46
[docs]
47 def rename(self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]) -> None:
48 MultiStoragePath(path).rename(new_path)
49
[docs]
50 def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
51 return MultiStoragePath(path)
52
[docs]
53 def mkdir(self, path: Union[str, os.PathLike]) -> None:
54 MultiStoragePath(path).mkdir(parents=True, exist_ok=True)
55
[docs]
56 @classmethod
57 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
58 try:
59 MultiStoragePath(checkpoint_id)
60 except ValueError:
61 return False
62
63 return True
64
[docs]
65 def exists(self, path: Union[str, os.PathLike]) -> bool:
66 return MultiStoragePath(path).exists()
67
[docs]
68 def rm_file(self, path: Union[str, os.PathLike]) -> None:
69 MultiStoragePath(path).unlink()
70
[docs]
71 def ls(self, path: Union[str, os.PathLike]) -> list[str]:
72 return [str(p) for p in MultiStoragePath(path).iterdir()]
73
74
75def _prefetch_objects(fs: MultiStorageFileSystem, urls: list[MultiStoragePath], thread_count: int) -> None:
76 """
77 Efficiently pre-downloads files from object storage using parallel threads, storing them in cache when enabled for optimized subsequent access.
78 """
79
80 def _prefetch(url: MultiStoragePath) -> None:
81 with fs.create_stream(url, "rb") as _:
82 pass
83
84 with ThreadPoolExecutor(max_workers=thread_count) as executor:
85 futures = [executor.submit(_prefetch, url) for url in urls]
86 for future in futures:
87 future.result()
88
89
[docs]
90class MultiStorageFileSystemReader(FileSystemReader):
91 """
92 A reader implementation that uses the MultiStorageFileSystem class to handle file system operations.
93 """
94
95 def __init__(self, path: Union[str, os.PathLike], thread_count: int = 1) -> None:
96 """
97 Initialize the MultiStorageFileSystemReader with the MultiStorageFileSystem.
98
99 :param path: The path to the checkpoint.
100 :param thread_count: The number of threads to use for prefetching.
101 """
102 super().__init__(path)
103 self.fs = MultiStorageFileSystem()
104 self.path = self.fs.init_path(path)
105 self.thread_count = thread_count
106
[docs]
107 def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
108 """
109 Override the method to prefetch objects from object storage.
110 """
111 if self.thread_count > 1:
112 # group requests by file
113 per_file: dict[str, list[ReadItem]] = {}
114 for read_item in plan.items:
115 item_md = self.storage_data[read_item.storage_index]
116 path = item_md.relative_path
117 per_file.setdefault(path, []).append(read_item)
118
119 # prefetch objects
120 urls = [cast(MultiStoragePath, self.path) / rel_path for rel_path, _ in per_file.items()]
121 _prefetch_objects(self.fs, urls, self.thread_count)
122
123 return super().read_data(plan, planner)
124
[docs]
125 @classmethod
126 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
127 return MultiStorageFileSystem.validate_checkpoint_id(checkpoint_id)
128
129
[docs]
130class MultiStorageFileSystemWriter(FileSystemWriter):
131 """
132 A writer implementation that uses the MultiStorageFileSystem class to handle file system operations.
133 """
134
135 def __init__(
136 self,
137 path: Union[str, os.PathLike],
138 single_file_per_rank: bool = True,
139 sync_files: bool = True,
140 thread_count: int = 1,
141 per_thread_copy_ahead: int = 10_000_000,
142 cache_staged_state_dict: bool = False,
143 overwrite: bool = True,
144 ) -> None:
145 """
146 Initialize the MultiStorageFileSystemWriter with the MultiStorageFileSystem.
147 """
148 super().__init__(
149 path,
150 single_file_per_rank,
151 sync_files,
152 thread_count,
153 per_thread_copy_ahead,
154 cache_staged_state_dict,
155 overwrite=overwrite,
156 )
157 self.fs = MultiStorageFileSystem()
158 self.path = self.fs.init_path(path)
159
[docs]
160 @classmethod
161 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
162 return MultiStorageFileSystem.validate_checkpoint_id(checkpoint_id)