1import argparse
2import logging
3import shutil
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8
9from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue
10from nvidia_resiliency_ext.checkpointing.local.basic_state_dict import BasicTensorAwareStateDict
11from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
12 LocalCheckpointManager,
13)
14from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
15 CliqueReplicationStrategy,
16)
17
18# Set up basic logging configuration
19logging.basicConfig(level=logging.INFO)
20
21
22def parse_args():
23 parser = argparse.ArgumentParser(
24 description='Local Checkpointing Basic Example',
25 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
26 )
27
28 parser.add_argument(
29 '--ckpt_dir',
30 default="/tmp/test_local_checkpointing/",
31 help="Checkpoint directory for local checkpoints",
32 )
33 parser.add_argument(
34 '--async_save',
35 action='store_true',
36 help="Enable asynchronous saving of checkpoints.",
37 )
38 parser.add_argument(
39 '--replication',
40 action='store_true',
41 help="If set, replication of local checkpoints is enabled"
42 "Needs to be enabled on all ranks.",
43 )
44 parser.add_argument(
45 '--persistent_queue',
46 action='store_true',
47 help=(
48 "Enables a persistent version of AsyncCallsQueue. "
49 "Effective only when --async_save is set."
50 ),
51 )
52 parser.add_argument(
53 '--replication_jump',
54 default=4,
55 type=int,
56 help=(
57 "Specifies `J`, the spacing between ranks storing replicas of a given rank's data. "
58 "Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. "
59 "This flag has an effect only if --replication is used. "
60 "and must be consistent across all ranks. "
61 "The default value of 4 is for demonstration purposes and can be adjusted as needed."
62 ),
63 )
64 parser.add_argument(
65 '--replication_factor',
66 default=2,
67 type=int,
68 help="Number of machines storing the replica of a given rank's data",
69 )
70 return parser.parse_args()
71
72
73# Define a simple model
74class SimpleModel(nn.Module):
75 def __init__(self):
76 super(SimpleModel, self).__init__()
77 self.fc1 = nn.Linear(10, 5) # Linear layer: input size 10, output size 5
78 self.fc2 = nn.Linear(5, 2) # Linear layer: input size 5, output size 2
79 self.activation = nn.ReLU() # Activation function: ReLU
80
81 def forward(self, x):
82 x = self.activation(self.fc1(x))
83 x = self.fc2(x)
84 return x
85
86
87def init_distributed_backend(backend="nccl"):
88 """
89 Initialize the distributed process group for NCCL backend.
90 Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
91 """
92 try:
93 dist.init_process_group(
94 backend=backend, # Use NCCL backend
95 init_method="env://", # Use environment variables for initialization
96 )
97 logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
98
99 # Ensure each process uses a different GPU
100 torch.cuda.set_device(dist.get_rank())
101 except Exception as e:
102 logging.error(f"Error initializing the distributed backend: {e}")
103 raise
104
105
106def create_checkpoint_manager(args):
107 if args.replication:
108 logging.info("Creating CliqueReplicationStrategy.")
109 repl_strategy = CliqueReplicationStrategy.from_replication_params(
110 args.replication_jump, args.replication_factor
111 )
112 else:
113 repl_strategy = None
114
115 return LocalCheckpointManager(args.ckpt_dir, repl_strategy=repl_strategy)
116
117
118def save(args, ckpt_manager, async_queue, model, iteration):
119 # Create Tensor-Aware State Dict
120 ta_state_dict = BasicTensorAwareStateDict(model.state_dict())
121
122 if args.async_save:
123 logging.info("Creating save request.")
124 save_request = ckpt_manager.save(ta_state_dict, iteration, is_async=True)
125
126 logging.info("Saving TASD checkpoint...")
127 async_queue.schedule_async_request(save_request)
128
129 else:
130 logging.info("Saving TASD checkpoint...")
131 ckpt_manager.save(ta_state_dict, iteration)
132
133
134def load(args, ckpt_manager):
135 logging.info("Loading TASD checkpoint...")
136 iteration = ckpt_manager.find_latest()
137 assert iteration != -1, "Local checkpoint has not been found"
138 logging.info(f"Found checkpoint from iteration: {iteration}")
139
140 ta_state_dict, ckpt_part_id = ckpt_manager.load()
141 logging.info(f"Successfully loaded checkpoint part (id: {ckpt_part_id})")
142 return ta_state_dict.state_dict
143
144
145def main():
146 args = parse_args()
147 assert (
148 not args.persistent_queue or args.async_save
149 ), "--persistent_queue requires --async_save to be enabled."
150 assert (
151 not args.persistent_queue or not args.replication
152 ), "persistent_queue is currently incompatible with replication due to object pickling issues."
153 logging.info(f'{args}')
154
155 # Initialize the distributed backend
156 init_distributed_backend(backend="nccl")
157
158 # Instantiate the model and move to CUDA
159 model = SimpleModel().to("cuda")
160
161 # Instantiate checkpointing classess needed for local checkpointing
162 ckpt_manager = create_checkpoint_manager(args)
163 async_queue = AsyncCallsQueue(persistent=args.persistent_queue) if args.async_save else None
164
165 iteration = 123 # training iteration (used as training state id)
166
167 # Local checkpointing save
168 save(args, ckpt_manager, async_queue, model, iteration)
169
170 if args.async_save:
171 # Other operations can happen here
172
173 logging.info("Finalize TASD checkpoint saving.")
174 async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
175 async_queue.close() # Explicitly close queue (optional)
176
177 # Synchronize processes to ensure all have completed the saving
178 dist.barrier()
179
180 # Local checkpointing load
181 load(args, ckpt_manager)
182
183 # Synchronize processes to ensure all have completed the loading
184 dist.barrier()
185
186 # Clean up checkpoint directory only on rank 0
187 if dist.get_rank() == 0:
188 logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
189 shutil.rmtree(args.ckpt_dir)
190
191
192if __name__ == "__main__":
193 main()