FileSystemWriter example

  1import argparse
  2import logging
  3import shutil
  4
  5try:
  6    import multistorageclient as msc
  7except ImportError:
  8    msc = None
  9import torch
 10import torch.distributed as dist
 11import torch.nn as nn
 12import torch.optim as optim
 13from torch.distributed import checkpoint
 14from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
 15from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 16from torch.utils.data import DataLoader, DistributedSampler
 17
 18from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
 19from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
 20from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
 21    init_checkpoint_metadata_cache,
 22    save_state_dict_async_finalize,
 23    save_state_dict_async_plan,
 24)
 25
 26# Set up basic logging configuration
 27# Try setting `DEBUG` to see detailed steps of NVRx checkpointing
 28logging.basicConfig(level=logging.INFO)
 29
 30FEAT_SIZE = 4096
 31DNN_OUT_SIZE = 128
 32BATCH_SIZE = 100
 33NUM_EPOCHS = 10
 34DATASET_LEN = 100000
 35CKPT_INTERVAL = 100
 36
 37
 38def print_on_rank0(msg):
 39    if dist.get_rank() == 0:
 40        print(msg)
 41
 42
 43def parse_args():
 44    parser = argparse.ArgumentParser(description="Async Checkpointing Example")
 45    parser.add_argument(
 46        "--ckpt_dir",
 47        default="/tmp/test_checkpointing/",
 48        help="Directory for saving checkpoints",
 49    )
 50    parser.add_argument(
 51        "--thread_count",
 52        default=2,
 53        type=int,
 54        help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
 55    )
 56    parser.add_argument(
 57        '--no_persistent_queue',
 58        action='store_false',
 59        default=True,
 60        dest='persistent_queue',
 61        help=(
 62            "Disables a persistent version of AsyncCallsQueue. "
 63            "Effective only when --async_save is set."
 64        ),
 65    )
 66    parser.add_argument(
 67        '--enable_msc',
 68        action='store_true',
 69        help="Enables MSC for checkpoint saving and loading. See Usage Guide in Async Checkpointing documentation for detailed instructions.",
 70    )
 71
 72    return parser.parse_args()
 73
 74
 75class SimpleDataset(torch.utils.data.Dataset):
 76    def __init__(self, size):
 77        self.size = size
 78
 79    def __len__(self):
 80        return self.size
 81
 82    def __getitem__(self, idx):
 83        x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
 84        y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
 85        return x, y
 86
 87
 88class SimpleModel(nn.Module):
 89    def __init__(self):
 90        super().__init__()
 91        self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
 92        self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
 93
 94    def forward(self, x):
 95        x = self.fc1(x)
 96        x = nn.functional.relu(x)
 97        x = self.fc2(x)
 98        return x
 99
100
101def init_distributed_backend(backend="nccl"):
102    """Initializes the distributed process group using the specified backend."""
103    try:
104        dist.init_process_group(backend=backend, init_method="env://")
105        rank = dist.get_rank()
106        torch.cuda.set_device(dist.get_node_local_rank())
107        logging.info(f"Process {rank} initialized with {backend} backend.")
108        return rank, torch.distributed.get_world_size()
109    except Exception as e:
110        logging.error(f"Failed to initialize distributed backend: {e}")
111        raise
112
113
114def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
115    """Creates an async save request with a finalize function."""
116    save_fn, preload_fn, save_args = writer.get_save_function_and_args()
117
118    def finalize_fn():
119        """Finalizes async checkpointing and synchronizes processes."""
120        save_state_dict_async_finalize(*save_state_dict_ret)
121        dist.barrier()
122
123    return AsyncRequest(
124        save_fn, save_args, [finalize_fn], async_fn_kwargs={}, preload_fn=preload_fn
125    )
126
127
128def save_checkpoint(checkpoint_dir, async_queue, model, thread_count, enable_msc):
129    """Asynchronously saves a model checkpoint."""
130    state_dict = model.state_dict()
131    planner = DefaultSavePlanner()
132    writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count, use_msc=enable_msc)
133    coordinator_rank = 0
134
135    save_state_dict_ret = save_state_dict_async_plan(
136        state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
137    )
138    save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
139    async_queue.schedule_async_request(save_request)
140
141
142def load_checkpoint(checkpoint_dir, model, thread_count, enable_msc):
143    """Loads a model checkpoint synchronously."""
144    state_dict = model.state_dict()
145    if enable_msc:
146        reader = msc.torch.MultiStorageFileSystemReader(checkpoint_dir, thread_count=thread_count)
147    else:
148        reader = FileSystemReader(checkpoint_dir)
149    checkpoint.load(
150        state_dict=state_dict,
151        storage_reader=reader,
152        planner=DefaultLoadPlanner(),
153    )
154    return state_dict
155
156
157def main():
158    args = parse_args()
159    logging.info(f"Arguments: {args}")
160
161    if args.enable_msc and msc is None:
162        raise ImportError(
163            "multistorageclient is required when --enable_msc is set. "
164            "Install it with: pip install multistorageclient"
165        )
166
167    # Initialize distributed training
168    rank, world_size = init_distributed_backend(backend="nccl")
169
170    # Define checkpoint directory based on iteration number
171    dataset = SimpleDataset(size=DATASET_LEN)
172    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
173    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
174
175    # Model, optimizer, and FSDP wrapper
176    model = SimpleModel().to("cuda")
177    fsdp_model = FSDP(model)
178    optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
179    loss_fn = nn.MSELoss()
180
181    # Create an async queue for handling asynchronous operations
182    async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
183
184    iteration = 0
185    num_iters_in_epoch = len(dataloader)
186    print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
187
188    num_iters_for_10pct = num_iters_in_epoch // 10  # iters for 1/10 of epoch
189    checkpoint_dir = None
190    sampler.set_epoch(0)
191
192    init_checkpoint_metadata_cache()
193
194    for batch_idx, (data, target) in enumerate(dataloader):
195        async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
196        if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
197            print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
198        optimizer.zero_grad()
199        output = fsdp_model(data)
200        loss = loss_fn(output, target)
201        loss.backward()
202        optimizer.step()
203        if batch_idx % num_iters_for_10pct == 0:
204            iteration = batch_idx
205            checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
206            # Save the model asynchronously
207            save_checkpoint(
208                checkpoint_dir, async_queue, fsdp_model, args.thread_count, args.enable_msc
209            )
210            print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
211            iteration += batch_idx
212    print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
213
214    logging.info("Finalizing checkpoint save...")
215    async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
216    async_queue.close()  # Explicitly close queue
217
218    # Synchronize processes before loading
219    dist.barrier()
220    print_on_rank0(f"loading from {checkpoint_dir}")
221    # Load the checkpoint
222    loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model, args.thread_count, args.enable_msc)
223
224    # Synchronize again to ensure all ranks have completed loading
225    dist.barrier()
226
227    # Clean up checkpoint directory (only on rank 0)
228    if dist.get_rank() == 0:
229        logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
230        if args.enable_msc:
231            msc.delete(args.ckpt_dir, recursive=True)
232        else:
233            shutil.rmtree(args.ckpt_dir)
234
235    # Ensure NCCL process group is properly destroyed
236    if dist.is_initialized():
237        dist.destroy_process_group()
238
239
240if __name__ == "__main__":
241    main()