Basic usage example

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16import json
 17
 18import torch
 19import torch.distributed as dist
 20import torch.nn as nn
 21import torch.optim as optim
 22from torch.nn.parallel import DistributedDataParallel as DDP
 23from torch.utils.data import DataLoader, DistributedSampler
 24
 25# FT: import NVRx
 26import nvidia_resiliency_ext.fault_tolerance as ft
 27
 28# Simple example of using the FT library with PyTorch DDP.
 29# This script trains a dummy model on dummy data. CPU is used for training.
 30# After each epoch, FT timeouts are calculated and saved to the file "./ft_state.json".
 31#
 32# You can run it using:
 33# `ft_launcher --nproc_per_node=4 --max-restarts=3 --ft-param-initial_rank_heartbeat_timeout=30 --ft-param-rank_heartbeat_timeout=15 examples/fault_tolerance/basic_ft_example.py`
 34# In this example configuration, at most 3 training restarts are allowed.
 35#
 36# To find rank PIDs, use:
 37# `ps aux | grep basic_ft_example.py | grep -v grep`
 38#
 39# Examples:
 40#
 41# 1. Hang detection using predefined timeouts:
 42#    - Remove `ft_state.json` if it exists (`rm ft_state.json`).
 43#    - During the 0th epoch, stop a rank using `kill -SIGSTOP <rank_pid>`.
 44#    - After approximately 15 seconds, a "Did not get subsequent heartbeat." error should be raised.
 45#    - All ranks will be restarted.
 46#
 47# 2. Hang detection using computed timeouts:
 48#    - Run the example for more than 1 epoch to allow FT timeouts to be calculated.
 49#    - Stop a rank using `kill -SIGSTOP <rank_pid>`.
 50#    - After the computed timeout elapses, a "Did not get subsequent heartbeat." error should be raised.
 51#    - All ranks will be restarted.
 52#
 53# 3. Rank error handling:
 54#    - Kill a rank using `kill -SIGKILL <rank_pid>`.
 55#    - All ranks will be restarted.
 56
 57
 58FEAT_SIZE = 4096
 59DNN_OUT_SIZE = 128
 60BATCH_SIZE = 100
 61NUM_EPOCHS = 10
 62DATASET_LEN = 100000
 63
 64
 65class SimpleDataset(torch.utils.data.Dataset):
 66    def __init__(self, size):
 67        self.size = size
 68
 69    def __len__(self):
 70        return self.size
 71
 72    def __getitem__(self, idx):
 73        x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cpu')
 74        y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cpu')
 75        return x, y
 76
 77
 78class SimpleModel(nn.Module):
 79    def __init__(self):
 80        super().__init__()
 81        self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
 82        self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
 83
 84    def forward(self, x):
 85        x = self.fc1(x)
 86        x = nn.functional.relu(x)
 87        x = self.fc2(x)
 88        return x
 89
 90
 91def print_on_rank0(msg):
 92    if dist.get_rank() == 0:
 93        print(msg)
 94
 95
 96def main(rank, world_size):
 97    dist.init_process_group("gloo", rank=rank, world_size=world_size)
 98    print_on_rank0(f"Starting new training run... World size={dist.get_world_size()}")
 99
100    # FT: initialize the client
101    ft_client = ft.RankMonitorClient()
102    ft_client.init_workload_monitoring()
103    print_on_rank0(f"FT initialized. Timeouts: {ft_client.timeouts}")
104    # FT: load state (calculated timeouts)
105    if os.path.exists("ft_state.json"):
106        with open("ft_state.json", "r") as f:
107            ft_state = json.load(f)
108            ft_client.load_state_dict(ft_state)
109        print_on_rank0(f"FT timeouts {ft_client.timeouts} loaded from ft_state.json")
110
111    # Dataset and DataLoader with DistributedSampler
112    dataset = SimpleDataset(size=DATASET_LEN)
113    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
114    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
115
116    # Model, optimizer, and DDP
117    model = SimpleModel()
118    ddp_model = DDP(model)
119    optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
120    loss_fn = nn.MSELoss()
121
122    num_iters_in_epoch = len(dataloader)
123    num_iters_for_10pct = num_iters_in_epoch // 10  # iters for 1/10 of epoch
124
125    for epoch in range(NUM_EPOCHS):
126        sampler.set_epoch(epoch)
127        for batch_idx, (data, target) in enumerate(dataloader):
128            if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
129                print(f"Epoch {epoch} progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
130            optimizer.zero_grad()
131            output = ddp_model(data)
132            loss = loss_fn(output, target)
133            loss.backward()
134            optimizer.step()
135            # FT: send heartbeat to the server
136            ft_client.send_heartbeat()
137        print_on_rank0(f"Epoch {epoch} complete. Loss: {loss.item()}")
138        # FT: calculate and set new timeouts
139        ft_client.calculate_and_set_timeouts()
140        # FT: save the state (calculated timeouts)
141        with open("ft_state.json", "w") as f:
142            json.dump(ft_client.state_dict(), f)
143        print_on_rank0(f"FT timeouts {ft_client.timeouts} saved to ft_state.json")
144
145    # FT: shutdown the client
146    ft_client.shutdown_workload_monitoring()
147    dist.destroy_process_group()
148
149
150if __name__ == "__main__":
151    import os
152
153    world_size = int(os.environ["WORLD_SIZE"])
154    rank = int(os.environ["RANK"])
155    main(rank, world_size)