1import argparse
2import logging
3import shutil
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8
9from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue
10from nvidia_resiliency_ext.checkpointing.local.basic_state_dict import BasicTensorAwareStateDict
11from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
12 LocalCheckpointManager,
13)
14from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
15 CliqueReplicationStrategy,
16)
17
18# Set up basic logging configuration
19logging.basicConfig(level=logging.INFO)
20
21
22def parse_args():
23 parser = argparse.ArgumentParser(
24 description='Local Checkpointing Basic Example',
25 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
26 )
27
28 parser.add_argument(
29 '--ckpt_dir',
30 default="/tmp/test_local_checkpointing/",
31 help="Checkpoint directory for local checkpoints",
32 )
33 parser.add_argument(
34 '--async_save',
35 action='store_true',
36 help="Enable asynchronous saving of checkpoints.",
37 )
38 parser.add_argument(
39 '--replication',
40 action='store_true',
41 help="If set, replication of local checkpoints is enabled"
42 "Needs to be enabled on all ranks.",
43 )
44 parser.add_argument(
45 '--replication_jump',
46 default=4,
47 type=int,
48 help=(
49 "Specifies `J`, the spacing between ranks storing replicas of a given rank's data. "
50 "Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. "
51 "This flag has an effect only if --replication is used. "
52 "and must be consistent across all ranks. "
53 "The default value of 4 is for demonstration purposes and can be adjusted as needed."
54 ),
55 )
56 parser.add_argument(
57 '--replication_factor',
58 default=2,
59 type=int,
60 help="Number of machines storing the replica of a given rank's data",
61 )
62 return parser.parse_args()
63
64
65# Define a simple model
66class SimpleModel(nn.Module):
67 def __init__(self):
68 super(SimpleModel, self).__init__()
69 self.fc1 = nn.Linear(10, 5) # Linear layer: input size 10, output size 5
70 self.fc2 = nn.Linear(5, 2) # Linear layer: input size 5, output size 2
71 self.activation = nn.ReLU() # Activation function: ReLU
72
73 def forward(self, x):
74 x = self.activation(self.fc1(x))
75 x = self.fc2(x)
76 return x
77
78
79def init_distributed_backend(backend="nccl"):
80 """
81 Initialize the distributed process group for NCCL backend.
82 Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
83 """
84 try:
85 dist.init_process_group(
86 backend=backend, # Use NCCL backend
87 init_method="env://", # Use environment variables for initialization
88 )
89 logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
90
91 # Ensure each process uses a different GPU
92 torch.cuda.set_device(dist.get_rank())
93 except Exception as e:
94 logging.error(f"Error initializing the distributed backend: {e}")
95 raise
96
97
98def create_checkpoint_manager(args):
99 if args.replication:
100 logging.info("Creating CliqueReplicationStrategy.")
101 repl_strategy = CliqueReplicationStrategy.from_replication_params(
102 args.replication_jump, args.replication_factor
103 )
104 else:
105 repl_strategy = None
106
107 return LocalCheckpointManager(args.ckpt_dir, repl_strategy=repl_strategy)
108
109
110def save(args, ckpt_manager, async_queue, model, iteration):
111 # Create Tensor-Aware State Dict
112 ta_state_dict = BasicTensorAwareStateDict(model.state_dict())
113
114 if args.async_save:
115 logging.info("Creating save request.")
116 save_request = ckpt_manager.save(ta_state_dict, iteration, is_async=True)
117
118 logging.info("Saving TASD checkpoint...")
119 async_queue.schedule_async_request(save_request)
120
121 else:
122 logging.info("Saving TASD checkpoint...")
123 ckpt_manager.save(ta_state_dict, iteration)
124
125
126def load(args, ckpt_manager):
127 logging.info("Loading TASD checkpoint...")
128 iteration = ckpt_manager.find_latest()
129 assert iteration != -1, "Local checkpoint has not been found"
130 logging.info(f"Found checkpoint from iteration: {iteration}")
131
132 ta_state_dict, ckpt_part_id = ckpt_manager.load()
133 logging.info(f"Successfully loaded checkpoint part (id: {ckpt_part_id})")
134 return ta_state_dict.state_dict
135
136
137def main():
138 args = parse_args()
139 logging.info(f'{args}')
140
141 # Initialize the distributed backend
142 init_distributed_backend(backend="nccl")
143
144 # Instantiate the model and move to CUDA
145 model = SimpleModel().to("cuda")
146
147 # Instantiate checkpointing classess needed for local checkpointing
148 ckpt_manager = create_checkpoint_manager(args)
149 # Persistent queue is incompatible with local checkpointing because some routines are not pickleable.
150 async_queue = AsyncCallsQueue(persistent=False) if args.async_save else None
151
152 iteration = 123 # training iteration (used as training state id)
153
154 # Local checkpointing save
155 save(args, ckpt_manager, async_queue, model, iteration)
156
157 if args.async_save:
158 # Other operations can happen here
159
160 logging.info("Finalize TASD checkpoint saving.")
161 async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
162 async_queue.close() # Explicitly close queue (optional)
163
164 # Synchronize processes to ensure all have completed the saving
165 dist.barrier()
166
167 # Local checkpointing load
168 load(args, ckpt_manager)
169
170 # Synchronize processes to ensure all have completed the loading
171 dist.barrier()
172
173 # Clean up checkpoint directory only on rank 0
174 if dist.get_rank() == 0:
175 logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
176 shutil.rmtree(args.ckpt_dir)
177
178 # Ensure NCCL process group is properly destroyed
179 if dist.is_initialized():
180 dist.destroy_process_group()
181
182
183if __name__ == "__main__":
184 main()