1import argparse
2import logging
3import shutil
4
5import multistorageclient as msc
6import torch
7import torch.distributed as dist
8import torch.nn as nn
9import torch.optim as optim
10from torch.distributed import checkpoint
11from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13from torch.utils.data import DataLoader, DistributedSampler
14
15from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
16from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
17from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
18 init_checkpoint_metadata_cache,
19 save_state_dict_async_finalize,
20 save_state_dict_async_plan,
21)
22
23# Set up basic logging configuration
24# Try setting `DEBUG` to see detailed steps of NVRx checkpointing
25logging.basicConfig(level=logging.INFO)
26
27FEAT_SIZE = 4096
28DNN_OUT_SIZE = 128
29BATCH_SIZE = 100
30NUM_EPOCHS = 10
31DATASET_LEN = 100000
32CKPT_INTERVAL = 100
33
34
35def print_on_rank0(msg):
36 if dist.get_rank() == 0:
37 print(msg)
38
39
40def parse_args():
41 parser = argparse.ArgumentParser(description="Async Checkpointing Example")
42 parser.add_argument(
43 "--ckpt_dir",
44 default="/tmp/test_checkpointing/",
45 help="Directory for saving checkpoints",
46 )
47 parser.add_argument(
48 "--thread_count",
49 default=2,
50 type=int,
51 help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
52 )
53 parser.add_argument(
54 '--no_persistent_queue',
55 action='store_false',
56 default=True,
57 dest='persistent_queue',
58 help=(
59 "Disables a persistent version of AsyncCallsQueue. "
60 "Effective only when --async_save is set."
61 ),
62 )
63 parser.add_argument(
64 '--enable_msc',
65 action='store_true',
66 help="Enables MSC for checkpoint saving and loading. See Usage Guide in Async Checkpointing documentation for detailed instructions.",
67 )
68
69 return parser.parse_args()
70
71
72class SimpleDataset(torch.utils.data.Dataset):
73 def __init__(self, size):
74 self.size = size
75
76 def __len__(self):
77 return self.size
78
79 def __getitem__(self, idx):
80 x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
81 y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
82 return x, y
83
84
85class SimpleModel(nn.Module):
86 def __init__(self):
87 super().__init__()
88 self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
89 self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
90
91 def forward(self, x):
92 x = self.fc1(x)
93 x = nn.functional.relu(x)
94 x = self.fc2(x)
95 return x
96
97
98def init_distributed_backend(backend="nccl"):
99 """Initializes the distributed process group using the specified backend."""
100 try:
101 dist.init_process_group(backend=backend, init_method="env://")
102 rank = dist.get_rank()
103 torch.cuda.set_device(dist.get_node_local_rank())
104 logging.info(f"Process {rank} initialized with {backend} backend.")
105 return rank, torch.distributed.get_world_size()
106 except Exception as e:
107 logging.error(f"Failed to initialize distributed backend: {e}")
108 raise
109
110
111def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
112 """Creates an async save request with a finalize function."""
113 save_fn, preload_fn, save_args = writer.get_save_function_and_args()
114
115 def finalize_fn():
116 """Finalizes async checkpointing and synchronizes processes."""
117 save_state_dict_async_finalize(*save_state_dict_ret)
118 dist.barrier()
119
120 return AsyncRequest(
121 save_fn, save_args, [finalize_fn], async_fn_kwargs={}, preload_fn=preload_fn
122 )
123
124
125def save_checkpoint(checkpoint_dir, async_queue, model, thread_count, enable_msc):
126 """Asynchronously saves a model checkpoint."""
127 state_dict = model.state_dict()
128 planner = DefaultSavePlanner()
129 writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count, use_msc=enable_msc)
130 coordinator_rank = 0
131
132 save_state_dict_ret = save_state_dict_async_plan(
133 state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
134 )
135 save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
136 async_queue.schedule_async_request(save_request)
137
138
139def load_checkpoint(checkpoint_dir, model, thread_count, enable_msc):
140 """Loads a model checkpoint synchronously."""
141 state_dict = model.state_dict()
142 if enable_msc:
143 reader = msc.torch.MultiStorageFileSystemReader(checkpoint_dir, thread_count=thread_count)
144 else:
145 reader = FileSystemReader(checkpoint_dir)
146 checkpoint.load(
147 state_dict=state_dict,
148 storage_reader=reader,
149 planner=DefaultLoadPlanner(),
150 )
151 return state_dict
152
153
154def main():
155 args = parse_args()
156 logging.info(f"Arguments: {args}")
157
158 # Initialize distributed training
159 rank, world_size = init_distributed_backend(backend="nccl")
160
161 # Define checkpoint directory based on iteration number
162 dataset = SimpleDataset(size=DATASET_LEN)
163 sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
164 dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
165
166 # Model, optimizer, and FSDP wrapper
167 model = SimpleModel().to("cuda")
168 fsdp_model = FSDP(model)
169 optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
170 loss_fn = nn.MSELoss()
171
172 # Create an async queue for handling asynchronous operations
173 async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
174
175 iteration = 0
176 num_iters_in_epoch = len(dataloader)
177 print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
178
179 num_iters_for_10pct = num_iters_in_epoch // 10 # iters for 1/10 of epoch
180 checkpoint_dir = None
181 sampler.set_epoch(0)
182
183 init_checkpoint_metadata_cache()
184
185 for batch_idx, (data, target) in enumerate(dataloader):
186 async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
187 if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
188 print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
189 optimizer.zero_grad()
190 output = fsdp_model(data)
191 loss = loss_fn(output, target)
192 loss.backward()
193 optimizer.step()
194 if batch_idx % num_iters_for_10pct == 0:
195 iteration = batch_idx
196 checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
197 # Save the model asynchronously
198 save_checkpoint(
199 checkpoint_dir, async_queue, fsdp_model, args.thread_count, args.enable_msc
200 )
201 print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
202 iteration += batch_idx
203 print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
204
205 logging.info("Finalizing checkpoint save...")
206 async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
207 async_queue.close() # Explicitly close queue
208
209 # Synchronize processes before loading
210 dist.barrier()
211 print_on_rank0(f"loading from {checkpoint_dir}")
212 # Load the checkpoint
213 loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model, args.thread_count, args.enable_msc)
214
215 # Synchronize again to ensure all ranks have completed loading
216 dist.barrier()
217
218 # Clean up checkpoint directory (only on rank 0)
219 if dist.get_rank() == 0:
220 logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
221 if args.enable_msc:
222 msc.delete(args.ckpt_dir, recursive=True)
223 else:
224 shutil.rmtree(args.ckpt_dir)
225
226 # Ensure NCCL process group is properly destroyed
227 if dist.is_initialized():
228 dist.destroy_process_group()
229
230
231if __name__ == "__main__":
232 main()