FileSystemWriter example

  1import argparse
  2import logging
  3import shutil
  4
  5import torch
  6import torch.distributed as dist
  7import torch.nn as nn
  8import torch.optim as optim
  9from torch.distributed import checkpoint
 10from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
 11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 12from torch.utils.data import DataLoader, DistributedSampler
 13
 14from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
 15from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
 16from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
 17    init_checkpoint_metadata_cache,
 18    save_state_dict_async_finalize,
 19    save_state_dict_async_plan,
 20)
 21
 22# Set up basic logging configuration
 23# Try setting `DEBUG` to see detailed steps of NVRx checkpointing
 24logging.basicConfig(level=logging.INFO)
 25
 26FEAT_SIZE = 4096
 27DNN_OUT_SIZE = 128
 28BATCH_SIZE = 100
 29NUM_EPOCHS = 10
 30DATASET_LEN = 100000
 31CKPT_INTERVAL = 100
 32
 33
 34def print_on_rank0(msg):
 35    if dist.get_rank() == 0:
 36        print(msg)
 37
 38
 39def parse_args():
 40    parser = argparse.ArgumentParser(description="Async Checkpointing Example")
 41    parser.add_argument(
 42        "--ckpt_dir",
 43        default="/tmp/test_checkpointing/",
 44        help="Directory for saving checkpoints",
 45    )
 46    parser.add_argument(
 47        "--thread_count",
 48        default=2,
 49        type=int,
 50        help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
 51    )
 52    parser.add_argument(
 53        '--no_persistent_queue',
 54        action='store_false',
 55        default=True,
 56        dest='persistent_queue',
 57        help=(
 58            "Disables a persistent version of AsyncCallsQueue. "
 59            "Effective only when --async_save is set."
 60        ),
 61    )
 62
 63    return parser.parse_args()
 64
 65
 66class SimpleDataset(torch.utils.data.Dataset):
 67    def __init__(self, size):
 68        self.size = size
 69
 70    def __len__(self):
 71        return self.size
 72
 73    def __getitem__(self, idx):
 74        x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
 75        y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
 76        return x, y
 77
 78
 79class SimpleModel(nn.Module):
 80    def __init__(self):
 81        super().__init__()
 82        self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
 83        self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
 84
 85    def forward(self, x):
 86        x = self.fc1(x)
 87        x = nn.functional.relu(x)
 88        x = self.fc2(x)
 89        return x
 90
 91
 92def init_distributed_backend(backend="nccl"):
 93    """Initializes the distributed process group using the specified backend."""
 94    try:
 95        dist.init_process_group(backend=backend, init_method="env://")
 96        rank = dist.get_rank()
 97        torch.cuda.set_device(rank)
 98        logging.info(f"Process {rank} initialized with {backend} backend.")
 99        return rank, torch.distributed.get_world_size()
100    except Exception as e:
101        logging.error(f"Failed to initialize distributed backend: {e}")
102        raise
103
104
105def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
106    """Creates an async save request with a finalize function."""
107    save_fn, preload_fn, save_args = writer.get_save_function_and_args()
108
109    def finalize_fn():
110        """Finalizes async checkpointing and synchronizes processes."""
111        save_state_dict_async_finalize(*save_state_dict_ret)
112        dist.barrier()
113
114    return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
115
116
117def save_checkpoint(checkpoint_dir, async_queue, model, thread_count):
118    """Asynchronously saves a model checkpoint."""
119    state_dict = model.state_dict()
120    planner = DefaultSavePlanner()
121    writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count)
122    coordinator_rank = 0
123
124    save_state_dict_ret = save_state_dict_async_plan(
125        state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
126    )
127    save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
128    async_queue.schedule_async_request(save_request)
129
130
131def load_checkpoint(checkpoint_dir, model):
132    """Loads a model checkpoint synchronously."""
133    state_dict = model.state_dict()
134    checkpoint.load(
135        state_dict=state_dict,
136        storage_reader=FileSystemReader(checkpoint_dir),
137        planner=DefaultLoadPlanner(),
138    )
139    return state_dict
140
141
142def main():
143    args = parse_args()
144    logging.info(f"Arguments: {args}")
145
146    # Initialize distributed training
147    rank, world_size = init_distributed_backend(backend="nccl")
148
149    # Define checkpoint directory based on iteration number
150    dataset = SimpleDataset(size=DATASET_LEN)
151    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
152    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
153
154    # Model, optimizer, and FSDP wrapper
155    model = SimpleModel().to("cuda")
156    fsdp_model = FSDP(model)
157    optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
158    loss_fn = nn.MSELoss()
159
160    # Create an async queue for handling asynchronous operations
161    async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
162
163    iteration = 0
164    num_iters_in_epoch = len(dataloader)
165    print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
166
167    num_iters_for_10pct = num_iters_in_epoch // 10  # iters for 1/10 of epoch
168    checkpoint_dir = None
169    sampler.set_epoch(0)
170
171    init_checkpoint_metadata_cache()
172
173    for batch_idx, (data, target) in enumerate(dataloader):
174        async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
175        if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
176            print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
177        optimizer.zero_grad()
178        output = fsdp_model(data)
179        loss = loss_fn(output, target)
180        loss.backward()
181        optimizer.step()
182        if batch_idx % num_iters_for_10pct == 0:
183            iteration = batch_idx
184            checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
185            # Save the model asynchronously
186            save_checkpoint(checkpoint_dir, async_queue, fsdp_model, args.thread_count)
187            print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
188            iteration += batch_idx
189    print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
190
191    logging.info("Finalizing checkpoint save...")
192    async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
193    async_queue.close()  # Explicitly close queue (optional)
194
195    # Synchronize processes before loading
196    dist.barrier()
197    print_on_rank0(f"loading from {checkpoint_dir}")
198    # Load the checkpoint
199    loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model)
200
201    # Synchronize again to ensure all ranks have completed loading
202    dist.barrier()
203
204    # Clean up checkpoint directory (only on rank 0)
205    if dist.get_rank() == 0:
206        logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
207        shutil.rmtree(args.ckpt_dir)
208
209    # Ensure NCCL process group is properly destroyed
210    if dist.is_initialized():
211        dist.destroy_process_group()
212
213
214if __name__ == "__main__":
215    main()