Basic usage example

  1import argparse
  2import logging
  3import shutil
  4
  5import torch
  6import torch.distributed as dist
  7import torch.nn as nn
  8
  9from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue
 10from nvidia_resiliency_ext.checkpointing.local.basic_state_dict import BasicTensorAwareStateDict
 11from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
 12    LocalCheckpointManager,
 13)
 14from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
 15    CliqueReplicationStrategy,
 16)
 17
 18# Set up basic logging configuration
 19logging.basicConfig(level=logging.INFO)
 20
 21
 22def parse_args():
 23    parser = argparse.ArgumentParser(
 24        description='Local Checkpointing Basic Example',
 25        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 26    )
 27
 28    parser.add_argument(
 29        '--ckpt_dir',
 30        default="/tmp/test_local_checkpointing/",
 31        help="Checkpoint directory for local checkpoints",
 32    )
 33    parser.add_argument(
 34        '--async_save',
 35        action='store_true',
 36        help="Enable asynchronous saving of checkpoints.",
 37    )
 38    parser.add_argument(
 39        '--replication',
 40        action='store_true',
 41        help="If set, replication of local checkpoints is enabled"
 42        "Needs to be enabled on all ranks.",
 43    )
 44    parser.add_argument(
 45        '--replication_jump',
 46        default=4,
 47        type=int,
 48        help=(
 49            "Specifies `J`, the spacing between ranks storing replicas of a given rank's data. "
 50            "Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. "
 51            "This flag has an effect only if --replication is used. "
 52            "and must be consistent across all ranks. "
 53            "The default value of 4 is for demonstration purposes and can be adjusted as needed."
 54        ),
 55    )
 56    parser.add_argument(
 57        '--replication_factor',
 58        default=2,
 59        type=int,
 60        help="Number of machines storing the replica of a given rank's data",
 61    )
 62    return parser.parse_args()
 63
 64
 65# Define a simple model
 66class SimpleModel(nn.Module):
 67    def __init__(self):
 68        super(SimpleModel, self).__init__()
 69        self.fc1 = nn.Linear(10, 5)  # Linear layer: input size 10, output size 5
 70        self.fc2 = nn.Linear(5, 2)  # Linear layer: input size 5, output size 2
 71        self.activation = nn.ReLU()  # Activation function: ReLU
 72
 73    def forward(self, x):
 74        x = self.activation(self.fc1(x))
 75        x = self.fc2(x)
 76        return x
 77
 78
 79def init_distributed_backend(backend="nccl"):
 80    """
 81    Initialize the distributed process group for NCCL backend.
 82    Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
 83    """
 84    try:
 85        dist.init_process_group(
 86            backend=backend,  # Use NCCL backend
 87            init_method="env://",  # Use environment variables for initialization
 88        )
 89        logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
 90
 91        # Ensure each process uses a different GPU
 92        torch.cuda.set_device(dist.get_rank())
 93    except Exception as e:
 94        logging.error(f"Error initializing the distributed backend: {e}")
 95        raise
 96
 97
 98def create_checkpoint_manager(args):
 99    if args.replication:
100        logging.info("Creating CliqueReplicationStrategy.")
101        repl_strategy = CliqueReplicationStrategy.from_replication_params(
102            args.replication_jump, args.replication_factor
103        )
104    else:
105        repl_strategy = None
106
107    return LocalCheckpointManager(args.ckpt_dir, repl_strategy=repl_strategy)
108
109
110def save(args, ckpt_manager, async_queue, model, iteration):
111    # Create Tensor-Aware State Dict
112    ta_state_dict = BasicTensorAwareStateDict(model.state_dict())
113
114    if args.async_save:
115        logging.info("Creating save request.")
116        save_request = ckpt_manager.save(ta_state_dict, iteration, is_async=True)
117
118        logging.info("Saving TASD checkpoint...")
119        async_queue.schedule_async_request(save_request)
120
121    else:
122        logging.info("Saving TASD checkpoint...")
123        ckpt_manager.save(ta_state_dict, iteration)
124
125
126def load(args, ckpt_manager):
127    logging.info("Loading TASD checkpoint...")
128    iteration = ckpt_manager.find_latest()
129    assert iteration != -1, "Local checkpoint has not been found"
130    logging.info(f"Found checkpoint from iteration: {iteration}")
131
132    ta_state_dict, ckpt_part_id = ckpt_manager.load()
133    logging.info(f"Successfully loaded checkpoint part (id: {ckpt_part_id})")
134    return ta_state_dict.state_dict
135
136
137def main():
138    args = parse_args()
139    logging.info(f'{args}')
140
141    # Initialize the distributed backend
142    init_distributed_backend(backend="nccl")
143
144    # Instantiate the model and move to CUDA
145    model = SimpleModel().to("cuda")
146
147    # Instantiate checkpointing classess needed for local checkpointing
148    ckpt_manager = create_checkpoint_manager(args)
149    # Persistent queue is incompatible with local checkpointing because some routines are not pickleable.
150    async_queue = AsyncCallsQueue(persistent=False) if args.async_save else None
151
152    iteration = 123  # training iteration (used as training state id)
153
154    # Local checkpointing save
155    save(args, ckpt_manager, async_queue, model, iteration)
156
157    if args.async_save:
158        # Other operations can happen here
159
160        logging.info("Finalize TASD checkpoint saving.")
161        async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
162        async_queue.close()  # Explicitly close queue (optional)
163
164    # Synchronize processes to ensure all have completed the saving
165    dist.barrier()
166
167    # Local checkpointing load
168    load(args, ckpt_manager)
169
170    # Synchronize processes to ensure all have completed the loading
171    dist.barrier()
172
173    # Clean up checkpoint directory only on rank 0
174    if dist.get_rank() == 0:
175        logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
176        shutil.rmtree(args.ckpt_dir)
177
178    # Ensure NCCL process group is properly destroyed
179    if dist.is_initialized():
180        dist.destroy_process_group()
181
182
183if __name__ == "__main__":
184    main()