Basic usage example

 1import argparse
 2import logging
 3import os
 4import shutil
 5from typing import Union
 6
 7import torch
 8import torch.distributed as dist
 9import torch.nn as nn
10
11from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint
12
13# Set up basic logging configuration
14logging.basicConfig(level=logging.INFO)
15
16def parse_args():
17    parser = argparse.ArgumentParser(
18        description='Local Checkpointing Basic Example',
19        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
20    )
21
22    return parser.parse_args()
23
24
25# Define a simple model
26class SimpleModel(nn.Module):
27    def __init__(self):
28        super(SimpleModel, self).__init__()
29        self.fc1 = nn.Linear(10, 5)  # Linear layer: input size 10, output size 5
30        self.fc2 = nn.Linear(5, 2)  # Linear layer: input size 5, output size 2
31        self.activation = nn.ReLU()  # Activation function: ReLU
32
33    def forward(self, x):
34        x = self.activation(self.fc1(x))
35        x = self.fc2(x)
36        return x
37
38
39def init_distributed_backend(backend="nccl"):
40    """
41    Initialize the distributed process group for NCCL backend.
42    Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
43    """
44    try:
45        dist.init_process_group(
46            backend=backend,  # Use NCCL backend
47            init_method="env://",  # Use environment variables for initialization
48        )
49        logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
50
51        # Ensure each process uses a different GPU
52        torch.cuda.set_device(dist.get_rank())
53    except Exception as e:
54        logging.error(f"Error initializing the distributed backend: {e}")
55        raise
56
57
58def main():
59    args = parse_args()
60    logging.info(f'{args}')
61
62    # Initialize the distributed backend
63    init_distributed_backend(backend="nccl")
64
65    # Instantiate the model and move to CUDA
66    model = SimpleModel().to("cuda")
67
68    # Define checkpoint directory and manager
69    ckpt_dir = "/tmp/test_local_checkpointing/ckpt.pt"
70  
71    ckpt_impl = TorchAsyncCheckpoint()
72
73    ckpt_impl.async_save(model.state_dict(), ckpt_dir + "ckpt.pt")
74
75    finalize_async_save(blocking=True, no_dist=True)
76
77    torch.load(ckpt_dir, ckpt_dir)
78
79    # Synchronize processes to ensure all have completed the loading
80    dist.barrier()
81    
82    # Clean up checkpoint directory only on rank 0
83    if dist.get_rank() == 0:
84        logging.info(f"Cleaning up checkpoint directory: {ckpt_dir}")
85        shutil.rmtree(ckpt_dir)
86
87
88if __name__ == "__main__":
89    main()