Source code for multistorageclient.contrib.torch.filesystem

  1# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language overning permissions and
 14# limitations under the License.
 15
 16import io
 17import os
 18from collections.abc import Generator
 19from concurrent.futures import ThreadPoolExecutor
 20from contextlib import contextmanager
 21from typing import Union, cast
 22
 23from torch.distributed.checkpoint.filesystem import FileSystemBase, FileSystemReader, FileSystemWriter
 24from torch.distributed.checkpoint.planner import (
 25    LoadPlan,
 26    LoadPlanner,
 27    ReadItem,
 28)
 29from torch.futures import Future
 30
 31from ...pathlib import MultiStoragePath
 32
 33
[docs] 34class MultiStorageFileSystem(FileSystemBase): 35 """ 36 A filesystem implementation that uses the MultiStoragePath class to handle paths. 37 """ 38
[docs] 39 @contextmanager 40 def create_stream(self, path: Union[str, os.PathLike], mode: str) -> Generator[io.IOBase, None, None]: 41 # always download the file here (used for checkpointing) 42 with MultiStoragePath(path).open(mode=mode, prefetch_file=True) as fp: 43 yield fp
44
[docs] 45 def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> Union[str, os.PathLike]: 46 return MultiStoragePath(path) / suffix
47
[docs] 48 def rename(self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]) -> None: 49 MultiStoragePath(path).rename(new_path)
50
[docs] 51 def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: 52 return MultiStoragePath(path)
53
[docs] 54 def mkdir(self, path: Union[str, os.PathLike]) -> None: 55 MultiStoragePath(path).mkdir(parents=True, exist_ok=True)
56
[docs] 57 @classmethod 58 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 59 try: 60 MultiStoragePath(checkpoint_id) 61 except ValueError: 62 return False 63 64 return True
65
[docs] 66 def exists(self, path: Union[str, os.PathLike]) -> bool: 67 return MultiStoragePath(path).exists()
68
[docs] 69 def rm_file(self, path: Union[str, os.PathLike]) -> None: 70 MultiStoragePath(path).unlink()
71
[docs] 72 def ls(self, path: Union[str, os.PathLike]) -> list[str]: 73 return [str(p) for p in MultiStoragePath(path).iterdir()]
74 75 76def _prefetch_objects(fs: MultiStorageFileSystem, urls: list[MultiStoragePath], thread_count: int) -> None: 77 """ 78 Efficiently pre-downloads files from object storage using parallel threads, storing them in cache when enabled for optimized subsequent access. 79 """ 80 81 def _prefetch(url: MultiStoragePath) -> None: 82 with fs.create_stream(url, "rb") as _: 83 pass 84 85 with ThreadPoolExecutor(max_workers=thread_count) as executor: 86 futures = [executor.submit(_prefetch, url) for url in urls] 87 for future in futures: 88 future.result() 89 90
[docs] 91class MultiStorageFileSystemReader(FileSystemReader): 92 """ 93 A reader implementation that uses the MultiStorageFileSystem class to handle file system operations. 94 """ 95 96 def __init__(self, path: Union[str, os.PathLike], thread_count: int = 1) -> None: 97 """ 98 Initialize the MultiStorageFileSystemReader with the MultiStorageFileSystem. 99 100 :param path: The path to the checkpoint. 101 :param thread_count: The number of threads to use for prefetching. 102 """ 103 super().__init__(path) 104 self.fs = MultiStorageFileSystem() 105 self.path = self.fs.init_path(path) 106 self.thread_count = thread_count 107
[docs] 108 def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: 109 """ 110 Override the method to prefetch objects from object storage. 111 """ 112 if self.thread_count > 1: 113 # group requests by file 114 per_file: dict[str, list[ReadItem]] = {} 115 for read_item in plan.items: 116 item_md = self.storage_data[read_item.storage_index] 117 path = item_md.relative_path 118 per_file.setdefault(path, []).append(read_item) 119 120 # prefetch objects 121 urls = [cast(MultiStoragePath, self.path) / rel_path for rel_path, _ in per_file.items()] 122 _prefetch_objects(self.fs, urls, self.thread_count) 123 124 return super().read_data(plan, planner)
125
[docs] 126 @classmethod 127 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 128 return MultiStorageFileSystem.validate_checkpoint_id(checkpoint_id)
129 130
[docs] 131class MultiStorageFileSystemWriter(FileSystemWriter): 132 """ 133 A writer implementation that uses the MultiStorageFileSystem class to handle file system operations. 134 """ 135 136 def __init__( 137 self, 138 path: Union[str, os.PathLike], 139 single_file_per_rank: bool = True, 140 sync_files: bool = True, 141 thread_count: int = 1, 142 per_thread_copy_ahead: int = 10_000_000, 143 cache_staged_state_dict: bool = False, 144 overwrite: bool = True, 145 ) -> None: 146 """ 147 Initialize the MultiStorageFileSystemWriter with the MultiStorageFileSystem. 148 """ 149 super().__init__( 150 path, 151 single_file_per_rank, 152 sync_files, 153 thread_count, 154 per_thread_copy_ahead, 155 cache_staged_state_dict, 156 overwrite=overwrite, 157 ) 158 self.fs = MultiStorageFileSystem() 159 self.path = self.fs.init_path(path) 160
[docs] 161 @classmethod 162 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 163 return MultiStorageFileSystem.validate_checkpoint_id(checkpoint_id)