Basic usage example

  1import argparse
  2import logging
  3import os
  4
  5import torch
  6import torch.distributed as dist
  7import torch.nn as nn
  8
  9from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint
 10
 11# Set up basic logging configuration
 12logging.basicConfig(level=logging.INFO)
 13
 14
 15def parse_args():
 16    parser = argparse.ArgumentParser(
 17        description='Async Checkpointing Basic Example',
 18        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 19    )
 20    parser.add_argument(
 21        '--ckpt_dir',
 22        default="/tmp/test_async_ckpt/",
 23        help="Checkpoint directory for async checkpoints",
 24    )
 25    parser.add_argument(
 26        '--persistent_queue',
 27        action='store_true',
 28        help="Enables a persistent version of AsyncCallsQueue.",
 29    )
 30    return parser.parse_args()
 31
 32
 33# Define a simple model
 34class SimpleModel(nn.Module):
 35    def __init__(self):
 36        super(SimpleModel, self).__init__()
 37        self.fc1 = nn.Linear(10, 5)  # Linear layer: input size 10, output size 5
 38        self.fc2 = nn.Linear(5, 2)  # Linear layer: input size 5, output size 2
 39        self.activation = nn.ReLU()  # Activation function: ReLU
 40
 41    def forward(self, x):
 42        x = self.activation(self.fc1(x))
 43        x = self.fc2(x)
 44        return x
 45
 46
 47def init_distributed_backend(backend="nccl"):
 48    """
 49    Initialize the distributed process group for NCCL backend.
 50    Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
 51    """
 52    try:
 53        dist.init_process_group(
 54            backend=backend,  # Use NCCL backend
 55            init_method="env://",  # Use environment variables for initialization
 56        )
 57        logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
 58
 59        # Ensure each process uses a different GPU
 60        torch.cuda.set_device(dist.get_rank())
 61    except Exception as e:
 62        logging.error(f"Error initializing the distributed backend: {e}")
 63        raise
 64
 65
 66def cleanup(ckpt_dir):
 67    if dist.get_rank() == 0:
 68        logging.info(f"Cleaning up checkpoint directory: {ckpt_dir}")
 69        for file_item in os.scandir(ckpt_dir):
 70            if file_item.is_file():
 71                os.remove(file_item.path)
 72
 73
 74def main():
 75    args = parse_args()
 76    logging.info(f'{args}')
 77
 78    # Initialize the distributed backend
 79    init_distributed_backend(backend="nccl")
 80
 81    # Instantiate the model and move to CUDA
 82    model = SimpleModel().to("cuda")
 83    org_sd = model.state_dict()
 84    # Define checkpoint directory and manager
 85    ckpt_dir = args.ckpt_dir
 86    if not os.path.isdir(ckpt_dir):
 87        raise Exception(f"{ckpt_dir} directory doesn't exists")
 88    ckpt_file_name = os.path.join(ckpt_dir, f"ckpt_rank{torch.distributed.get_rank()}.pt")
 89
 90    ckpt_impl = TorchAsyncCheckpoint(persistent_queue=args.persistent_queue)
 91
 92    ckpt_impl.async_save(org_sd, ckpt_file_name)
 93
 94    ckpt_impl.finalize_async_save(blocking=True, no_dist=True, terminate=True)
 95
 96    loaded_sd = torch.load(ckpt_file_name, map_location="cuda")
 97
 98    for k in loaded_sd.keys():
 99        assert torch.equal(loaded_sd[k], org_sd[k]), f"loaded_sd[{k}] != org_sd[{k}]"
100
101    # Synchronize processes to ensure all have completed the loading
102    dist.barrier()
103
104    # Clean up checkpoint directory only on rank 0
105    cleanup(ckpt_dir)
106
107    # Ensure NCCL process group is properly destroyed
108    if dist.is_initialized():
109        dist.destroy_process_group()
110
111
112if __name__ == "__main__":
113    main()