Heartbeat API usage example with DDP

Warning

This example loads checkpoints with torch.load(..., weights_only=True) because it saves only plain state dictionaries. For PyTorch versions before 2.10.0, CVE-2026-24747 affects the weights_only unpickler. Do not load untrusted checkpoint files with affected PyTorch versions; use PyTorch 2.10.0 or newer when checkpoint provenance is not fully trusted.

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16"""
 17Demo of fault tolerance with DDP training, using FT package heartbeats API
 18
 19This script demonstrates how to use the FT heartbeats API for hang detection in
 20distributed training. It should be run with the ft_launcher command. E.g.:
 21
 22`ft_launcher --nproc-per-node=2 --ft-cfg-path=./examples/fault_tolerance/fault_tol_cfg_heartbeats.yaml examples/fault_tolerance/train_ddp_heartbeats_api.py --device=cpu`
 23
 24Fault tolerance features demonstrated:
 251. Heartbeat sending during training
 262. Timeout calculation and setting
 273. State persistence through checkpoints
 284. Simulated fault injection
 29
 30Security note:
 31This example loads checkpoints with torch.load(..., weights_only=True)
 32because it saves only plain state dictionaries. For PyTorch versions before
 332.10.0, CVE-2026-24747 affects the weights_only unpickler. Do not load
 34untrusted checkpoint files with affected PyTorch versions; use PyTorch 2.10.0
 35or newer when checkpoint provenance is not fully trusted.
 36"""
 37
 38import argparse
 39import logging
 40import os
 41import random
 42import signal
 43import sys
 44import threading
 45import time
 46
 47import dist_utils
 48import log_utils
 49import numpy as np
 50import torch
 51import torch.nn as nn
 52
 53import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
 54
 55
 56# Dummy dataset.
 57class Dataset(torch.utils.data.Dataset):
 58    def __init__(self, size, hidden):
 59        self.size = size
 60        self.hidden = hidden
 61
 62    def __len__(self):
 63        return self.size
 64
 65    def __getitem__(self, idx):
 66        data = torch.full(
 67            (self.hidden,),
 68            fill_value=idx,
 69            dtype=torch.float32,
 70            device='cpu',
 71        )
 72        return data
 73
 74
 75# Dummy model
 76class Model(nn.Module):
 77    def __init__(self, hidden):
 78        super().__init__()
 79        self.l1 = nn.Linear(hidden, hidden)
 80        self.l2 = nn.Linear(hidden, hidden)
 81
 82    def forward(self, x):
 83        x = self.l1(x)
 84        x = self.l2(x)
 85        return x
 86
 87
 88def parse_args():
 89    def tuple_type(strings):
 90        strings = strings.replace("(", "").replace(")", "")
 91        mapped_int = map(int, strings.split(","))
 92        return tuple(mapped_int)
 93
 94    def fault_desc(strings):
 95        parts = strings.split(",")
 96        assert len(parts) == 2
 97        return {'fault': parts[0], 'delay': float(parts[1])}
 98
 99    parser = argparse.ArgumentParser(
100        description='Example of PyTorch DDP training with the Fault Tolerance package',
101        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
102    )
103
104    # fmt: off
105    parser.add_argument('--hidden', type=int, default=4096,
106                        help='Hidden size')
107    parser.add_argument('--batch', type=int, default=8,
108                        help='Batch size')
109    parser.add_argument('--epochs', type=int, default=4,
110                        help='Number of training epochs')
111    parser.add_argument('--train_dataset_size', type=int, default=1000000,
112                        help='Train dataset size')
113    parser.add_argument('--val_dataset_size', type=int, default=2000,
114                        help='Validation dataset size')
115    parser.add_argument('--device', type=str, default='cuda',
116                        choices=['cpu', 'cuda'],
117                        help='Device')
118    
119    parser.add_argument('--interrupt_at', type=tuple_type, nargs='*',
120                        help='Manual interruption after (epoch, iteration), '
121                        'for testing only')
122    parser.add_argument('--save_interval', type=int, default=-1,
123                        help='Interval for saving periodic checkpoints')
124    parser.add_argument('--logging_interval', type=int, default=1,
125                        help='Interval for log entries')
126    parser.add_argument('--log_all_ranks', action='store_true',
127                        help='Enable logging from all distributed ranks')
128    parser.add_argument('--output_dir', type=str, default='results/output',
129                        help='Output dir')
130    parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
131                        help='Name of a checkpoint file')
132    
133    parser.add_argument('--local_rank', type=int,
134                        default=os.getenv('LOCAL_RANK', 0))
135    parser.add_argument('--init_distributed_method', type=str, default='tcp',
136                        help='Init distributed group with TCP store ("tcp") or file store ("file")')
137
138    parser.add_argument('--simulated_fault', type=fault_desc,
139                        help='Description of a fault to be simulated')
140    # fmt: on
141
142    args = parser.parse_args()
143
144    if args.interrupt_at:
145        args.interrupt_at = set(args.interrupt_at)
146    else:
147        args.interrupt_at = set()
148
149    return args
150
151
152def load_checkpoint(path):
153    map_location = {
154        'cpu': 'cpu',
155    }
156    if torch.cuda.is_available():
157        map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
158
159    logging.info(f'Loading checkpoint from {path}')
160    checkpoint = torch.load(path, map_location=map_location, weights_only=True)
161    return checkpoint
162
163
164def save_checkpoint(
165    progress,
166    model,
167    optimizer,
168    ft_client,
169    output_dir,
170    checkpoint_fname,
171):
172    state = {
173        'progress': progress,
174        'model_state': model.state_dict(),
175        'optimizer_state': optimizer.state_dict(),
176        'ft_state': ft_client.state_dict(),
177    }
178
179    checkpoint_path = os.path.join(output_dir, checkpoint_fname)
180
181    with dist_utils.sync_workers() as rank:
182        if rank == 0:
183            logging.info(f'Saving checkpoint to {checkpoint_path}')
184            torch.save(state, checkpoint_path)
185
186
187def training_loop(
188    ft_client,
189    para_model,
190    model,
191    optimizer,
192    device,
193    dataloader,
194    sampler,
195    progress,
196    args,
197):
198    epoch_idx = progress['epoch_idx']
199
200    # NOTE: torch.utils.data.DistributedSampler must be prepared for current epoch
201    # need to do it before starting iteration
202    sampler.start_sample_idx = progress['iter_idx'] * args.batch
203    sampler.set_epoch(epoch_idx)
204
205    para_model.train()
206
207    last_log_time = time.monotonic()
208
209    for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
210        if ft_client.hb_timeouts.are_valid is False and epoch_idx == 1 and iter_idx == 1:
211            # after 0th epoch is completed and we've done 0th iteration of the 1st epoch,
212            # we can calculate and set timeouts. this is a good moment to do so,
213            # because now we've seen the possibly long interval where checkpoint was saved.
214            ft_client.calculate_and_set_hb_timeouts()
215
216        optimizer.zero_grad()
217        x = x.to(device)
218        y = para_model(x)
219        loss = y.mean()
220        train_loss = loss.item()
221        loss.backward()
222
223        if iter_idx % args.logging_interval == 0:
224            avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
225            logging.info(
226                f'CHECK TRAIN epoch: {epoch_idx:4d} '
227                f'iter: {iter_idx:5d} '
228                f'loss: {avg_train_loss} '
229                f'input: {x[:, 0]}'
230            )
231            if iter_idx > 0:
232                time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
233                last_log_time = time.monotonic()
234                logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
235
236        progress['iter_idx'] = iter_idx + 1
237
238        ft_client.send_heartbeat()
239        optimizer.step()
240
241        # Whether to do a periodic checkpointing
242        periodic_save = iter_idx % args.save_interval == args.save_interval - 1
243
244        if periodic_save or (epoch_idx, iter_idx) in args.interrupt_at:
245            save_checkpoint(
246                progress=progress,
247                model=model,
248                optimizer=optimizer,
249                ft_client=ft_client,
250                output_dir=args.output_dir,
251                checkpoint_fname=args.checkpoint_fname,
252            )
253            if (epoch_idx, iter_idx) in args.interrupt_at:
254                logging.info('Manual interruption, exiting')
255                sys.exit(0)
256
257
258def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
259    total_val_loss = 0
260    model.eval()
261
262    for iter_idx, x in enumerate(val_dataloader):
263        x = x.to(device)
264        y = model(x)
265        loss = y.mean().item()
266        total_val_loss += loss
267        ft_client.send_heartbeat()
268
269    logging.info(
270        f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
271    )
272
273
274_sim_fault_canceled = False
275_sim_fault_is_set = False
276
277
278def _cancel_simulated_fault():
279    global _sim_fault_canceled
280    _sim_fault_canceled = True
281
282
283def _setup_simulated_fault(ft_client, fault_desc, device):
284    # FIXME: hanging rank with SIGTSTP results in rank monitor
285    # blocked when trying to receive the data in _on_ipc_data_from_rank
286
287    global _sim_fault_is_set
288    _sim_fault_is_set = True  # should be True on all ranks
289
290    rng = random.Random()
291
292    logging.info(f"Initializing simulated fault: {fault_desc}")
293
294    fault_type = fault_desc['fault']
295    if fault_type == 'random':
296        fault_type = rng.choice(['rank_killed', 'rank_hung'])
297
298    rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
299    rank_to_fail = torch.tensor([rank_to_fail], device=device)
300    dist_utils.broadcast(rank_to_fail, 0)
301    rank_to_fail = int(rank_to_fail.item())
302
303    rank = torch.distributed.get_rank()
304    if rank != rank_to_fail:
305        return
306
307    if fault_type == 'rank_killed':
308        target_pid = os.getpid()
309        target_sig = signal.SIGKILL
310    elif fault_type == 'rank_hung':
311        target_pid = os.getpid()
312        target_sig = signal.SIGSTOP
313    else:
314        raise Exception(f"Unknown fault type {fault_type}")
315
316    delay = fault_desc['delay'] + 4.0 * rng.random()
317
318    logging.info(
319        f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
320    )
321
322    def __fault_thread():
323        time.sleep(delay)
324        if _sim_fault_canceled:
325            return
326        print(
327            f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
328            file=sys.stderr,
329        )
330        os.kill(target_pid, target_sig)
331
332    fault_sim_thread = threading.Thread(target=__fault_thread)
333    fault_sim_thread.daemon = True
334    fault_sim_thread.start()
335
336
337_signal_received = False
338
339
340def _sig_handler(*args, **kwargs):
341    print("Signal received!", file=sys.stderr)
342    global _signal_received
343    _signal_received = True
344
345
346def main():
347    signal.signal(signal.SIGTERM, _sig_handler)
348
349    args = parse_args()
350
351    torch.manual_seed(123)
352    np.random.seed(123)
353    random.seed(123)
354
355    if args.device == 'cuda':
356        if torch.cuda.is_available():
357            device = torch.device('cuda')
358            torch.cuda.set_device(args.local_rank)
359        else:
360            raise RuntimeError("Selected 'cuda' device but torch.cuda is not available.")
361    elif args.device == 'cpu':
362        device = torch.device('cpu')
363    else:
364        raise RuntimeError('Unknown device')
365
366    if int(os.getenv('WORLD_SIZE', '1')) == 1:
367        raise RuntimeError('This example supports only multi-gpu training')
368
369    os.makedirs(args.output_dir, exist_ok=True)
370
371    if args.init_distributed_method == 'tcp':
372        # NOTE: when runing tests with tcp init method we noticed
373        # occasional "address already in use" errors, after workload
374        # is restarted
375        dist_utils.init_distributed_with_tcp_store(device)
376    elif args.init_distributed_method == 'file':
377        dist_utils.init_distributed_with_file_store(device, store_file_dir=args.output_dir)
378    else:
379        raise RuntimeError(
380            f"--init_distributed_method should be ['tcp','file'] it is {args.init_distributed_method}"
381        )
382
383    if args.log_all_ranks:
384        log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
385    else:
386        log_file_name = 'train_log.log'
387    log_file_path = os.path.join(args.output_dir, log_file_name)
388
389    # NOTE: logging appends outputs to an existing log file if it already
390    # exists. Results from a single training run (potentially with many
391    # restarts from a checkpoint) are stored in a single log file.
392    log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
393    logging.info(args)
394
395    rank = dist_utils.get_rank()
396
397    logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
398
399    # Dummy datasets
400    train_dataset = Dataset(args.train_dataset_size, args.hidden)
401    val_dataset = Dataset(args.val_dataset_size, args.hidden)
402
403    # ResumableDistributedSampler is needed to skip consumed samples
404    train_sampler = dist_utils.ResumableDistributedSampler(
405        train_dataset,
406        drop_last=True,
407    )
408
409    val_sampler = torch.utils.data.DistributedSampler(
410        val_dataset,
411    )
412
413    # A dummy model and an optimizer
414    model = Model(args.hidden).to(device)
415    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
416
417    # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
418    progress = {
419        'epoch_idx': 0,
420        'iter_idx': 0,
421    }
422
423    checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
424
425    # Initialize fault tolerance.
426    ft_client = fault_tolerance.RankMonitorClient()
427    ft_client.init_workload_monitoring()
428
429    checkpoint = None
430
431    # try to load checkpoint from disk
432    if os.path.exists(checkpoint_path):
433        checkpoint = load_checkpoint(checkpoint_path)
434        if checkpoint:
435            logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
436
437    if checkpoint:
438        model.load_state_dict(checkpoint['model_state'])
439        optimizer.load_state_dict(checkpoint['optimizer_state'])
440        ft_client.load_state_dict(checkpoint['ft_state'])
441        progress.update(checkpoint['progress'])
442        # Return with zero exit code if model is already fully trained.
443        if progress['epoch_idx'] == args.epochs:
444            logging.info('Training finished.')
445            ft_client.shutdown_workload_monitoring()
446            torch.distributed.destroy_process_group()
447            sys.exit(0)
448
449    train_dataloader = torch.utils.data.DataLoader(
450        dataset=train_dataset,
451        batch_size=args.batch,
452        sampler=train_sampler,
453        num_workers=4,
454        persistent_workers=True,
455        pin_memory=False,
456    )
457
458    val_dataloader = torch.utils.data.DataLoader(
459        dataset=val_dataset,
460        batch_size=args.batch,
461        sampler=val_sampler,
462        num_workers=4,
463    )
464
465    # Regular DDP init
466    # NOTE: for convenience code is keeping both wrapped and unwrapped model and
467    # uses wrapped model for training and unwrapped model for saving the
468    # checkpoint and validation. It doesn't increase memory consumption
469    # since both models are holding references to the same parameters.
470    # Additionally saved checkpoint is ready for inference and doesn't have to
471    # be manually unwrapped by accessing the (undocumented) "module" attribute
472    # of DDP-wrapped model.
473    if device.type == 'cuda':
474        device_ids = [args.local_rank]
475        output_device = args.local_rank
476    elif device.type == 'cpu':
477        device_ids = None
478        output_device = None
479    else:
480        raise RuntimeError('Unsupported device type')
481    para_model = torch.nn.parallel.DistributedDataParallel(
482        model, device_ids=device_ids, output_device=output_device
483    )
484
485    # Iteration over epochs, notice that it starts from 'epoch_idx'
486    # which was previously loaded from the checkpoint
487    for epoch_idx in range(progress['epoch_idx'], args.epochs):
488        training_loop(
489            ft_client,
490            para_model,
491            model,
492            optimizer,
493            device,
494            train_dataloader,
495            train_sampler,
496            progress,
497            args,
498        )
499
500        # epoch_idx is incremented because the current epoch is finished
501        # and potential resume from this checkpoint should start a new training epoch.
502        progress['epoch_idx'] += 1
503        progress['iter_idx'] = 0
504
505        validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
506
507        # Checkpoint contains everything needed for deterministic resume:
508        # state of the model, optimizer and other components,
509        save_checkpoint(
510            progress=progress,
511            model=model,
512            optimizer=optimizer,
513            ft_client=ft_client,
514            output_dir=args.output_dir,
515            checkpoint_fname=args.checkpoint_fname,
516        )
517
518        # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
519        # if _any_ rank received SIGTERM, we leave the main loop
520        if dist_utils.is_true_on_any_rank(_signal_received):
521            logging.info('Leaving the main loop, due to SIGTERM')
522            break
523
524        # Setup simulated fault as soon as we have valid timeouts
525        if args.simulated_fault and not _sim_fault_is_set and ft_client.hb_timeouts.are_valid:
526            _setup_simulated_fault(ft_client, args.simulated_fault, device)
527
528    _cancel_simulated_fault()
529    ft_client.shutdown_workload_monitoring()
530    torch.distributed.destroy_process_group()
531    logging.info('Leaving main, ret_code=0')
532    sys.exit(0)
533
534
535if __name__ == "__main__":
536    main()