1import argparse
2import logging
3import os
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8
9from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint
10
11# Set up basic logging configuration
12logging.basicConfig(level=logging.INFO)
13
14
15def parse_args():
16 parser = argparse.ArgumentParser(
17 description='Async Checkpointing Basic Example',
18 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
19 )
20 parser.add_argument(
21 '--ckpt_dir',
22 default="/tmp/test_async_ckpt/",
23 help="Checkpoint directory for async checkpoints",
24 )
25 parser.add_argument(
26 '--persistent_queue',
27 action='store_true',
28 help="Enables a persistent version of AsyncCallsQueue.",
29 )
30 return parser.parse_args()
31
32
33# Define a simple model
34class SimpleModel(nn.Module):
35 def __init__(self):
36 super(SimpleModel, self).__init__()
37 self.fc1 = nn.Linear(10, 5) # Linear layer: input size 10, output size 5
38 self.fc2 = nn.Linear(5, 2) # Linear layer: input size 5, output size 2
39 self.activation = nn.ReLU() # Activation function: ReLU
40
41 def forward(self, x):
42 x = self.activation(self.fc1(x))
43 x = self.fc2(x)
44 return x
45
46
47def init_distributed_backend(backend="nccl"):
48 """
49 Initialize the distributed process group for NCCL backend.
50 Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
51 """
52 try:
53 dist.init_process_group(
54 backend=backend, # Use NCCL backend
55 init_method="env://", # Use environment variables for initialization
56 )
57 logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
58
59 # Ensure each process uses a different GPU
60 torch.cuda.set_device(dist.get_rank())
61 except Exception as e:
62 logging.error(f"Error initializing the distributed backend: {e}")
63 raise
64
65
66def cleanup(ckpt_dir):
67 if dist.get_rank() == 0:
68 logging.info(f"Cleaning up checkpoint directory: {ckpt_dir}")
69 for file_item in os.scandir(ckpt_dir):
70 if file_item.is_file():
71 os.remove(file_item.path)
72
73
74def main():
75 args = parse_args()
76 logging.info(f'{args}')
77
78 # Initialize the distributed backend
79 init_distributed_backend(backend="nccl")
80
81 # Instantiate the model and move to CUDA
82 model = SimpleModel().to("cuda")
83 org_sd = model.state_dict()
84 # Define checkpoint directory and manager
85 ckpt_dir = args.ckpt_dir
86 if not os.path.isdir(ckpt_dir):
87 raise Exception(f"{ckpt_dir} directory doesn't exists")
88 ckpt_file_name = os.path.join(ckpt_dir, f"ckpt_rank{torch.distributed.get_rank()}.pt")
89
90 ckpt_impl = TorchAsyncCheckpoint(persistent_queue=args.persistent_queue)
91
92 ckpt_impl.async_save(org_sd, ckpt_file_name)
93
94 ckpt_impl.finalize_async_save(blocking=True, no_dist=True, terminate=True)
95
96 loaded_sd = torch.load(ckpt_file_name, map_location="cuda")
97
98 for k in loaded_sd.keys():
99 assert torch.equal(loaded_sd[k], org_sd[k]), f"loaded_sd[{k}] != org_sd[{k}]"
100
101 # Synchronize processes to ensure all have completed the loading
102 dist.barrier()
103
104 # Clean up checkpoint directory only on rank 0
105 cleanup(ckpt_dir)
106
107 # Ensure NCCL process group is properly destroyed
108 if dist.is_initialized():
109 dist.destroy_process_group()
110
111
112if __name__ == "__main__":
113 main()