Source code for multistorageclient.contrib.torch.filesystem

  1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language overning permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18from collections.abc import Generator
 19from concurrent.futures import ThreadPoolExecutor
 20from contextlib import contextmanager
 21from typing import Union, cast
 22
 23from torch.distributed.checkpoint.filesystem import FileSystemBase, FileSystemReader, FileSystemWriter
 24from torch.distributed.checkpoint.planner import (
 25    LoadPlan,
 26    LoadPlanner,
 27    ReadItem,
 28)
 29from torch.futures import Future
 30
 31from ...pathlib import MultiStoragePath
 32
 33
[docs] 34class MultiStorageFileSystem(FileSystemBase): 35 """ 36 A filesystem implementation that uses the MultiStoragePath class to handle paths. 37 """ 38
[docs] 39 @contextmanager 40 def create_stream(self, path: Union[str, os.PathLike], mode: str) -> Generator[io.IOBase, None, None]: 41 with MultiStoragePath(path).open(mode=mode) as fp: 42 yield fp
43
[docs] 44 def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> Union[str, os.PathLike]: 45 return MultiStoragePath(path) / suffix
46
[docs] 47 def rename(self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]) -> None: 48 MultiStoragePath(path).rename(new_path)
49
[docs] 50 def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: 51 return MultiStoragePath(path)
52
[docs] 53 def mkdir(self, path: Union[str, os.PathLike]) -> None: 54 MultiStoragePath(path).mkdir(parents=True, exist_ok=True)
55
[docs] 56 @classmethod 57 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 58 try: 59 MultiStoragePath(checkpoint_id) 60 except ValueError: 61 return False 62 63 return True
64
[docs] 65 def exists(self, path: Union[str, os.PathLike]) -> bool: 66 return MultiStoragePath(path).exists()
67
[docs] 68 def rm_file(self, path: Union[str, os.PathLike]) -> None: 69 MultiStoragePath(path).unlink()
70
[docs] 71 def ls(self, path: Union[str, os.PathLike]) -> list[str]: 72 return [str(p) for p in MultiStoragePath(path).iterdir()]
73 74 75def _prefetch_objects(fs: MultiStorageFileSystem, urls: list[MultiStoragePath], thread_count: int) -> None: 76 """ 77 Efficiently pre-downloads files from object storage using parallel threads, storing them in cache when enabled for optimized subsequent access. 78 """ 79 80 def _prefetch(url: MultiStoragePath) -> None: 81 with fs.create_stream(url, "rb") as _: 82 pass 83 84 with ThreadPoolExecutor(max_workers=thread_count) as executor: 85 futures = [executor.submit(_prefetch, url) for url in urls] 86 for future in futures: 87 future.result() 88 89
[docs] 90class MultiStorageFileSystemReader(FileSystemReader): 91 """ 92 A reader implementation that uses the MultiStorageFileSystem class to handle file system operations. 93 """ 94 95 def __init__(self, path: Union[str, os.PathLike], thread_count: int = 1) -> None: 96 """ 97 Initialize the MultiStorageFileSystemReader with the MultiStorageFileSystem. 98 99 :param path: The path to the checkpoint. 100 :param thread_count: The number of threads to use for prefetching. 101 """ 102 super().__init__(path) 103 self.fs = MultiStorageFileSystem() 104 self.path = self.fs.init_path(path) 105 self.thread_count = thread_count 106
[docs] 107 def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: 108 """ 109 Override the method to prefetch objects from object storage. 110 """ 111 if self.thread_count > 1: 112 # group requests by file 113 per_file: dict[str, list[ReadItem]] = {} 114 for read_item in plan.items: 115 item_md = self.storage_data[read_item.storage_index] 116 path = item_md.relative_path 117 per_file.setdefault(path, []).append(read_item) 118 119 # prefetch objects 120 urls = [cast(MultiStoragePath, self.path) / rel_path for rel_path, _ in per_file.items()] 121 _prefetch_objects(self.fs, urls, self.thread_count) 122 123 return super().read_data(plan, planner)
124
[docs] 125 @classmethod 126 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 127 return MultiStorageFileSystem.validate_checkpoint_id(checkpoint_id)
128 129
[docs] 130class MultiStorageFileSystemWriter(FileSystemWriter): 131 """ 132 A writer implementation that uses the MultiStorageFileSystem class to handle file system operations. 133 """ 134 135 def __init__( 136 self, 137 path: Union[str, os.PathLike], 138 single_file_per_rank: bool = True, 139 sync_files: bool = True, 140 thread_count: int = 1, 141 per_thread_copy_ahead: int = 10_000_000, 142 cache_staged_state_dict: bool = False, 143 overwrite: bool = True, 144 ) -> None: 145 """ 146 Initialize the MultiStorageFileSystemWriter with the MultiStorageFileSystem. 147 """ 148 super().__init__( 149 path, 150 single_file_per_rank, 151 sync_files, 152 thread_count, 153 per_thread_copy_ahead, 154 cache_staged_state_dict, 155 overwrite=overwrite, 156 ) 157 self.fs = MultiStorageFileSystem() 158 self.path = self.fs.init_path(path) 159
[docs] 160 @classmethod 161 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 162 return MultiStorageFileSystem.validate_checkpoint_id(checkpoint_id)