Basic usage example

  1import argparse
  2import logging
  3import shutil
  4
  5import torch
  6import torch.distributed as dist
  7import torch.nn as nn
  8
  9from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue
 10from nvidia_resiliency_ext.checkpointing.local.basic_state_dict import BasicTensorAwareStateDict
 11from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
 12    LocalCheckpointManager,
 13)
 14from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
 15    CliqueReplicationStrategy,
 16)
 17
 18# Set up basic logging configuration
 19logging.basicConfig(level=logging.INFO)
 20
 21
 22def parse_args():
 23    parser = argparse.ArgumentParser(
 24        description='Local Checkpointing Basic Example',
 25        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 26    )
 27
 28    parser.add_argument(
 29        '--ckpt_dir',
 30        default="/tmp/test_local_checkpointing/",
 31        help="Checkpoint directory for local checkpoints",
 32    )
 33    parser.add_argument(
 34        '--async_save',
 35        action='store_true',
 36        help="Enable asynchronous saving of checkpoints.",
 37    )
 38    parser.add_argument(
 39        '--replication',
 40        action='store_true',
 41        help="If set, replication of local checkpoints is enabled"
 42        "Needs to be enabled on all ranks.",
 43    )
 44    parser.add_argument(
 45        '--persistent_queue',
 46        action='store_true',
 47        help=(
 48            "Enables a persistent version of AsyncCallsQueue. "
 49            "Effective only when --async_save is set."
 50        ),
 51    )
 52    parser.add_argument(
 53        '--replication_jump',
 54        default=4,
 55        type=int,
 56        help=(
 57            "Specifies `J`, the spacing between ranks storing replicas of a given rank's data. "
 58            "Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. "
 59            "This flag has an effect only if --replication is used. "
 60            "and must be consistent across all ranks. "
 61            "The default value of 4 is for demonstration purposes and can be adjusted as needed."
 62        ),
 63    )
 64    parser.add_argument(
 65        '--replication_factor',
 66        default=2,
 67        type=int,
 68        help="Number of machines storing the replica of a given rank's data",
 69    )
 70    return parser.parse_args()
 71
 72
 73# Define a simple model
 74class SimpleModel(nn.Module):
 75    def __init__(self):
 76        super(SimpleModel, self).__init__()
 77        self.fc1 = nn.Linear(10, 5)  # Linear layer: input size 10, output size 5
 78        self.fc2 = nn.Linear(5, 2)  # Linear layer: input size 5, output size 2
 79        self.activation = nn.ReLU()  # Activation function: ReLU
 80
 81    def forward(self, x):
 82        x = self.activation(self.fc1(x))
 83        x = self.fc2(x)
 84        return x
 85
 86
 87def init_distributed_backend(backend="nccl"):
 88    """
 89    Initialize the distributed process group for NCCL backend.
 90    Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
 91    """
 92    try:
 93        dist.init_process_group(
 94            backend=backend,  # Use NCCL backend
 95            init_method="env://",  # Use environment variables for initialization
 96        )
 97        logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
 98
 99        # Ensure each process uses a different GPU
100        torch.cuda.set_device(dist.get_rank())
101    except Exception as e:
102        logging.error(f"Error initializing the distributed backend: {e}")
103        raise
104
105
106def create_checkpoint_manager(args):
107    if args.replication:
108        logging.info("Creating CliqueReplicationStrategy.")
109        repl_strategy = CliqueReplicationStrategy.from_replication_params(
110            args.replication_jump, args.replication_factor
111        )
112    else:
113        repl_strategy = None
114
115    return LocalCheckpointManager(args.ckpt_dir, repl_strategy=repl_strategy)
116
117
118def save(args, ckpt_manager, async_queue, model, iteration):
119    # Create Tensor-Aware State Dict
120    ta_state_dict = BasicTensorAwareStateDict(model.state_dict())
121
122    if args.async_save:
123        logging.info("Creating save request.")
124        save_request = ckpt_manager.save(ta_state_dict, iteration, is_async=True)
125
126        logging.info("Saving TASD checkpoint...")
127        async_queue.schedule_async_request(save_request)
128
129    else:
130        logging.info("Saving TASD checkpoint...")
131        ckpt_manager.save(ta_state_dict, iteration)
132
133
134def load(args, ckpt_manager):
135    logging.info("Loading TASD checkpoint...")
136    iteration = ckpt_manager.find_latest()
137    assert iteration != -1, "Local checkpoint has not been found"
138    logging.info(f"Found checkpoint from iteration: {iteration}")
139
140    ta_state_dict, ckpt_part_id = ckpt_manager.load()
141    logging.info(f"Successfully loaded checkpoint part (id: {ckpt_part_id})")
142    return ta_state_dict.state_dict
143
144
145def main():
146    args = parse_args()
147    assert (
148        not args.persistent_queue or args.async_save
149    ), "--persistent_queue requires --async_save to be enabled."
150    assert (
151        not args.persistent_queue or not args.replication
152    ), "persistent_queue is currently incompatible with replication due to object pickling issues."
153    logging.info(f'{args}')
154
155    # Initialize the distributed backend
156    init_distributed_backend(backend="nccl")
157
158    # Instantiate the model and move to CUDA
159    model = SimpleModel().to("cuda")
160
161    # Instantiate checkpointing classess needed for local checkpointing
162    ckpt_manager = create_checkpoint_manager(args)
163    async_queue = AsyncCallsQueue(persistent=args.persistent_queue) if args.async_save else None
164
165    iteration = 123  # training iteration (used as training state id)
166
167    # Local checkpointing save
168    save(args, ckpt_manager, async_queue, model, iteration)
169
170    if args.async_save:
171        # Other operations can happen here
172
173        logging.info("Finalize TASD checkpoint saving.")
174        async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
175        async_queue.close()  # Explicitly close queue (optional)
176
177    # Synchronize processes to ensure all have completed the saving
178    dist.barrier()
179
180    # Local checkpointing load
181    load(args, ckpt_manager)
182
183    # Synchronize processes to ensure all have completed the loading
184    dist.barrier()
185
186    # Clean up checkpoint directory only on rank 0
187    if dist.get_rank() == 0:
188        logging.info(f"Cleaning up checkpoint directory: {args.ckpt_dir}")
189        shutil.rmtree(args.ckpt_dir)
190
191
192if __name__ == "__main__":
193    main()