1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import json
17
18import torch
19import torch.distributed as dist
20import torch.nn as nn
21import torch.optim as optim
22from torch.nn.parallel import DistributedDataParallel as DDP
23from torch.utils.data import DataLoader, DistributedSampler
24
25# FT: import NVRx
26import nvidia_resiliency_ext.fault_tolerance as ft
27
28# Simple example of using the FT library with PyTorch DDP.
29# This script trains a dummy model on dummy data. CPU is used for training.
30# After each epoch, FT timeouts are calculated and saved to the file "./ft_state.json".
31#
32# You can run it using:
33# `ft_launcher --nproc_per_node=4 --max-restarts=3 --ft-param-initial_rank_heartbeat_timeout=30 --ft-param-rank_heartbeat_timeout=15 examples/fault_tolerance/basic_ft_example.py`
34# In this example configuration, at most 3 training restarts are allowed.
35#
36# To find rank PIDs, use:
37# `ps aux | grep basic_ft_example.py | grep -v grep`
38#
39# Examples:
40#
41# 1. Hang detection using predefined timeouts:
42# - Remove `ft_state.json` if it exists (`rm ft_state.json`).
43# - During the 0th epoch, stop a rank using `kill -SIGSTOP <rank_pid>`.
44# - After approximately 15 seconds, a "Did not get subsequent heartbeat." error should be raised.
45# - All ranks will be restarted.
46#
47# 2. Hang detection using computed timeouts:
48# - Run the example for more than 1 epoch to allow FT timeouts to be calculated.
49# - Stop a rank using `kill -SIGSTOP <rank_pid>`.
50# - After the computed timeout elapses, a "Did not get subsequent heartbeat." error should be raised.
51# - All ranks will be restarted.
52#
53# 3. Rank error handling:
54# - Kill a rank using `kill -SIGKILL <rank_pid>`.
55# - All ranks will be restarted.
56
57
58FEAT_SIZE = 4096
59DNN_OUT_SIZE = 128
60BATCH_SIZE = 100
61NUM_EPOCHS = 10
62DATASET_LEN = 100000
63
64
65class SimpleDataset(torch.utils.data.Dataset):
66 def __init__(self, size):
67 self.size = size
68
69 def __len__(self):
70 return self.size
71
72 def __getitem__(self, idx):
73 x = torch.rand((FEAT_SIZE,), dtype=torch.float32, device='cpu')
74 y = torch.rand((DNN_OUT_SIZE,), dtype=torch.float32, device='cpu')
75 return x, y
76
77
78class SimpleModel(nn.Module):
79 def __init__(self):
80 super().__init__()
81 self.fc1 = nn.Linear(FEAT_SIZE, FEAT_SIZE)
82 self.fc2 = nn.Linear(FEAT_SIZE, DNN_OUT_SIZE)
83
84 def forward(self, x):
85 x = self.fc1(x)
86 x = nn.functional.relu(x)
87 x = self.fc2(x)
88 return x
89
90
91def print_on_rank0(msg):
92 if dist.get_rank() == 0:
93 print(msg)
94
95
96def main(rank, world_size):
97 dist.init_process_group("gloo", rank=rank, world_size=world_size)
98 print_on_rank0(f"Starting new training run... World size={dist.get_world_size()}")
99
100 # FT: initialize the client
101 ft_client = ft.RankMonitorClient()
102 ft_client.init_workload_monitoring()
103 print_on_rank0(f"FT initialized. Timeouts: {ft_client.timeouts}")
104 # FT: load state (calculated timeouts)
105 if os.path.exists("ft_state.json"):
106 with open("ft_state.json", "r") as f:
107 ft_state = json.load(f)
108 ft_client.load_state_dict(ft_state)
109 print_on_rank0(f"FT timeouts {ft_client.timeouts} loaded from ft_state.json")
110
111 # Dataset and DataLoader with DistributedSampler
112 dataset = SimpleDataset(size=DATASET_LEN)
113 sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
114 dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler)
115
116 # Model, optimizer, and DDP
117 model = SimpleModel()
118 ddp_model = DDP(model)
119 optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
120 loss_fn = nn.MSELoss()
121
122 num_iters_in_epoch = len(dataloader)
123 num_iters_for_10pct = num_iters_in_epoch // 10 # iters for 1/10 of epoch
124
125 for epoch in range(NUM_EPOCHS):
126 sampler.set_epoch(epoch)
127 for batch_idx, (data, target) in enumerate(dataloader):
128 if (batch_idx % num_iters_for_10pct) == 0 and rank == 0:
129 print(f"Epoch {epoch} progress: {100 * batch_idx / num_iters_in_epoch:.2f}%")
130 optimizer.zero_grad()
131 output = ddp_model(data)
132 loss = loss_fn(output, target)
133 loss.backward()
134 optimizer.step()
135 # FT: send heartbeat to the server
136 ft_client.send_heartbeat()
137 print_on_rank0(f"Epoch {epoch} complete. Loss: {loss.item()}")
138 # FT: calculate and set new timeouts
139 ft_client.calculate_and_set_timeouts()
140 # FT: save the state (calculated timeouts)
141 with open("ft_state.json", "w") as f:
142 json.dump(ft_client.state_dict(), f)
143 print_on_rank0(f"FT timeouts {ft_client.timeouts} saved to ft_state.json")
144
145 # FT: shutdown the client
146 ft_client.shutdown_workload_monitoring()
147 dist.destroy_process_group()
148
149
150if __name__ == "__main__":
151 import os
152
153 world_size = int(os.environ["WORLD_SIZE"])
154 rank = int(os.environ["RANK"])
155 main(rank, world_size)