1import argparse
2import logging
3import shutil
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8import torch.optim as optim
9from torch.distributed import checkpoint
10from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12from torch.utils.data import DataLoader, DistributedSampler
13
14from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
15from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
16from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
17 init_checkpoint_metadata_cache,
18 save_state_dict_async_finalize,
19 save_state_dict_async_plan,
20)
21
22# Set up basic logging configuration
23# Try setting `DEBUG` to see detailed steps of NVRx checkpointing
24logging.basicConfig(level=logging.INFO)
25
26FEAT_SIZE = 4096
27DNN_OUT_SIZE = 128
28BATCH_SIZE = 100
29NUM_EPOCHS = 10
30DATASET_LEN = 100000
31CKPT_INTERVAL = 100
32
33
34def print_on_rank0(msg):
35 if dist.get_rank() == 0:
36 print(msg)
37
38
39def parse_args():
40 parser = argparse.ArgumentParser(description="Async Checkpointing Example")
41 parser.add_argument(
42 "--ckpt_dir",
43 default="/tmp/test_checkpointing/",
44 help="Directory for saving checkpoints",
45 )
46 parser.add_argument(
47 "--thread_count",
48 default=2,
49 type=int,
50 help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
51 )
52 parser.add_argument(
53 '--persistent_queue',
54 action='store_true',
55 help="Enables a persistent version of AsyncCallsQueue.",
56 )
57
58 return parser.parse_args()
59
60
61class SimpleDataset(torch.utils.data.Dataset):
62 def __init__(self, size):
63 self.size = size
64
65 def __len__(self):
66 return self.size
67
68 def __getitem__(self, idx):
69 x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
70 y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
71 return x, y
72
73
74class SimpleModel(nn.Module):
75 def __init__(self):
76 super().__init__()
77 self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
78 self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
79
80 def forward(self, x):
81 x = self.fc1(x)
82 x = nn.functional.relu(x)
83 x = self.fc2(x)
84 return x
85
86
87def init_distributed_backend(backend="nccl"):
88 """Initializes the distributed process group using the specified backend."""
89 try:
90 dist.init_process_group(backend=backend, init_method="env://")
91 rank = dist.get_rank()
92 torch.cuda.set_device(rank)
93 logging.info(f"Process {rank} initialized with {backend} backend.")
94 return rank, torch.distributed.get_world_size()
95 except Exception as e:
96 logging.error(f"Failed to initialize distributed backend: {e}")
97 raise
98
99
100def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
101 """Creates an async save request with a finalize function."""
102 save_fn, preload_fn, save_args = writer.get_save_function_and_args()
103
104 def finalize_fn():
105 """Finalizes async checkpointing and synchronizes processes."""
106 save_state_dict_async_finalize(*save_state_dict_ret)
107 dist.barrier()
108
109 return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
110
111
112def save_checkpoint(checkpoint_dir, async_queue, model, thread_count):
113 """Asynchronously saves a model checkpoint."""
114 state_dict = model.state_dict()
115 planner = DefaultSavePlanner()
116 writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count)
117 coordinator_rank = 0
118
119 save_state_dict_ret = save_state_dict_async_plan(
120 state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
121 )
122 save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
123 async_queue.schedule_async_request(save_request)
124
125
126def load_checkpoint(checkpoint_dir, model):
127 """Loads a model checkpoint synchronously."""
128 state_dict = model.state_dict()
129 checkpoint.load(
130 state_dict=state_dict,
131 storage_reader=FileSystemReader(checkpoint_dir),
132 planner=DefaultLoadPlanner(),
133 )
134 return state_dict
135
136
137def main():
138 args = parse_args()
139 logging.info(f"Arguments: {args}")
140
141 # Initialize distributed training
142 rank, world_size = init_distributed_backend(backend="nccl")
143
144 # Define checkpoint directory based on iteration number
145 dataset = SimpleDataset(size=DATASET_LEN)
146 sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
147 dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
148
149 # Model, optimizer, and FSDP wrapper
150 model = SimpleModel().to("cuda")
151 fsdp_model = FSDP(model)
152 optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
153 loss_fn = nn.MSELoss()
154
155 # Create an async queue for handling asynchronous operations
156 async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
157
158 iteration = 0
159 num_iters_in_epoch = len(dataloader)
160 print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
161
162 num_iters_for_10pct = num_iters_in_epoch // 10 # iters for 1/10 of epoch
163 checkpoint_dir = None
164 sampler.set_epoch(0)
165
166 init_checkpoint_metadata_cache()
167
168 for batch_idx, (data, target) in enumerate(dataloader):
169 async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
170 if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
171 print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
172 optimizer.zero_grad()
173 output = fsdp_model(data)
174 loss = loss_fn(output, target)
175 loss.backward()
176 optimizer.step()
177 if batch_idx % num_iters_for_10pct == 0:
178 iteration = batch_idx
179 checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
180 # Save the model asynchronously
181 save_checkpoint(checkpoint_dir, async_queue, fsdp_model, args.thread_count)
182 print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
183 iteration += batch_idx
184 print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
185
186 logging.info("Finalizing checkpoint save...")
187 async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
188 async_queue.close() # Explicitly close queue (optional)
189
190 # Synchronize processes before loading
191 dist.barrier()
192 print_on_rank0(f"loading from {checkpoint_dir}")
193 # Load the checkpoint
194 loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model)
195
196 # Synchronize again to ensure all ranks have completed loading
197 dist.barrier()
198
199 # Clean up checkpoint directory (only on rank 0)
200 if dist.get_rank() == 0:
201 logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
202 shutil.rmtree(args.ckpt_dir)
203
204 # Ensure NCCL process group is properly destroyed
205 if dist.is_initialized():
206 dist.destroy_process_group()
207
208
209if __name__ == "__main__":
210 main()