Basic usage example

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16#
 17# This is a basic example of straggler detection usage with a simple DDP workload
 18# It uses straggler detection API to wrap the forward pass and measure GPU performance
 19# GPU performance scores are printed at regular intervals
 20# You can try "nvidia-smi -i <GPU idx> -lgc 800" to slow down some GPUs and see the effect.
 21#
 22
 23import argparse
 24import os
 25import time
 26import uuid
 27
 28import torch
 29import torch.distributed as dist
 30import torch.distributed.launcher as pet
 31import torch.nn as nn
 32from torch.nn.parallel import DistributedDataParallel as DDP
 33from torch.utils.data import DataLoader, DistributedSampler
 34from torchvision import datasets, transforms
 35
 36from nvidia_resiliency_ext import straggler
 37
 38
 39class Model(nn.Module):
 40    def __init__(self):
 41        super(Model, self).__init__()
 42        self.layers = nn.Sequential(
 43            nn.Linear(784, 128),
 44            nn.ReLU(),
 45            nn.Linear(128, 64),
 46            nn.ReLU(),
 47            nn.Linear(64, 32),
 48            nn.ReLU(),
 49            nn.Linear(32, 10),
 50        )
 51
 52    def forward(self, x):
 53        x = torch.flatten(x, 1)
 54        return self.layers(x)
 55
 56
 57def train(args) -> None:
 58    print(args)
 59
 60    straggler.Detector.initialize(gather_on_rank0=True)
 61
 62    dist.init_process_group(backend="nccl")
 63    rank = dist.get_rank()
 64    local_rank = int(os.environ["LOCAL_RANK"])
 65    world_size = int(os.environ["WORLD_SIZE"])
 66    device = torch.device(f"cuda:{local_rank}")
 67    torch.cuda.set_device(device)
 68
 69    torch.manual_seed(42)
 70
 71    print(f"Running basic straggler det. DDP example on device {device}.")
 72    model = Model().to(device)
 73
 74    ddp_model = DDP(model, device_ids=[local_rank])
 75
 76    transform = transforms.Compose(
 77        [
 78            transforms.ToTensor(),
 79            transforms.Normalize((0.5,), (0.5,)),
 80        ]
 81    )
 82    dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
 83    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank)
 84    loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler)
 85
 86    optim = torch.optim.SGD(ddp_model.parameters(), lr=0.01, momentum=0.5)
 87    loss_fn = torch.nn.CrossEntropyLoss()
 88    epoch_num = 0
 89
 90    ddp_model.train()
 91    total_iters_made = 0
 92    training_start_time = time.monotonic()
 93
 94    while epoch_num < args.num_epochs:
 95        for batch_idx, (data, target) in enumerate(loader):
 96            data, target = data.to(device), target.to(device)
 97
 98            with straggler.Detector.detection_section("fwd", profile_cuda=True):
 99                output = ddp_model(data)
100
101            loss = loss_fn(output, target)
102            optim.zero_grad()
103            loss.backward()
104            optim.step()
105
106            if (batch_idx % args.log_interval) == 0 and rank == 0:
107                print(
108                    f"Rank {local_rank}, Epoch {epoch_num}, Batch {batch_idx}, Loss {loss.item()}"
109                )
110
111            if (batch_idx % args.report_interval) == 0:
112                report = straggler.Detector.generate_report()
113                if rank == 0:
114                    print(
115                        f"Rank {local_rank} GPUs relative perf: {report.gpu_relative_perf_scores}"
116                    )
117                    print(
118                        f"Rank {local_rank} GPUs individual perf: {report.gpu_individual_perf_scores}"
119                    )
120
121            total_iters_made += 1
122        epoch_num += 1
123
124    training_stop_time = time.monotonic()
125    time_per_iter = (training_stop_time - training_start_time) / total_iters_made
126    print(f"Time per iteration [sec]: {time_per_iter:.5f}")
127
128    straggler.Detector.shutdown()
129    dist.destroy_process_group()
130
131
132def main() -> None:
133    parser = argparse.ArgumentParser()
134    parser.add_argument("--num-processes", type=int, default=4)
135    parser.add_argument("--num-epochs", type=int, default=3)
136    parser.add_argument("--batch-size", type=int, default=100)
137    parser.add_argument("--log-interval", type=int, default=100)
138    parser.add_argument("--report-interval", type=int, default=300)
139
140    args: argparse.Namespace = parser.parse_args()
141
142    lc = pet.LaunchConfig(
143        min_nodes=1,
144        max_nodes=1,
145        nproc_per_node=args.num_processes,
146        run_id=str(uuid.uuid4()),
147        rdzv_backend="c10d",
148        rdzv_endpoint="localhost:0",
149        max_restarts=0,
150        monitor_interval=1,
151    )
152
153    pet.elastic_launch(lc, entrypoint=train)(args)
154
155
156if __name__ == "__main__":
157    main()