Basic usage example

  1import argparse
  2import logging
  3import os
  4
  5import torch
  6import torch.distributed as dist
  7import torch.nn as nn
  8
  9from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint
 10
 11# Set up basic logging configuration
 12logging.basicConfig(level=logging.INFO)
 13
 14
 15def parse_args():
 16    parser = argparse.ArgumentParser(
 17        description='Async Checkpointing Basic Example',
 18        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 19    )
 20    parser.add_argument(
 21        '--ckpt_dir',
 22        default="/tmp/test_async_ckpt/",
 23        help="Checkpoint directory for async checkpoints",
 24    )
 25    parser.add_argument(
 26        '--no_persistent_queue',
 27        action='store_false',
 28        default=True,
 29        dest='persistent_queue',
 30        help=(
 31            "Disables a persistent version of AsyncCallsQueue. "
 32            "Effective only when --async_save is set."
 33        ),
 34    )
 35    return parser.parse_args()
 36
 37
 38# Define a simple model
 39class SimpleModel(nn.Module):
 40    def __init__(self):
 41        super(SimpleModel, self).__init__()
 42        self.fc1 = nn.Linear(10, 5)  # Linear layer: input size 10, output size 5
 43        self.fc2 = nn.Linear(5, 2)  # Linear layer: input size 5, output size 2
 44        self.activation = nn.ReLU()  # Activation function: ReLU
 45
 46    def forward(self, x):
 47        x = self.activation(self.fc1(x))
 48        x = self.fc2(x)
 49        return x
 50
 51
 52def init_distributed_backend(backend="nccl"):
 53    """
 54    Initialize the distributed process group for NCCL backend.
 55    Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
 56    """
 57    try:
 58        dist.init_process_group(
 59            backend=backend,  # Use NCCL backend
 60            init_method="env://",  # Use environment variables for initialization
 61        )
 62        logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
 63
 64        # Ensure each process uses a different GPU
 65        torch.cuda.set_device(dist.get_rank())
 66    except Exception as e:
 67        logging.error(f"Error initializing the distributed backend: {e}")
 68        raise
 69
 70
 71def cleanup(ckpt_dir):
 72    if dist.get_rank() == 0:
 73        logging.info(f"Cleaning up checkpoint directory: {ckpt_dir}")
 74        for file_item in os.scandir(ckpt_dir):
 75            if file_item.is_file():
 76                os.remove(file_item.path)
 77
 78
 79def main():
 80    args = parse_args()
 81    logging.info(f'{args}')
 82
 83    # Initialize the distributed backend
 84    init_distributed_backend(backend="nccl")
 85
 86    # Instantiate the model and move to CUDA
 87    model = SimpleModel().to("cuda")
 88    org_sd = model.state_dict()
 89    # Define checkpoint directory and manager
 90    ckpt_dir = args.ckpt_dir
 91    os.makedirs(ckpt_dir, exist_ok=True)
 92    logging.info(f"Created checkpoint directory: {ckpt_dir}")
 93    ckpt_file_name = os.path.join(ckpt_dir, f"ckpt_rank{torch.distributed.get_rank()}.pt")
 94
 95    ckpt_impl = TorchAsyncCheckpoint(persistent_queue=args.persistent_queue)
 96
 97    ckpt_impl.async_save(org_sd, ckpt_file_name)
 98
 99    ckpt_impl.finalize_async_save(blocking=True, no_dist=True, terminate=True)
100
101    loaded_sd = torch.load(ckpt_file_name, map_location="cuda")
102
103    for k in loaded_sd.keys():
104        assert torch.equal(loaded_sd[k], org_sd[k]), f"loaded_sd[{k}] != org_sd[{k}]"
105
106    # Synchronize processes to ensure all have completed the loading
107    dist.barrier()
108
109    # Clean up checkpoint directory only on rank 0
110    cleanup(ckpt_dir)
111
112    # Ensure NCCL process group is properly destroyed
113    if dist.is_initialized():
114        dist.destroy_process_group()
115
116
117if __name__ == "__main__":
118    main()