1import argparse
2import logging
3import shutil
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8import torch.optim as optim
9from torch.distributed import checkpoint
10from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12from torch.utils.data import DataLoader, DistributedSampler
13
14from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
15from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
16from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
17 save_state_dict_async_finalize,
18 save_state_dict_async_plan,
19)
20
21# Set up basic logging configuration
22logging.basicConfig(level=logging.INFO)
23
24FEAT_SIZE = 4096
25DNN_OUT_SIZE = 128
26BATCH_SIZE = 100
27NUM_EPOCHS = 10
28DATASET_LEN = 100000
29CKPT_INTERVAL = 100
30
31
32def print_on_rank0(msg):
33 if dist.get_rank() == 0:
34 print(msg)
35
36
37def parse_args():
38 parser = argparse.ArgumentParser(description="Async Checkpointing Example")
39 parser.add_argument(
40 "--ckpt_dir",
41 default="/tmp/test_checkpointing/",
42 help="Directory for saving checkpoints",
43 )
44 parser.add_argument(
45 "--thread_count",
46 default=2,
47 type=int,
48 help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
49 )
50 parser.add_argument(
51 '--persistent_queue',
52 action='store_true',
53 help="Enables a persistent version of AsyncCallsQueue.",
54 )
55
56 return parser.parse_args()
57
58
59class SimpleDataset(torch.utils.data.Dataset):
60 def __init__(self, size):
61 self.size = size
62
63 def __len__(self):
64 return self.size
65
66 def __getitem__(self, idx):
67 x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
68 y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
69 return x, y
70
71
72class SimpleModel(nn.Module):
73 def __init__(self):
74 super().__init__()
75 self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
76 self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
77
78 def forward(self, x):
79 x = self.fc1(x)
80 x = nn.functional.relu(x)
81 x = self.fc2(x)
82 return x
83
84
85def init_distributed_backend(backend="nccl"):
86 """Initializes the distributed process group using the specified backend."""
87 try:
88 dist.init_process_group(backend=backend, init_method="env://")
89 rank = dist.get_rank()
90 torch.cuda.set_device(rank)
91 logging.info(f"Process {rank} initialized with {backend} backend.")
92 return rank, torch.distributed.get_world_size()
93 except Exception as e:
94 logging.error(f"Failed to initialize distributed backend: {e}")
95 raise
96
97
98def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
99 """Creates an async save request with a finalize function."""
100 save_fn, preload_fn, save_args = writer.get_save_function_and_args()
101
102 def finalize_fn():
103 """Finalizes async checkpointing and synchronizes processes."""
104 save_state_dict_async_finalize(*save_state_dict_ret)
105 dist.barrier()
106
107 return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
108
109
110def save_checkpoint(checkpoint_dir, async_queue, model, thread_count):
111 """Asynchronously saves a model checkpoint."""
112 state_dict = model.state_dict()
113 planner = DefaultSavePlanner()
114 writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count)
115 coordinator_rank = 0
116
117 save_state_dict_ret = save_state_dict_async_plan(
118 state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
119 )
120 save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
121 async_queue.schedule_async_request(save_request)
122
123
124def load_checkpoint(checkpoint_dir, model):
125 """Loads a model checkpoint synchronously."""
126 state_dict = model.state_dict()
127 checkpoint.load(
128 state_dict=state_dict,
129 storage_reader=FileSystemReader(checkpoint_dir),
130 planner=DefaultLoadPlanner(),
131 )
132 return state_dict
133
134
135def main():
136 args = parse_args()
137 logging.info(f"Arguments: {args}")
138
139 # Initialize distributed training
140 rank, world_size = init_distributed_backend(backend="nccl")
141
142 # Define checkpoint directory based on iteration number
143 dataset = SimpleDataset(size=DATASET_LEN)
144 sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
145 dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
146
147 # Model, optimizer, and DDP
148 model = SimpleModel().to("cuda")
149 fsdp_model = FSDP(model)
150 optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
151 loss_fn = nn.MSELoss()
152
153 # Create an async queue for handling asynchronous operations
154 async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
155
156 iteration = 0
157 num_iters_in_epoch = len(dataloader)
158 print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
159
160 num_iters_for_10pct = num_iters_in_epoch // 10 # iters for 1/10 of epoch
161 checkpoint_dir = None
162 sampler.set_epoch(0)
163 for batch_idx, (data, target) in enumerate(dataloader):
164 async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
165 if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
166 print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
167 optimizer.zero_grad()
168 output = fsdp_model(data)
169 loss = loss_fn(output, target)
170 loss.backward()
171 optimizer.step()
172 if batch_idx % num_iters_for_10pct == 0:
173 iteration = batch_idx
174 checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
175 # Save the model asynchronously
176 save_checkpoint(checkpoint_dir, async_queue, fsdp_model, args.thread_count)
177 print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
178 iteration += batch_idx
179 print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
180
181 logging.info("Finalizing checkpoint save...")
182 async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
183 async_queue.close() # Explicitly close queue (optional)
184
185 # Synchronize processes before loading
186 dist.barrier()
187 print_on_rank0(f"loading from {checkpoint_dir}")
188 # Load the checkpoint
189 loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model)
190
191 # Synchronize again to ensure all ranks have completed loading
192 dist.barrier()
193
194 # Clean up checkpoint directory (only on rank 0)
195 if dist.get_rank() == 0:
196 logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
197 shutil.rmtree(args.ckpt_dir)
198
199 # Ensure NCCL process group is properly destroyed
200 if dist.is_initialized():
201 dist.destroy_process_group()
202
203
204if __name__ == "__main__":
205 main()