1import argparse
2import logging
3import shutil
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8from torch.distributed import checkpoint
9from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner, FileSystemReader
10
11from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
12from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
13from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
14 save_state_dict_async_finalize,
15 save_state_dict_async_plan,
16)
17
18# Set up basic logging configuration
19logging.basicConfig(level=logging.INFO)
20
21
22def parse_args():
23 parser = argparse.ArgumentParser(description="Async Checkpointing Example")
24 parser.add_argument(
25 "--ckpt_dir",
26 default="/tmp/test_checkpointing/",
27 help="Directory for saving checkpoints",
28 )
29 parser.add_argument(
30 "--thread_count",
31 default=2,
32 type=int,
33 help="Threads to use during saving. Affects the number of files in the checkpoint (saving ranks * num_threads).",
34 )
35 parser.add_argument(
36 '--persistent_queue',
37 action='store_true',
38 help="Enables a persistent version of AsyncCallsQueue.",
39 )
40
41 return parser.parse_args()
42
43
44class SimpleModel(nn.Module):
45 """A simple feedforward neural network for demonstration purposes."""
46
47 def __init__(self):
48 super(SimpleModel, self).__init__()
49 self.fc1 = nn.Linear(10, 5)
50 self.fc2 = nn.Linear(5, 2)
51 self.activation = nn.ReLU()
52
53 def forward(self, x):
54 x = self.activation(self.fc1(x))
55 return self.fc2(x)
56
57
58def init_distributed_backend(backend="nccl"):
59 """Initializes the distributed process group using the specified backend."""
60 try:
61 dist.init_process_group(backend=backend, init_method="env://")
62 rank = dist.get_rank()
63 torch.cuda.set_device(rank)
64 logging.info(f"Process {rank} initialized with {backend} backend.")
65 except Exception as e:
66 logging.error(f"Failed to initialize distributed backend: {e}")
67 raise
68
69
70def get_save_and_finalize_callbacks(writer, save_state_dict_ret) -> AsyncRequest:
71 """Creates an async save request with a finalize function."""
72 save_fn, preload_fn, save_args = writer.get_save_function_and_args()
73
74 def finalize_fn():
75 """Finalizes async checkpointing and synchronizes processes."""
76 save_state_dict_async_finalize(*save_state_dict_ret)
77 dist.barrier()
78
79 return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
80
81
82def save_checkpoint(checkpoint_dir, async_queue, model, thread_count):
83 """Asynchronously saves a model checkpoint."""
84 state_dict = model.state_dict()
85 planner = DefaultSavePlanner()
86 writer = FileSystemWriterAsync(checkpoint_dir, thread_count=thread_count)
87 coordinator_rank = 0
88
89 save_state_dict_ret, *_ = save_state_dict_async_plan(
90 state_dict, writer, None, coordinator_rank, planner=planner
91 )
92 save_request = get_save_and_finalize_callbacks(writer, save_state_dict_ret)
93 async_queue.schedule_async_request(save_request)
94
95
96def load_checkpoint(checkpoint_dir, model):
97 """Loads a model checkpoint synchronously."""
98 state_dict = model.state_dict()
99 checkpoint.load(
100 state_dict=state_dict,
101 storage_reader=FileSystemReader(checkpoint_dir),
102 planner=DefaultLoadPlanner(),
103 )
104 return state_dict
105
106
107def main():
108 args = parse_args()
109 logging.info(f"Arguments: {args}")
110
111 # Initialize distributed training
112 init_distributed_backend(backend="nccl")
113
114 # Initialize model and move to GPU
115 model = SimpleModel().to("cuda")
116
117 # Create an async queue for handling asynchronous operations
118 async_queue = AsyncCallsQueue(persistent=args.persistent_queue)
119
120 # Define checkpoint directory based on iteration number
121 iteration = 123
122 checkpoint_dir = f"{args.ckpt_dir}/iter_{iteration:07d}"
123
124 # Save the model asynchronously
125 save_checkpoint(checkpoint_dir, async_queue, model, args.thread_count)
126
127 logging.info("Finalizing checkpoint save...")
128 async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
129 async_queue.close() # Explicitly close queue (optional)
130
131 # Synchronize processes before loading
132 dist.barrier()
133
134 # Load the checkpoint
135 loaded_sd = load_checkpoint(checkpoint_dir, model)
136
137 # Synchronize again to ensure all ranks have completed loading
138 dist.barrier()
139
140 # Clean up checkpoint directory (only on rank 0)
141 if dist.get_rank() == 0:
142 logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
143 shutil.rmtree(args.ckpt_dir)
144
145 # Ensure NCCL process group is properly destroyed
146 if dist.is_initialized():
147 dist.destroy_process_group()
148
149
150if __name__ == "__main__":
151 main()