FileSystemWriter example

  1import argparse
  2import logging
  3import shutil
  4
  5import torch
  6import torch.distributed as dist
  7import torch.nn as nn
  8from torch.distributed import checkpoint
  9from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
 10
 11from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
 12from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
 13from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
 14    save_state_dict_async_finalize,
 15    save_state_dict_async_plan,
 16)
 17
 18# Set up basic logging configuration
 19logging.basicConfig(level=logging.INFO)
 20
 21
 22def parse_args():
 23    parser = argparse.ArgumentParser(description="Async Checkpointing Example")
 24    parser.add_argument(
 25        "--ckpt_dir",
 26        default="/tmp/test_checkpointing/",
 27        help="Directory for saving checkpoints",
 28    )
 29    parser.add_argument(
 30        "--thread_count",
 31        default=2,
 32        type=int,
 33        help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
 34    )
 35    parser.add_argument(
 36        '--persistent_queue',
 37        action='store_true',
 38        help="Enables a persistent version of AsyncCallsQueue.",
 39    )
 40
 41    return parser.parse_args()
 42
 43
 44class SimpleModel(nn.Module):
 45    """A simple feedforward neural network for demonstration purposes."""
 46
 47    def __init__(self):
 48        super(SimpleModel, self).__init__()
 49        self.fc1 = nn.Linear(10, 5)
 50        self.fc2 = nn.Linear(5, 2)
 51        self.activation = nn.ReLU()
 52
 53    def forward(self, x):
 54        x = self.activation(self.fc1(x))
 55        return self.fc2(x)
 56
 57
 58def init_distributed_backend(backend="nccl"):
 59    """Initializes the distributed process group using the specified backend."""
 60    try:
 61        dist.init_process_group(backend=backend, init_method="env://")
 62        rank = dist.get_rank()
 63        torch.cuda.set_device(rank)
 64        logging.info(f"Process {rank} initialized with {backend} backend.")
 65    except Exception as e:
 66        logging.error(f"Failed to initialize distributed backend: {e}")
 67        raise
 68
 69
 70def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
 71    """Creates an async save request with a finalize function."""
 72    save_fn, preload_fn, save_args = writer.get_save_function_and_args()
 73
 74    def finalize_fn():
 75        """Finalizes async checkpointing and synchronizes processes."""
 76        save_state_dict_async_finalize(*save_state_dict_ret)
 77        dist.barrier()
 78
 79    return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
 80
 81
 82def save_checkpoint(checkpoint_dir, async_queue, model, thread_count):
 83    """Asynchronously saves a model checkpoint."""
 84    state_dict = model.state_dict()
 85    planner = DefaultSavePlanner()
 86    writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count)
 87    coordinator_rank = 0
 88
 89    save_state_dict_ret, *_ = save_state_dict_async_plan(
 90        state_dict, writer, None, coordinator_rank, planner=planner
 91    )
 92    save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
 93    async_queue.schedule_async_request(save_request)
 94
 95
 96def load_checkpoint(checkpoint_dir, model):
 97    """Loads a model checkpoint synchronously."""
 98    state_dict = model.state_dict()
 99    checkpoint.load(
100        state_dict=state_dict,
101        storage_reader=FileSystemReader(checkpoint_dir),
102        planner=DefaultLoadPlanner(),
103    )
104    return state_dict
105
106
107def main():
108    args = parse_args()
109    logging.info(f"Arguments: {args}")
110
111    # Initialize distributed training
112    init_distributed_backend(backend="nccl")
113
114    # Initialize model and move to GPU
115    model = SimpleModel().to("cuda")
116
117    # Create an async queue for handling asynchronous operations
118    async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
119
120    # Define checkpoint directory based on iteration number
121    iteration = 123
122    checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
123
124    # Save the model asynchronously
125    save_checkpoint(checkpoint_dir, async_queue, model, args.thread_count)
126
127    logging.info("Finalizing checkpoint save...")
128    async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
129    async_queue.close()  # Explicitly close queue (optional)
130
131    # Synchronize processes before loading
132    dist.barrier()
133
134    # Load the checkpoint
135    loaded_sd = load_checkpoint(checkpoint_dir, model)
136
137    # Synchronize again to ensure all ranks have completed loading
138    dist.barrier()
139
140    # Clean up checkpoint directory (only on rank 0)
141    if dist.get_rank() == 0:
142        logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
143        shutil.rmtree(args.ckpt_dir)
144
145    # Ensure NCCL process group is properly destroyed
146    if dist.is_initialized():
147        dist.destroy_process_group()
148
149
150if __name__ == "__main__":
151    main()