1import argparse
2import logging
3import os
4import shutil
5from typing import Union
6
7import torch
8import torch.distributed as dist
9import torch.nn as nn
10
11from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue
12from nvidia_resiliency_ext.checkpointing.local.basic_state_dict import BasicTensorAwareStateDict
13from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
14 LocalCheckpointManager,
15)
16from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
17 CliqueReplicationStrategy,
18)
19
20# Set up basic logging configuration
21logging.basicConfig(level=logging.INFO)
22
23
24def parse_args():
25 parser = argparse.ArgumentParser(
26 description='Local Checkpointing Basic Example',
27 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
28 )
29
30 parser.add_argument(
31 '--ckpt_dir',
32 default="/tmp/test_local_checkpointing/",
33 help="Checkpoint directory for local checkpoints",
34 )
35 parser.add_argument(
36 '--async_save',
37 action='store_true',
38 help="Enable asynchronous saving of checkpoints.",
39 )
40 parser.add_argument(
41 '--replication',
42 action='store_true',
43 help="If set, replication of local checkpoints is enabled"
44 "Needs to be enabled on all ranks."
45 )
46 parser.add_argument(
47 '--replication_jump',
48 default=4,
49 type=int,
50 help=(
51 "Specifies `J`, the spacing between ranks storing replicas of a given rank's data. "
52 "Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. "
53 "This flag has an effect only if --replication is used. "
54 "and must be consistent across all ranks. "
55 "The default value of 4 is for demonstration purposes and can be adjusted as needed."
56 ),
57 )
58 parser.add_argument(
59 '--replication_factor',
60 default=2,
61 type=int,
62 help="Number of machines storing the replica of a given rank's data",
63 )
64 return parser.parse_args()
65
66
67# Define a simple model
68class SimpleModel(nn.Module):
69 def __init__(self):
70 super(SimpleModel, self).__init__()
71 self.fc1 = nn.Linear(10, 5) # Linear layer: input size 10, output size 5
72 self.fc2 = nn.Linear(5, 2) # Linear layer: input size 5, output size 2
73 self.activation = nn.ReLU() # Activation function: ReLU
74
75 def forward(self, x):
76 x = self.activation(self.fc1(x))
77 x = self.fc2(x)
78 return x
79
80
81def init_distributed_backend(backend="nccl"):
82 """
83 Initialize the distributed process group for NCCL backend.
84 Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
85 """
86 try:
87 dist.init_process_group(
88 backend=backend, # Use NCCL backend
89 init_method="env://", # Use environment variables for initialization
90 )
91 logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
92
93 # Ensure each process uses a different GPU
94 torch.cuda.set_device(dist.get_rank())
95 except Exception as e:
96 logging.error(f"Error initializing the distributed backend: {e}")
97 raise
98
99
100def create_checkpoint_manager(args):
101 if args.replication:
102 logging.info("Creating CliqueReplicationStrategy.")
103 repl_strategy = CliqueReplicationStrategy.from_replication_params(
104 args.replication_jump, args.replication_factor
105 )
106 else:
107 repl_strategy = None
108
109 return LocalCheckpointManager(args.ckpt_dir, repl_strategy=repl_strategy)
110
111
112def save(args, ckpt_manager, async_queue, model, iteration):
113 # Create Tensor-Aware State Dict
114 ta_state_dict = BasicTensorAwareStateDict(model.state_dict())
115
116 if args.async_save:
117 logging.info("Creating save request.")
118 save_request = ckpt_manager.save(ta_state_dict, iteration, is_async=True)
119
120 logging.info("Saving TASD checkpoint...")
121 async_queue.schedule_async_request(save_request)
122
123 else:
124 logging.info("Saving TASD checkpoint...")
125 ckpt_manager.save(ta_state_dict, iteration)
126
127
128def load(args, ckpt_manager):
129 logging.info("Loading TASD checkpoint...")
130 iteration = ckpt_manager.find_latest()
131 assert iteration != -1, "Local checkpoint has not been found"
132 logging.info(f"Found checkpoint from iteration: {iteration}")
133
134 ta_state_dict, ckpt_part_id = ckpt_manager.load()
135 logging.info(f"Successfully loaded checkpoint part (id: {ckpt_part_id})")
136 return ta_state_dict.state_dict
137
138
139def main():
140 args = parse_args()
141 logging.info(f'{args}')
142
143 # Initialize the distributed backend
144 init_distributed_backend(backend="nccl")
145
146 # Instantiate the model and move to CUDA
147 model = SimpleModel().to("cuda")
148
149 # Instantiate checkpointing classess needed for local checkpointing
150 ckpt_manager = create_checkpoint_manager(args)
151 async_queue = AsyncCallsQueue() if args.async_save else None
152
153 iteration = 123 # training iteration (used as training state id)
154
155 # Local checkpointing save
156 save(args, ckpt_manager, async_queue, model, iteration)
157
158 if args.async_save:
159 # Other operations can happen here
160
161 logging.info("Finalize TASD checkpoint saving.")
162 async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
163
164 # Synchronize processes to ensure all have completed the saving
165 dist.barrier()
166
167 # Local checkpointing load
168 load(args, ckpt_manager)
169
170 # Synchronize processes to ensure all have completed the loading
171 dist.barrier()
172
173 # Clean up checkpoint directory only on rank 0
174 if dist.get_rank() == 0:
175 logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
176 shutil.rmtree(args.ckpt_dir)
177
178
179if __name__ == "__main__":
180 main()