1import argparse
2import logging
3import shutil
4
5import multistorageclient as msc
6import torch
7import torch.distributed as dist
8import torch.nn as nn
9import torch.optim as optim
10from torch.distributed import checkpoint
11from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13from torch.utils.data import DataLoader, DistributedSampler
14
15from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
16from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
17from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
18 init_checkpoint_metadata_cache,
19 save_state_dict_async_finalize,
20 save_state_dict_async_plan,
21)
22
23# Set up basic logging configuration
24# Try setting `DEBUG` to see detailed steps of NVRx checkpointing
25logging.basicConfig(level=logging.INFO)
26
27FEAT_SIZE = 4096
28DNN_OUT_SIZE = 128
29BATCH_SIZE = 100
30NUM_EPOCHS = 10
31DATASET_LEN = 100000
32CKPT_INTERVAL = 100
33
34
35def print_on_rank0(msg):
36 if dist.get_rank() == 0:
37 print(msg)
38
39
40def parse_args():
41 parser = argparse.ArgumentParser(description="Async Checkpointing Example")
42 parser.add_argument(
43 "--ckpt_dir",
44 default="/tmp/test_checkpointing/",
45 help="Directory for saving checkpoints",
46 )
47 parser.add_argument(
48 "--thread_count",
49 default=2,
50 type=int,
51 help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
52 )
53 parser.add_argument(
54 '--no_persistent_queue',
55 action='store_false',
56 default=True,
57 dest='persistent_queue',
58 help=(
59 "Disables a persistent version of AsyncCallsQueue. "
60 "Effective only when --async_save is set."
61 ),
62 )
63 parser.add_argument(
64 '--enable_msc',
65 action='store_true',
66 help="Enables MSC for checkpoint saving and loading. See Usage Guide in Async Checkpointing documentation for detailed instructions.",
67 )
68
69 return parser.parse_args()
70
71
72class SimpleDataset(torch.utils.data.Dataset):
73 def __init__(self, size):
74 self.size = size
75
76 def __len__(self):
77 return self.size
78
79 def __getitem__(self, idx):
80 x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
81 y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
82 return x, y
83
84
85class SimpleModel(nn.Module):
86 def __init__(self):
87 super().__init__()
88 self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
89 self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
90
91 def forward(self, x):
92 x = self.fc1(x)
93 x = nn.functional.relu(x)
94 x = self.fc2(x)
95 return x
96
97
98def init_distributed_backend(backend="nccl"):
99 """Initializes the distributed process group using the specified backend."""
100 try:
101 dist.init_process_group(backend=backend, init_method="env://")
102 rank = dist.get_rank()
103 torch.cuda.set_device(dist.get_node_local_rank())
104 logging.info(f"Process {rank} initialized with {backend} backend.")
105 return rank, torch.distributed.get_world_size()
106 except Exception as e:
107 logging.error(f"Failed to initialize distributed backend: {e}")
108 raise
109
110
111def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
112 """Creates an async save request with a finalize function."""
113 save_fn, preload_fn, save_args = writer.get_save_function_and_args()
114
115 def finalize_fn():
116 """Finalizes async checkpointing and synchronizes processes."""
117 save_state_dict_async_finalize(*save_state_dict_ret)
118 dist.barrier()
119
120 return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
121
122
123def save_checkpoint(checkpoint_dir, async_queue, model, thread_count, enable_msc):
124 """Asynchronously saves a model checkpoint."""
125 state_dict = model.state_dict()
126 planner = DefaultSavePlanner()
127 writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count, use_msc=enable_msc)
128 coordinator_rank = 0
129
130 save_state_dict_ret = save_state_dict_async_plan(
131 state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
132 )
133 save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
134 async_queue.schedule_async_request(save_request)
135
136
137def load_checkpoint(checkpoint_dir, model, thread_count, enable_msc):
138 """Loads a model checkpoint synchronously."""
139 state_dict = model.state_dict()
140 if enable_msc:
141 reader = msc.torch.MultiStorageFileSystemReader(checkpoint_dir, thread_count=thread_count)
142 else:
143 reader = FileSystemReader(checkpoint_dir)
144 checkpoint.load(
145 state_dict=state_dict,
146 storage_reader=reader,
147 planner=DefaultLoadPlanner(),
148 )
149 return state_dict
150
151
152def main():
153 args = parse_args()
154 logging.info(f"Arguments: {args}")
155
156 # Initialize distributed training
157 rank, world_size = init_distributed_backend(backend="nccl")
158
159 # Define checkpoint directory based on iteration number
160 dataset = SimpleDataset(size=DATASET_LEN)
161 sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
162 dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
163
164 # Model, optimizer, and FSDP wrapper
165 model = SimpleModel().to("cuda")
166 fsdp_model = FSDP(model)
167 optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
168 loss_fn = nn.MSELoss()
169
170 # Create an async queue for handling asynchronous operations
171 async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
172
173 iteration = 0
174 num_iters_in_epoch = len(dataloader)
175 print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
176
177 num_iters_for_10pct = num_iters_in_epoch // 10 # iters for 1/10 of epoch
178 checkpoint_dir = None
179 sampler.set_epoch(0)
180
181 init_checkpoint_metadata_cache()
182
183 for batch_idx, (data, target) in enumerate(dataloader):
184 async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
185 if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
186 print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
187 optimizer.zero_grad()
188 output = fsdp_model(data)
189 loss = loss_fn(output, target)
190 loss.backward()
191 optimizer.step()
192 if batch_idx % num_iters_for_10pct == 0:
193 iteration = batch_idx
194 checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
195 # Save the model asynchronously
196 save_checkpoint(
197 checkpoint_dir, async_queue, fsdp_model, args.thread_count, args.enable_msc
198 )
199 print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
200 iteration += batch_idx
201 print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
202
203 logging.info("Finalizing checkpoint save...")
204 async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
205 async_queue.close() # Explicitly close queue
206
207 # Synchronize processes before loading
208 dist.barrier()
209 print_on_rank0(f"loading from {checkpoint_dir}")
210 # Load the checkpoint
211 loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model, args.thread_count, args.enable_msc)
212
213 # Synchronize again to ensure all ranks have completed loading
214 dist.barrier()
215
216 # Clean up checkpoint directory (only on rank 0)
217 if dist.get_rank() == 0:
218 logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
219 if args.enable_msc:
220 msc.delete(args.ckpt_dir, recursive=True)
221 else:
222 shutil.rmtree(args.ckpt_dir)
223
224 # Ensure NCCL process group is properly destroyed
225 if dist.is_initialized():
226 dist.destroy_process_group()
227
228
229if __name__ == "__main__":
230 main()