1import argparse
2import logging
3import shutil
4
5try:
6 import multistorageclient as msc
7except ImportError:
8 msc = None
9import torch
10import torch.distributed as dist
11import torch.nn as nn
12import torch.optim as optim
13from torch.distributed import checkpoint
14from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
15from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16from torch.utils.data import DataLoader, DistributedSampler
17
18from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
19from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
20from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
21 init_checkpoint_metadata_cache,
22 save_state_dict_async_finalize,
23 save_state_dict_async_plan,
24)
25
26# Set up basic logging configuration
27# Try setting `DEBUG` to see detailed steps of NVRx checkpointing
28logging.basicConfig(level=logging.INFO)
29
30FEAT_SIZE = 4096
31DNN_OUT_SIZE = 128
32BATCH_SIZE = 100
33NUM_EPOCHS = 10
34DATASET_LEN = 100000
35CKPT_INTERVAL = 100
36
37
38def print_on_rank0(msg):
39 if dist.get_rank() == 0:
40 print(msg)
41
42
43def parse_args():
44 parser = argparse.ArgumentParser(description="Async Checkpointing Example")
45 parser.add_argument(
46 "--ckpt_dir",
47 default="/tmp/test_checkpointing/",
48 help="Directory for saving checkpoints",
49 )
50 parser.add_argument(
51 "--thread_count",
52 default=2,
53 type=int,
54 help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
55 )
56 parser.add_argument(
57 '--no_persistent_queue',
58 action='store_false',
59 default=True,
60 dest='persistent_queue',
61 help=(
62 "Disables a persistent version of AsyncCallsQueue. "
63 "Effective only when --async_save is set."
64 ),
65 )
66 parser.add_argument(
67 '--enable_msc',
68 action='store_true',
69 help="Enables MSC for checkpoint saving and loading. See Usage Guide in Async Checkpointing documentation for detailed instructions.",
70 )
71
72 return parser.parse_args()
73
74
75class SimpleDataset(torch.utils.data.Dataset):
76 def __init__(self, size):
77 self.size = size
78
79 def __len__(self):
80 return self.size
81
82 def __getitem__(self, idx):
83 x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cuda')
84 y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cuda')
85 return x, y
86
87
88class SimpleModel(nn.Module):
89 def __init__(self):
90 super().__init__()
91 self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
92 self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
93
94 def forward(self, x):
95 x = self.fc1(x)
96 x = nn.functional.relu(x)
97 x = self.fc2(x)
98 return x
99
100
101def init_distributed_backend(backend="nccl"):
102 """Initializes the distributed process group using the specified backend."""
103 try:
104 dist.init_process_group(backend=backend, init_method="env://")
105 rank = dist.get_rank()
106 torch.cuda.set_device(dist.get_node_local_rank())
107 logging.info(f"Process {rank} initialized with {backend} backend.")
108 return rank, torch.distributed.get_world_size()
109 except Exception as e:
110 logging.error(f"Failed to initialize distributed backend: {e}")
111 raise
112
113
114def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
115 """Creates an async save request with a finalize function."""
116 save_fn, preload_fn, save_args = writer.get_save_function_and_args()
117
118 def finalize_fn():
119 """Finalizes async checkpointing and synchronizes processes."""
120 save_state_dict_async_finalize(*save_state_dict_ret)
121 dist.barrier()
122
123 return AsyncRequest(
124 save_fn, save_args, [finalize_fn], async_fn_kwargs={}, preload_fn=preload_fn
125 )
126
127
128def save_checkpoint(checkpoint_dir, async_queue, model, thread_count, enable_msc):
129 """Asynchronously saves a model checkpoint."""
130 state_dict = model.state_dict()
131 planner = DefaultSavePlanner()
132 writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count, use_msc=enable_msc)
133 coordinator_rank = 0
134
135 save_state_dict_ret = save_state_dict_async_plan(
136 state_dict, writer, None, coordinator_rank, planner=planner, enable_cache=True
137 )
138 save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
139 async_queue.schedule_async_request(save_request)
140
141
142def load_checkpoint(checkpoint_dir, model, thread_count, enable_msc):
143 """Loads a model checkpoint synchronously."""
144 state_dict = model.state_dict()
145 if enable_msc:
146 reader = msc.torch.MultiStorageFileSystemReader(checkpoint_dir, thread_count=thread_count)
147 else:
148 reader = FileSystemReader(checkpoint_dir)
149 checkpoint.load(
150 state_dict=state_dict,
151 storage_reader=reader,
152 planner=DefaultLoadPlanner(),
153 )
154 return state_dict
155
156
157def main():
158 args = parse_args()
159 logging.info(f"Arguments: {args}")
160
161 if args.enable_msc and msc is None:
162 raise ImportError(
163 "multistorageclient is required when --enable_msc is set. "
164 "Install it with: pip install multistorageclient"
165 )
166
167 # Initialize distributed training
168 rank, world_size = init_distributed_backend(backend="nccl")
169
170 # Define checkpoint directory based on iteration number
171 dataset = SimpleDataset(size=DATASET_LEN)
172 sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
173 dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
174
175 # Model, optimizer, and FSDP wrapper
176 model = SimpleModel().to("cuda")
177 fsdp_model = FSDP(model)
178 optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
179 loss_fn = nn.MSELoss()
180
181 # Create an async queue for handling asynchronous operations
182 async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
183
184 iteration = 0
185 num_iters_in_epoch = len(dataloader)
186 print_on_rank0(f"num_iters_in_epoch: {num_iters_in_epoch}")
187
188 num_iters_for_10pct = num_iters_in_epoch // 10 # iters for 1/10 of epoch
189 checkpoint_dir = None
190 sampler.set_epoch(0)
191
192 init_checkpoint_metadata_cache()
193
194 for batch_idx, (data, target) in enumerate(dataloader):
195 async_queue.maybe_finalize_async_calls(blocking=False, no_dist=False)
196 if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
197 print(f"Epoch 0 progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
198 optimizer.zero_grad()
199 output = fsdp_model(data)
200 loss = loss_fn(output, target)
201 loss.backward()
202 optimizer.step()
203 if batch_idx % num_iters_for_10pct == 0:
204 iteration = batch_idx
205 checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
206 # Save the model asynchronously
207 save_checkpoint(
208 checkpoint_dir, async_queue, fsdp_model, args.thread_count, args.enable_msc
209 )
210 print_on_rank0(f"Checkpoint Save triggered: {checkpoint_dir}, iteration: {iteration}")
211 iteration += batch_idx
212 print_on_rank0(f"Epoch 0 complete. Loss: {loss.item()}")
213
214 logging.info("Finalizing checkpoint save...")
215 async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
216 async_queue.close() # Explicitly close queue
217
218 # Synchronize processes before loading
219 dist.barrier()
220 print_on_rank0(f"loading from {checkpoint_dir}")
221 # Load the checkpoint
222 loaded_sd = load_checkpoint(checkpoint_dir, fsdp_model, args.thread_count, args.enable_msc)
223
224 # Synchronize again to ensure all ranks have completed loading
225 dist.barrier()
226
227 # Clean up checkpoint directory (only on rank 0)
228 if dist.get_rank() == 0:
229 logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
230 if args.enable_msc:
231 msc.delete(args.ckpt_dir, recursive=True)
232 else:
233 shutil.rmtree(args.ckpt_dir)
234
235 # Ensure NCCL process group is properly destroyed
236 if dist.is_initialized():
237 dist.destroy_process_group()
238
239
240if __name__ == "__main__":
241 main()