Basic usage example

  1import argparse
  2import logging
  3import os
  4import shutil
  5from typing import Union
  6
  7import torch
  8import torch.distributed as dist
  9import torch.nn as nn
 10
 11from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue
 12from nvidia_resiliency_ext.checkpointing.local.basic_state_dict import BasicTensorAwareStateDict
 13from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
 14    LocalCheckpointManager,
 15)
 16from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
 17    CliqueReplicationStrategy,
 18)
 19
 20# Set up basic logging configuration
 21logging.basicConfig(level=logging.INFO)
 22
 23
 24def parse_args():
 25    parser = argparse.ArgumentParser(
 26        description='Local Checkpointing Basic Example',
 27        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 28    )
 29
 30    parser.add_argument(
 31        '--ckpt_dir',
 32        default="/tmp/test_local_checkpointing/",
 33        help="Checkpoint directory for local checkpoints",
 34    )
 35    parser.add_argument(
 36        '--async_save',
 37        action='store_true',
 38        help="Enable asynchronous saving of checkpoints.",
 39    )
 40    parser.add_argument(
 41        '--replication',
 42        action='store_true',
 43        help="If set, replication of local checkpoints is enabled"
 44        "Needs to be enabled on all ranks."
 45    )
 46    parser.add_argument(
 47        '--replication_jump',
 48        default=4,
 49        type=int,
 50        help=(
 51            "Specifies `J`, the spacing between ranks storing replicas of a given rank's data. "
 52            "Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. "
 53            "This flag has an effect only if --replication is used. "
 54            "and must be consistent across all ranks. "
 55            "The default value of 4 is for demonstration purposes and can be adjusted as needed."
 56        ),
 57    )
 58    parser.add_argument(
 59        '--replication_factor',
 60        default=2,
 61        type=int,
 62        help="Number of machines storing the replica of a given rank's data",
 63    )
 64    return parser.parse_args()
 65
 66
 67# Define a simple model
 68class SimpleModel(nn.Module):
 69    def __init__(self):
 70        super(SimpleModel, self).__init__()
 71        self.fc1 = nn.Linear(10, 5)  # Linear layer: input size 10, output size 5
 72        self.fc2 = nn.Linear(5, 2)  # Linear layer: input size 5, output size 2
 73        self.activation = nn.ReLU()  # Activation function: ReLU
 74
 75    def forward(self, x):
 76        x = self.activation(self.fc1(x))
 77        x = self.fc2(x)
 78        return x
 79
 80
 81def init_distributed_backend(backend="nccl"):
 82    """
 83    Initialize the distributed process group for NCCL backend.
 84    Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
 85    """
 86    try:
 87        dist.init_process_group(
 88            backend=backend,  # Use NCCL backend
 89            init_method="env://",  # Use environment variables for initialization
 90        )
 91        logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
 92
 93        # Ensure each process uses a different GPU
 94        torch.cuda.set_device(dist.get_rank())
 95    except Exception as e:
 96        logging.error(f"Error initializing the distributed backend: {e}")
 97        raise
 98
 99
100def create_checkpoint_manager(args):
101    if args.replication:
102        logging.info("Creating CliqueReplicationStrategy.")
103        repl_strategy = CliqueReplicationStrategy.from_replication_params(
104            args.replication_jump, args.replication_factor
105        )
106    else:
107        repl_strategy = None
108
109    return LocalCheckpointManager(args.ckpt_dir, repl_strategy=repl_strategy)
110
111
112def save(args, ckpt_manager, async_queue, model, iteration):
113    # Create Tensor-Aware State Dict
114    ta_state_dict = BasicTensorAwareStateDict(model.state_dict())
115
116    if args.async_save:
117        logging.info("Creating save request.")
118        save_request = ckpt_manager.save(ta_state_dict, iteration, is_async=True)
119
120        logging.info("Saving TASD checkpoint...")
121        async_queue.schedule_async_request(save_request)
122
123    else:
124        logging.info("Saving TASD checkpoint...")
125        ckpt_manager.save(ta_state_dict, iteration)
126
127
128def load(args, ckpt_manager):
129    logging.info("Loading TASD checkpoint...")
130    iteration = ckpt_manager.find_latest()
131    assert iteration != -1, "Local checkpoint has not been found"
132    logging.info(f"Found checkpoint from iteration: {iteration}")
133
134    ta_state_dict, ckpt_part_id = ckpt_manager.load()
135    logging.info(f"Successfully loaded checkpoint part (id: {ckpt_part_id})")
136    return ta_state_dict.state_dict
137
138
139def main():
140    args = parse_args()
141    logging.info(f'{args}')
142
143    # Initialize the distributed backend
144    init_distributed_backend(backend="nccl")
145
146    # Instantiate the model and move to CUDA
147    model = SimpleModel().to("cuda")
148
149    # Instantiate checkpointing classess needed for local checkpointing
150    ckpt_manager = create_checkpoint_manager(args)
151    async_queue = AsyncCallsQueue() if args.async_save else None
152
153    iteration = 123  # training iteration (used as training state id)
154
155    # Local checkpointing save
156    save(args, ckpt_manager, async_queue, model, iteration)
157
158    if args.async_save:
159        # Other operations can happen here
160
161        logging.info("Finalize TASD checkpoint saving.")
162        async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
163
164    # Synchronize processes to ensure all have completed the saving
165    dist.barrier()
166
167    # Local checkpointing load
168    load(args, ckpt_manager)
169
170    # Synchronize processes to ensure all have completed the loading
171    dist.barrier()
172
173    # Clean up checkpoint directory only on rank 0
174    if dist.get_rank() == 0:
175        logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
176        shutil.rmtree(args.ckpt_dir)
177
178
179if __name__ == "__main__":
180    main()