1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language overning permissions and
14# limitations under the License.
15
16import io
17import os
18from collections.abc import Generator
19from concurrent.futures import ThreadPoolExecutor
20from contextlib import contextmanager
21from typing import Union, cast
22
23from torch.distributed.checkpoint.filesystem import FileSystemBase, FileSystemReader, FileSystemWriter
24from torch.distributed.checkpoint.planner import (
25 LoadPlan,
26 LoadPlanner,
27 ReadItem,
28)
29from torch.futures import Future
30
31from ...pathlib import MultiStoragePath
32
33
[docs]
34class MultiStorageFileSystem(FileSystemBase):
35 """
36 A filesystem implementation that uses the MultiStoragePath class to handle paths.
37 """
38
[docs]
39 @contextmanager
40 def create_stream(self, path: Union[str, os.PathLike], mode: str) -> Generator[io.IOBase, None, None]:
41 # always download the file here (used for checkpointing)
42 with MultiStoragePath(path).open(mode=mode, prefetch_file=True) as fp:
43 yield fp
44
[docs]
45 def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> Union[str, os.PathLike]:
46 return MultiStoragePath(path) / suffix
47
[docs]
48 def rename(self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]) -> None:
49 MultiStoragePath(path).rename(new_path)
50
[docs]
51 def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
52 return MultiStoragePath(path)
53
[docs]
54 def mkdir(self, path: Union[str, os.PathLike]) -> None:
55 MultiStoragePath(path).mkdir(parents=True, exist_ok=True)
56
[docs]
57 @classmethod
58 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
59 try:
60 MultiStoragePath(checkpoint_id)
61 except ValueError:
62 return False
63
64 return True
65
[docs]
66 def exists(self, path: Union[str, os.PathLike]) -> bool:
67 return MultiStoragePath(path).exists()
68
[docs]
69 def rm_file(self, path: Union[str, os.PathLike]) -> None:
70 MultiStoragePath(path).unlink()
71
[docs]
72 def ls(self, path: Union[str, os.PathLike]) -> list[str]:
73 return [str(p) for p in MultiStoragePath(path).iterdir()]
74
75
76def _prefetch_objects(fs: MultiStorageFileSystem, urls: list[MultiStoragePath], thread_count: int) -> None:
77 """
78 Efficiently pre-downloads files from object storage using parallel threads, storing them in cache when enabled for optimized subsequent access.
79 """
80
81 def _prefetch(url: MultiStoragePath) -> None:
82 with fs.create_stream(url, "rb") as _:
83 pass
84
85 with ThreadPoolExecutor(max_workers=thread_count) as executor:
86 futures = [executor.submit(_prefetch, url) for url in urls]
87 for future in futures:
88 future.result()
89
90
[docs]
91class MultiStorageFileSystemReader(FileSystemReader):
92 """
93 A reader implementation that uses the MultiStorageFileSystem class to handle file system operations.
94 """
95
96 def __init__(self, path: Union[str, os.PathLike], thread_count: int = 1) -> None:
97 """
98 Initialize the MultiStorageFileSystemReader with the MultiStorageFileSystem.
99
100 :param path: The path to the checkpoint.
101 :param thread_count: The number of threads to use for prefetching.
102 """
103 super().__init__(path)
104 self.fs = MultiStorageFileSystem()
105 self.path = self.fs.init_path(path)
106 self.thread_count = thread_count
107
[docs]
108 def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
109 """
110 Override the method to prefetch objects from object storage.
111 """
112 if self.thread_count > 1:
113 # group requests by file
114 per_file: dict[str, list[ReadItem]] = {}
115 for read_item in plan.items:
116 item_md = self.storage_data[read_item.storage_index]
117 path = item_md.relative_path
118 per_file.setdefault(path, []).append(read_item)
119
120 # prefetch objects
121 urls = [cast(MultiStoragePath, self.path) / rel_path for rel_path, _ in per_file.items()]
122 _prefetch_objects(self.fs, urls, self.thread_count)
123
124 return super().read_data(plan, planner)
125
[docs]
126 @classmethod
127 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
128 return MultiStorageFileSystem.validate_checkpoint_id(checkpoint_id)
129
130
[docs]
131class MultiStorageFileSystemWriter(FileSystemWriter):
132 """
133 A writer implementation that uses the MultiStorageFileSystem class to handle file system operations.
134 """
135
136 def __init__(
137 self,
138 path: Union[str, os.PathLike],
139 single_file_per_rank: bool = True,
140 sync_files: bool = True,
141 thread_count: int = 1,
142 per_thread_copy_ahead: int = 10_000_000,
143 cache_staged_state_dict: bool = False,
144 overwrite: bool = True,
145 ) -> None:
146 """
147 Initialize the MultiStorageFileSystemWriter with the MultiStorageFileSystem.
148 """
149 super().__init__(
150 path,
151 single_file_per_rank,
152 sync_files,
153 thread_count,
154 per_thread_copy_ahead,
155 cache_staged_state_dict,
156 overwrite=overwrite,
157 )
158 self.fs = MultiStorageFileSystem()
159 self.path = self.fs.init_path(path)
160
[docs]
161 @classmethod
162 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
163 return MultiStorageFileSystem.validate_checkpoint_id(checkpoint_id)