1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16#
17# This is a basic example of straggler detection usage with a simple DDP workload
18# It uses straggler detection API to wrap the forward pass and measure GPU performance
19# GPU performance scores are printed at regular intervals
20# You can try "nvidia-smi -i <GPU idx> -lgc 800" to slow down some GPUs and see the effect.
21#
22
23import argparse
24import os
25import time
26import uuid
27
28import torch
29import torch.distributed as dist
30import torch.distributed.launcher as pet
31import torch.nn as nn
32from torch.nn.parallel import DistributedDataParallel as DDP
33from torch.utils.data import DataLoader, DistributedSampler
34from torchvision import datasets, transforms
35
36from nvidia_resiliency_ext import straggler
37
38
39class Model(nn.Module):
40 def __init__(self):
41 super(Model, self).__init__()
42 self.layers = nn.Sequential(
43 nn.Linear(784, 128),
44 nn.ReLU(),
45 nn.Linear(128, 64),
46 nn.ReLU(),
47 nn.Linear(64, 32),
48 nn.ReLU(),
49 nn.Linear(32, 10),
50 )
51
52 def forward(self, x):
53 x = torch.flatten(x, 1)
54 return self.layers(x)
55
56
57def train(args) -> None:
58 print(args)
59
60 straggler.Detector.initialize(gather_on_rank0=True)
61
62 dist.init_process_group(backend="nccl")
63 rank = dist.get_rank()
64 local_rank = int(os.environ["LOCAL_RANK"])
65 world_size = int(os.environ["WORLD_SIZE"])
66 device = torch.device(f"cuda:{local_rank}")
67 torch.cuda.set_device(device)
68
69 torch.manual_seed(42)
70
71 print(f"Running basic straggler det. DDP example on device {device}.")
72 model = Model().to(device)
73
74 ddp_model = DDP(model, device_ids=[local_rank])
75
76 transform = transforms.Compose(
77 [
78 transforms.ToTensor(),
79 transforms.Normalize((0.5,), (0.5,)),
80 ]
81 )
82 dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
83 sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank)
84 loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler)
85
86 optim = torch.optim.SGD(ddp_model.parameters(), lr=0.01, momentum=0.5)
87 loss_fn = torch.nn.CrossEntropyLoss()
88 epoch_num = 0
89
90 ddp_model.train()
91 total_iters_made = 0
92 training_start_time = time.monotonic()
93
94 while epoch_num < args.num_epochs:
95 for batch_idx, (data, target) in enumerate(loader):
96 data, target = data.to(device), target.to(device)
97
98 with straggler.Detector.detection_section("fwd", profile_cuda=True):
99 output = ddp_model(data)
100
101 loss = loss_fn(output, target)
102 optim.zero_grad()
103 loss.backward()
104 optim.step()
105
106 if (batch_idx % args.log_interval) == 0 and rank == 0:
107 print(
108 f"Rank {local_rank}, Epoch {epoch_num}, Batch {batch_idx}, Loss {loss.item()}"
109 )
110
111 if (batch_idx % args.report_interval) == 0:
112 report = straggler.Detector.generate_report()
113 if rank == 0:
114 print(
115 f"Rank {local_rank} GPUs relative perf: {report.gpu_relative_perf_scores}"
116 )
117 print(
118 f"Rank {local_rank} GPUs individual perf: {report.gpu_individual_perf_scores}"
119 )
120
121 total_iters_made += 1
122 epoch_num += 1
123
124 training_stop_time = time.monotonic()
125 time_per_iter = (training_stop_time - training_start_time) / total_iters_made
126 print(f"Time per iteration [sec]: {time_per_iter:.5f}")
127
128 straggler.Detector.shutdown()
129 dist.destroy_process_group()
130
131
132def main() -> None:
133 parser = argparse.ArgumentParser()
134 parser.add_argument("--num-processes", type=int, default=4)
135 parser.add_argument("--num-epochs", type=int, default=3)
136 parser.add_argument("--batch-size", type=int, default=100)
137 parser.add_argument("--log-interval", type=int, default=100)
138 parser.add_argument("--report-interval", type=int, default=300)
139
140 args: argparse.Namespace = parser.parse_args()
141
142 lc = pet.LaunchConfig(
143 min_nodes=1,
144 max_nodes=1,
145 nproc_per_node=args.num_processes,
146 run_id=str(uuid.uuid4()),
147 rdzv_backend="c10d",
148 rdzv_endpoint="localhost:0",
149 max_restarts=0,
150 monitor_interval=1,
151 )
152
153 pet.elastic_launch(lc, entrypoint=train)(args)
154
155
156if __name__ == "__main__":
157 main()