FileSystemWriter example

  1import argparse
  2import logging
  3import shutil
  4
  5import torch
  6import torch.distributed as dist
  7import torch.nn as nn
  8import torch.optim as optim
  9from torch.distributed import checkpoint
 10from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
 11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 12from torch.utils.data import DataLoader, DistributedSampler
 13
 14from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
 15from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
 16from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
 17    save_state_dict_async_finalize,
 18    save_state_dict_async_plan,
 19)
 20
 21# Set up basic logging configuration
 22logging.basicConfig(level=logging.INFO)
 23
 24FEAT_SIZE = 4096
 25DNN_OUT_SIZE = 128
 26BATCH_SIZE = 100
 27NUM_EPOCHS = 10
 28DATASET_LEN = 100000
 29CKPT_INTERVAL = 100
 30
 31
 32def print_on_rank0(msg):
 33    if dist.get_rank() == 0:
 34        print(msg)
 35
 36
 37def parse_args():
 38    parser = argparse.ArgumentParser(description="Async Checkpointing Example")
 39    parser.add_argument(
 40        "--ckpt_dir",
 41        default="/tmp/test_checkpointing/",
 42        help="Directory for saving checkpoints",
 43    )
 44    parser.add_argument(
 45        "--thread_count",
 46        default=2,
 47        type=int,
 48        help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
 49    )
 50    parser.add_argument(
 51        '--persistent_queue',
 52        action='store_true',
 53        help="Enables a persistent version of AsyncCallsQueue.",
 54    )
 55
 56    return parser.parse_args()
 57
 58
 59class SimpleDataset(torch.utils.data.Dataset):
 60    def __init__(self, size):
 61        self.size = size
 62
 63    def __len__(self):
 64        return self.size
 65
 66    def __getitem__(self, idx):
 67        x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
 68        y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
 69        return x, y
 70
 71
 72class SimpleModel(nn.Module):
 73    def __init__(self):
 74        super().__init__()
 75        self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
 76        self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
 77
 78    def forward(self, x):
 79        x = self.fc1(x)
 80        x = nn.functional.relu(x)
 81        x = self.fc2(x)
 82        return x
 83
 84
 85def init_distributed_backend(backend="nccl"):
 86    """Initializes the distributed process group using the specified backend."""
 87    try:
 88        dist.init_process_group(backend=backend, init_method="env://")
 89        rank = dist.get_rank()
 90        torch.cuda.set_device(rank)
 91        logging.info(f"Process {rank} initialized with {backend} backend.")
 92        return rank, torch.distributed.get_world_size()
 93    except Exception as e:
 94        logging.error(f"Failed to initialize distributed backend: {e}")
 95        raise
 96
 97
 98def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
 99    """Creates an async save request with a finalize function."""
100    save_fn, preload_fn, save_args = writer.get_save_function_and_args()
101
102    def finalize_fn():
103        """Finalizes async checkpointing and synchronizes processes."""
104        save_state_dict_async_finalize(*save_state_dict_ret)
105        dist.barrier()
106
107    return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
108
109
110def save_checkpoint(checkpoint_dir, async_queue, model, thread_count):
111    """Asynchronously saves a model checkpoint."""
112    state_dict = model.state_dict()
113    planner = DefaultSavePlanner()
114    writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count)
115    coordinator_rank = 0
116
117    save_state_dict_ret = save_state_dict_async_plan(
118        state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
119    )
120    save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
121    async_queue.schedule_async_request(save_request)
122
123
124def load_checkpoint(checkpoint_dir, model):
125    """Loads a model checkpoint synchronously."""
126    state_dict = model.state_dict()
127    checkpoint.load(
128        state_dict=state_dict,
129        storage_reader=FileSystemReader(checkpoint_dir),
130        planner=DefaultLoadPlanner(),
131    )
132    return state_dict
133
134
135def main():
136    args = parse_args()
137    logging.info(f"Arguments: {args}")
138
139    # Initialize distributed training
140    rank, world_size = init_distributed_backend(backend="nccl")
141
142    # Define checkpoint directory based on iteration number
143    dataset = SimpleDataset(size=DATASET_LEN)
144    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
145    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
146
147    # Model, optimizer, and DDP
148    model = SimpleModel().to("cuda")
149    fsdp_model = FSDP(model)
150    optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
151    loss_fn = nn.MSELoss()
152
153    # Create an async queue for handling asynchronous operations
154    async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
155
156    iteration = 0
157    num_iters_in_epoch = len(dataloader)
158    print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
159
160    num_iters_for_10pct = num_iters_in_epoch // 10  # iters for 1/10 of epoch
161    checkpoint_dir = None
162    sampler.set_epoch(0)
163    for batch_idx, (data, target) in enumerate(dataloader):
164        async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
165        if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
166            print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
167        optimizer.zero_grad()
168        output = fsdp_model(data)
169        loss = loss_fn(output, target)
170        loss.backward()
171        optimizer.step()
172        if batch_idx % num_iters_for_10pct == 0:
173            iteration = batch_idx
174            checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
175            # Save the model asynchronously
176            save_checkpoint(checkpoint_dir, async_queue, fsdp_model, args.thread_count)
177            print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
178            iteration += batch_idx
179    print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
180
181    logging.info("Finalizing checkpoint save...")
182    async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
183    async_queue.close()  # Explicitly close queue (optional)
184
185    # Synchronize processes before loading
186    dist.barrier()
187    print_on_rank0(f"loading from {checkpoint_dir}")
188    # Load the checkpoint
189    loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model)
190
191    # Synchronize again to ensure all ranks have completed loading
192    dist.barrier()
193
194    # Clean up checkpoint directory (only on rank 0)
195    if dist.get_rank() == 0:
196        logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
197        shutil.rmtree(args.ckpt_dir)
198
199    # Ensure NCCL process group is properly destroyed
200    if dist.is_initialized():
201        dist.destroy_process_group()
202
203
204if __name__ == "__main__":
205    main()