FileSystemWriter example

  1import argparse
  2import logging
  3import shutil
  4
  5import torch
  6import torch.distributed as dist
  7import torch.nn as nn
  8import torch.optim as optim
  9from torch.distributed import checkpoint
 10from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
 11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 12from torch.utils.data import DataLoader, DistributedSampler
 13
 14from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
 15from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
 16from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
 17    init_checkpoint_metadata_cache,
 18    save_state_dict_async_finalize,
 19    save_state_dict_async_plan,
 20)
 21
 22# Set up basic logging configuration
 23# Try setting `DEBUG` to see detailed steps of NVRx checkpointing
 24logging.basicConfig(level=logging.INFO)
 25
 26FEAT_SIZE = 4096
 27DNN_OUT_SIZE = 128
 28BATCH_SIZE = 100
 29NUM_EPOCHS = 10
 30DATASET_LEN = 100000
 31CKPT_INTERVAL = 100
 32
 33
 34def print_on_rank0(msg):
 35    if dist.get_rank() == 0:
 36        print(msg)
 37
 38
 39def parse_args():
 40    parser = argparse.ArgumentParser(description="Async Checkpointing Example")
 41    parser.add_argument(
 42        "--ckpt_dir",
 43        default="/tmp/test_checkpointing/",
 44        help="Directory for saving checkpoints",
 45    )
 46    parser.add_argument(
 47        "--thread_count",
 48        default=2,
 49        type=int,
 50        help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
 51    )
 52    parser.add_argument(
 53        '--persistent_queue',
 54        action='store_true',
 55        help="Enables a persistent version of AsyncCallsQueue.",
 56    )
 57
 58    return parser.parse_args()
 59
 60
 61class SimpleDataset(torch.utils.data.Dataset):
 62    def __init__(self, size):
 63        self.size = size
 64
 65    def __len__(self):
 66        return self.size
 67
 68    def __getitem__(self, idx):
 69        x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
 70        y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
 71        return x, y
 72
 73
 74class SimpleModel(nn.Module):
 75    def __init__(self):
 76        super().__init__()
 77        self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
 78        self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
 79
 80    def forward(self, x):
 81        x = self.fc1(x)
 82        x = nn.functional.relu(x)
 83        x = self.fc2(x)
 84        return x
 85
 86
 87def init_distributed_backend(backend="nccl"):
 88    """Initializes the distributed process group using the specified backend."""
 89    try:
 90        dist.init_process_group(backend=backend, init_method="env://")
 91        rank = dist.get_rank()
 92        torch.cuda.set_device(rank)
 93        logging.info(f"Process {rank} initialized with {backend} backend.")
 94        return rank, torch.distributed.get_world_size()
 95    except Exception as e:
 96        logging.error(f"Failed to initialize distributed backend: {e}")
 97        raise
 98
 99
100def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
101    """Creates an async save request with a finalize function."""
102    save_fn, preload_fn, save_args = writer.get_save_function_and_args()
103
104    def finalize_fn():
105        """Finalizes async checkpointing and synchronizes processes."""
106        save_state_dict_async_finalize(*save_state_dict_ret)
107        dist.barrier()
108
109    return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
110
111
112def save_checkpoint(checkpoint_dir, async_queue, model, thread_count):
113    """Asynchronously saves a model checkpoint."""
114    state_dict = model.state_dict()
115    planner = DefaultSavePlanner()
116    writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count)
117    coordinator_rank = 0
118
119    save_state_dict_ret = save_state_dict_async_plan(
120        state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
121    )
122    save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
123    async_queue.schedule_async_request(save_request)
124
125
126def load_checkpoint(checkpoint_dir, model):
127    """Loads a model checkpoint synchronously."""
128    state_dict = model.state_dict()
129    checkpoint.load(
130        state_dict=state_dict,
131        storage_reader=FileSystemReader(checkpoint_dir),
132        planner=DefaultLoadPlanner(),
133    )
134    return state_dict
135
136
137def main():
138    args = parse_args()
139    logging.info(f"Arguments: {args}")
140
141    # Initialize distributed training
142    rank, world_size = init_distributed_backend(backend="nccl")
143
144    # Define checkpoint directory based on iteration number
145    dataset = SimpleDataset(size=DATASET_LEN)
146    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
147    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
148
149    # Model, optimizer, and FSDP wrapper
150    model = SimpleModel().to("cuda")
151    fsdp_model = FSDP(model)
152    optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
153    loss_fn = nn.MSELoss()
154
155    # Create an async queue for handling asynchronous operations
156    async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
157
158    iteration = 0
159    num_iters_in_epoch = len(dataloader)
160    print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
161
162    num_iters_for_10pct = num_iters_in_epoch // 10  # iters for 1/10 of epoch
163    checkpoint_dir = None
164    sampler.set_epoch(0)
165
166    init_checkpoint_metadata_cache()
167
168    for batch_idx, (data, target) in enumerate(dataloader):
169        async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
170        if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
171            print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
172        optimizer.zero_grad()
173        output = fsdp_model(data)
174        loss = loss_fn(output, target)
175        loss.backward()
176        optimizer.step()
177        if batch_idx % num_iters_for_10pct == 0:
178            iteration = batch_idx
179            checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
180            # Save the model asynchronously
181            save_checkpoint(checkpoint_dir, async_queue, fsdp_model, args.thread_count)
182            print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
183            iteration += batch_idx
184    print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
185
186    logging.info("Finalizing checkpoint save...")
187    async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
188    async_queue.close()  # Explicitly close queue (optional)
189
190    # Synchronize processes before loading
191    dist.barrier()
192    print_on_rank0(f"loading from {checkpoint_dir}")
193    # Load the checkpoint
194    loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model)
195
196    # Synchronize again to ensure all ranks have completed loading
197    dist.barrier()
198
199    # Clean up checkpoint directory (only on rank 0)
200    if dist.get_rank() == 0:
201        logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
202        shutil.rmtree(args.ckpt_dir)
203
204    # Ensure NCCL process group is properly destroyed
205    if dist.is_initialized():
206        dist.destroy_process_group()
207
208
209if __name__ == "__main__":
210    main()