1import argparse
2import logging
3import os
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8
9from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint
10
11# Set up basic logging configuration
12logging.basicConfig(level=logging.INFO)
13
14
15def parse_args():
16 parser = argparse.ArgumentParser(
17 description='Async Checkpointing Basic Example',
18 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
19 )
20 parser.add_argument(
21 '--ckpt_dir',
22 default="/tmp/test_async_ckpt/",
23 help="Checkpoint directory for async checkpoints",
24 )
25 parser.add_argument(
26 '--no_persistent_queue',
27 action='store_false',
28 default=True,
29 dest='persistent_queue',
30 help=(
31 "Disables a persistent version of AsyncCallsQueue. "
32 "Effective only when --async_save is set."
33 ),
34 )
35 return parser.parse_args()
36
37
38# Define a simple model
39class SimpleModel(nn.Module):
40 def __init__(self):
41 super(SimpleModel, self).__init__()
42 self.fc1 = nn.Linear(10, 5) # Linear layer: input size 10, output size 5
43 self.fc2 = nn.Linear(5, 2) # Linear layer: input size 5, output size 2
44 self.activation = nn.ReLU() # Activation function: ReLU
45
46 def forward(self, x):
47 x = self.activation(self.fc1(x))
48 x = self.fc2(x)
49 return x
50
51
52def init_distributed_backend(backend="nccl"):
53 """
54 Initialize the distributed process group for NCCL backend.
55 Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
56 """
57 try:
58 dist.init_process_group(
59 backend=backend, # Use NCCL backend
60 init_method="env://", # Use environment variables for initialization
61 )
62 logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
63
64 # Ensure each process uses a different GPU
65 torch.cuda.set_device(dist.get_rank())
66 except Exception as e:
67 logging.error(f"Error initializing the distributed backend: {e}")
68 raise
69
70
71def cleanup(ckpt_dir):
72 if dist.get_rank() == 0:
73 logging.info(f"Cleaning up checkpoint directory: {ckpt_dir}")
74 for file_item in os.scandir(ckpt_dir):
75 if file_item.is_file():
76 os.remove(file_item.path)
77
78
79def main():
80 args = parse_args()
81 logging.info(f'{args}')
82
83 # Initialize the distributed backend
84 init_distributed_backend(backend="nccl")
85
86 # Instantiate the model and move to CUDA
87 model = SimpleModel().to("cuda")
88 org_sd = model.state_dict()
89 # Define checkpoint directory and manager
90 ckpt_dir = args.ckpt_dir
91 os.makedirs(ckpt_dir, exist_ok=True)
92 logging.info(f"Created checkpoint directory: {ckpt_dir}")
93 ckpt_file_name = os.path.join(ckpt_dir, f"ckpt_rank{torch.distributed.get_rank()}.pt")
94
95 ckpt_impl = TorchAsyncCheckpoint(persistent_queue=args.persistent_queue)
96
97 ckpt_impl.async_save(org_sd, ckpt_file_name)
98
99 ckpt_impl.finalize_async_save(blocking=True, no_dist=True, terminate=True)
100
101 loaded_sd = torch.load(ckpt_file_name, map_location="cuda")
102
103 for k in loaded_sd.keys():
104 assert torch.equal(loaded_sd[k], org_sd[k]), f"loaded_sd[{k}] != org_sd[{k}]"
105
106 # Synchronize processes to ensure all have completed the loading
107 dist.barrier()
108
109 # Clean up checkpoint directory only on rank 0
110 cleanup(ckpt_dir)
111
112 # Ensure NCCL process group is properly destroyed
113 if dist.is_initialized():
114 dist.destroy_process_group()
115
116
117if __name__ == "__main__":
118 main()