1import argparse
2import logging
3import shutil
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8import torch.optim as optim
9from torch.distributed import checkpoint
10from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12from torch.utils.data import DataLoader, DistributedSampler
13
14from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
15from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
16from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
17 init_checkpoint_metadata_cache,
18 save_state_dict_async_finalize,
19 save_state_dict_async_plan,
20)
21
22# Set up basic logging configuration
23# Try setting `DEBUG` to see detailed steps of NVRx checkpointing
24logging.basicConfig(level=logging.INFO)
25
26FEAT_SIZE = 4096
27DNN_OUT_SIZE = 128
28BATCH_SIZE = 100
29NUM_EPOCHS = 10
30DATASET_LEN = 100000
31CKPT_INTERVAL = 100
32
33
34def print_on_rank0(msg):
35 if dist.get_rank() == 0:
36 print(msg)
37
38
39def parse_args():
40 parser = argparse.ArgumentParser(description="Async Checkpointing Example")
41 parser.add_argument(
42 "--ckpt_dir",
43 default="/tmp/test_checkpointing/",
44 help="Directory for saving checkpoints",
45 )
46 parser.add_argument(
47 "--thread_count",
48 default=2,
49 type=int,
50 help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
51 )
52 parser.add_argument(
53 '--no_persistent_queue',
54 action='store_false',
55 default=True,
56 dest='persistent_queue',
57 help=(
58 "Disables a persistent version of AsyncCallsQueue. "
59 "Effective only when --async_save is set."
60 ),
61 )
62
63 return parser.parse_args()
64
65
66class SimpleDataset(torch.utils.data.Dataset):
67 def __init__(self, size):
68 self.size = size
69
70 def __len__(self):
71 return self.size
72
73 def __getitem__(self, idx):
74 x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
75 y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
76 return x, y
77
78
79class SimpleModel(nn.Module):
80 def __init__(self):
81 super().__init__()
82 self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
83 self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
84
85 def forward(self, x):
86 x = self.fc1(x)
87 x = nn.functional.relu(x)
88 x = self.fc2(x)
89 return x
90
91
92def init_distributed_backend(backend="nccl"):
93 """Initializes the distributed process group using the specified backend."""
94 try:
95 dist.init_process_group(backend=backend, init_method="env://")
96 rank = dist.get_rank()
97 torch.cuda.set_device(rank)
98 logging.info(f"Process {rank} initialized with {backend} backend.")
99 return rank, torch.distributed.get_world_size()
100 except Exception as e:
101 logging.error(f"Failed to initialize distributed backend: {e}")
102 raise
103
104
105def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
106 """Creates an async save request with a finalize function."""
107 save_fn, preload_fn, save_args = writer.get_save_function_and_args()
108
109 def finalize_fn():
110 """Finalizes async checkpointing and synchronizes processes."""
111 save_state_dict_async_finalize(*save_state_dict_ret)
112 dist.barrier()
113
114 return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
115
116
117def save_checkpoint(checkpoint_dir, async_queue, model, thread_count):
118 """Asynchronously saves a model checkpoint."""
119 state_dict = model.state_dict()
120 planner = DefaultSavePlanner()
121 writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count)
122 coordinator_rank = 0
123
124 save_state_dict_ret = save_state_dict_async_plan(
125 state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
126 )
127 save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
128 async_queue.schedule_async_request(save_request)
129
130
131def load_checkpoint(checkpoint_dir, model):
132 """Loads a model checkpoint synchronously."""
133 state_dict = model.state_dict()
134 checkpoint.load(
135 state_dict=state_dict,
136 storage_reader=FileSystemReader(checkpoint_dir),
137 planner=DefaultLoadPlanner(),
138 )
139 return state_dict
140
141
142def main():
143 args = parse_args()
144 logging.info(f"Arguments: {args}")
145
146 # Initialize distributed training
147 rank, world_size = init_distributed_backend(backend="nccl")
148
149 # Define checkpoint directory based on iteration number
150 dataset = SimpleDataset(size=DATASET_LEN)
151 sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
152 dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
153
154 # Model, optimizer, and FSDP wrapper
155 model = SimpleModel().to("cuda")
156 fsdp_model = FSDP(model)
157 optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
158 loss_fn = nn.MSELoss()
159
160 # Create an async queue for handling asynchronous operations
161 async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
162
163 iteration = 0
164 num_iters_in_epoch = len(dataloader)
165 print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
166
167 num_iters_for_10pct = num_iters_in_epoch // 10 # iters for 1/10 of epoch
168 checkpoint_dir = None
169 sampler.set_epoch(0)
170
171 init_checkpoint_metadata_cache()
172
173 for batch_idx, (data, target) in enumerate(dataloader):
174 async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
175 if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
176 print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
177 optimizer.zero_grad()
178 output = fsdp_model(data)
179 loss = loss_fn(output, target)
180 loss.backward()
181 optimizer.step()
182 if batch_idx % num_iters_for_10pct == 0:
183 iteration = batch_idx
184 checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
185 # Save the model asynchronously
186 save_checkpoint(checkpoint_dir, async_queue, fsdp_model, args.thread_count)
187 print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
188 iteration += batch_idx
189 print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
190
191 logging.info("Finalizing checkpoint save...")
192 async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
193 async_queue.close() # Explicitly close queue (optional)
194
195 # Synchronize processes before loading
196 dist.barrier()
197 print_on_rank0(f"loading from {checkpoint_dir}")
198 # Load the checkpoint
199 loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model)
200
201 # Synchronize again to ensure all ranks have completed loading
202 dist.barrier()
203
204 # Clean up checkpoint directory (only on rank 0)
205 if dist.get_rank() == 0:
206 logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
207 shutil.rmtree(args.ckpt_dir)
208
209 # Ensure NCCL process group is properly destroyed
210 if dist.is_initialized():
211 dist.destroy_process_group()
212
213
214if __name__ == "__main__":
215 main()