Section API usage example with DDP

Warning

This example loads checkpoints with torch.load(..., weights_only=True) because it saves only plain state dictionaries. For PyTorch versions before 2.10.0, CVE-2026-24747 affects the weights_only unpickler. Do not load untrusted checkpoint files with affected PyTorch versions; use PyTorch 2.10.0 or newer when checkpoint provenance is not fully trusted.

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16"""
 17Demo of DDP training with fault tolerance, using FT package sections API
 18
 19It should be run with `ft_launcher`. E.g.
 20`ft_launcher --nproc-per-node=2 --ft-cfg-path=./examples/fault_tolerance/fault_tol_cfg_sections.yaml examples/fault_tolerance/train_ddp_sections_api.py --device=cpu`
 21
 22This example uses following custom FT sections
 23- 'init' - covers workload initialization
 24- 'step' - covers training/evaluation step (fwd/bwd, loss calculation etc)
 25- 'checkpoint' - covers checkpoint saving
 26
 27Timeout for each section is calculated when enough data is collected.
 28FT "out-of-section" timeout is calculated when the training run ends normally.
 29FT state is saved in a JSON file.
 30
 31This example allows to simulate a training fault:
 32- selected rank hung
 33- selected rank terminated
 34
 35Security note:
 36This example loads checkpoints with torch.load(..., weights_only=True)
 37because it saves only plain state dictionaries. For PyTorch versions before
 382.10.0, CVE-2026-24747 affects the weights_only unpickler. Do not load
 39untrusted checkpoint files with affected PyTorch versions; use PyTorch 2.10.0
 40or newer when checkpoint provenance is not fully trusted.
 41"""
 42import argparse
 43import json
 44import logging
 45import os
 46import random
 47import signal
 48import sys
 49import threading
 50import time
 51
 52import dist_utils
 53import log_utils
 54import numpy as np
 55import torch
 56import torch.nn as nn
 57
 58import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
 59
 60
 61# Dummy dataset.
 62class Dataset(torch.utils.data.Dataset):
 63    def __init__(self, size, hidden):
 64        self.size = size
 65        self.hidden = hidden
 66
 67    def __len__(self):
 68        return self.size
 69
 70    def __getitem__(self, idx):
 71        data = torch.full(
 72            (self.hidden,),
 73            fill_value=idx,
 74            dtype=torch.float32,
 75            device='cpu',
 76        )
 77        return data
 78
 79
 80# Dummy model
 81class Model(nn.Module):
 82    def __init__(self, hidden):
 83        super().__init__()
 84        self.l1 = nn.Linear(hidden, hidden)
 85        self.l2 = nn.Linear(hidden, hidden)
 86
 87    def forward(self, x):
 88        x = self.l1(x)
 89        x = self.l2(x)
 90        return x
 91
 92
 93def parse_args():
 94    def fault_desc(strings):
 95        parts = strings.split(",")
 96        assert len(parts) == 2, "Fault description must be in format 'fault,delay'"
 97        return {'fault': parts[0], 'delay': float(parts[1])}
 98
 99    parser = argparse.ArgumentParser(
100        description='Example of PyTorch DDP training with the Fault Tolerance package',
101        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
102    )
103
104    # fmt: off
105    parser.add_argument('--hidden', type=int, default=4096,
106                        help='Hidden size')
107    parser.add_argument('--batch', type=int, default=8,
108                        help='Batch size')
109    parser.add_argument('--epochs', type=int, default=4,
110                        help='Number of training epochs')
111    parser.add_argument('--train_dataset_size', type=int, default=1000000,
112                        help='Train dataset size')
113    parser.add_argument('--val_dataset_size', type=int, default=2000,
114                        help='Validation dataset size')
115    parser.add_argument('--device', type=str, default='cuda',
116                        choices=['cpu', 'cuda'],
117                        help='Device')
118    parser.add_argument('--save_interval', type=int, default=-1,
119                        help='Interval for saving periodic checkpoints.')
120    parser.add_argument('--logging_interval', type=int, default=1,
121                        help='Interval for log entries')
122    parser.add_argument('--log_all_ranks', action='store_true',
123                        help='Enable logging from all distributed ranks')
124    parser.add_argument('--output_dir', type=str, default='results/output',
125                        help='Output dir')
126    parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
127                        help='Name of a checkpoint file')
128    parser.add_argument('--local_rank', type=int,
129                        default=os.getenv('LOCAL_RANK', 0))
130    parser.add_argument('--simulated_fault', type=fault_desc,
131                        help='Description of a fault to be simulated')
132    # fmt: on
133
134    args = parser.parse_args()
135    return args
136
137
138def load_checkpoint(path):
139    map_location = {
140        'cpu': 'cpu',
141    }
142    if torch.cuda.is_available():
143        map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
144
145    logging.info(f'Loading checkpoint from {path}')
146    checkpoint = torch.load(path, map_location=map_location, weights_only=True)
147    return checkpoint
148
149
150def save_checkpoint(
151    progress,
152    model,
153    optimizer,
154    ft_client,
155    output_dir,
156    checkpoint_fname,
157):
158    # Checkpointing is wrapped into "checkpoint" FT section
159    # NOTE: FT state is not stored in the checkpoint, but in a separate JSON file
160    ft_client.start_section('checkpoint')
161
162    state = {
163        'progress': progress,
164        'model_state': model.state_dict(),
165        'optimizer_state': optimizer.state_dict(),
166    }
167
168    checkpoint_path = os.path.join(output_dir, checkpoint_fname)
169
170    with dist_utils.sync_workers() as rank:
171        if rank == 0:
172            logging.info(f'Saving checkpoint to {checkpoint_path}')
173            torch.save(state, checkpoint_path)
174
175    ft_client.end_section('checkpoint')
176
177
178def maybe_load_ft_state(path):
179    # Load FT state from JSON file
180    if os.path.exists(path):
181        logging.info(f'FT state loading from: {path}')
182        with open(path, 'r') as f:
183            return json.load(f)
184    else:
185        logging.info(f'FT state file not found at: {path}')
186        return None
187
188
189def save_ft_state(ft_client, path):
190    # Save FT state into a JSON file
191    with dist_utils.sync_workers() as rank:
192        if rank == 0:
193            logging.info(f'Saving FT state into: {path}')
194            ft_state = ft_client.state_dict()
195            with open(path, 'w') as f:
196                json.dump(ft_state, f)
197
198
199def update_ft_section_timeouts(ft_client, selected_sections, calc_out_of_section, ft_state_path):
200    # Update FT timeouts and save the FT state
201    logging.info(
202        f'Updating FT section timeouts for: {selected_sections} will update out-of-section: {calc_out_of_section}'
203    )
204    ft_client.calculate_and_set_section_timeouts(
205        selected_sections=selected_sections, calc_out_of_section=calc_out_of_section
206    )
207    save_ft_state(ft_client, ft_state_path)
208
209
210def training_loop(
211    ft_client,
212    para_model,
213    model,
214    optimizer,
215    device,
216    dataloader,
217    progress,
218    args,
219):
220    # Training epoch implementation
221
222    epoch_idx = progress['epoch_idx']
223
224    para_model.train()
225
226    last_log_time = time.monotonic()
227
228    num_iters_made = 0
229
230    for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
231
232        # fwd/bwd and optimizer step are wrapped into "step" FT section
233        ft_client.start_section('step')
234
235        optimizer.zero_grad()
236        x = x.to(device)
237        y = para_model(x)
238        loss = y.mean()
239        train_loss = loss.item()
240        loss.backward()
241
242        if iter_idx % args.logging_interval == 0:
243            avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
244            logging.info(
245                f'CHECK TRAIN epoch: {epoch_idx:4d} '
246                f'iter: {iter_idx:5d} '
247                f'loss: {avg_train_loss} '
248                f'input: {x[:, 0]}'
249            )
250            if iter_idx > 0:
251                time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
252                last_log_time = time.monotonic()
253                logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
254
255        progress['iter_idx'] = iter_idx + 1
256
257        optimizer.step()
258
259        ft_client.end_section('step')
260
261        # Whether to do a periodic checkpointing
262        periodic_save = iter_idx % args.save_interval == args.save_interval - 1
263        if periodic_save:
264            save_checkpoint(
265                progress=progress,
266                model=model,
267                optimizer=optimizer,
268                ft_client=ft_client,
269                output_dir=args.output_dir,
270                checkpoint_fname=args.checkpoint_fname,
271            )
272
273        num_iters_made += 1
274
275    return num_iters_made
276
277
278def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
279
280    # Validation epoch implementation
281
282    total_val_loss = 0
283    model.eval()
284
285    for iter_idx, x in enumerate(val_dataloader):
286
287        # fwd and loss are wrapped into "step" FT section
288        # 'step' section is used for both: training and eval steps
289        ft_client.start_section('step')
290
291        x = x.to(device)
292        y = model(x)
293        loss = y.mean().item()
294        total_val_loss += loss
295
296        ft_client.end_section('step')
297
298    logging.info(
299        f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
300    )
301
302
303_sim_fault_canceled = False
304_sim_fault_is_set = False
305
306
307def _cancel_simulated_fault():
308    global _sim_fault_canceled
309    _sim_fault_canceled = True
310
311
312def _setup_simulated_fault(fault_desc, device):
313
314    global _sim_fault_is_set
315    _sim_fault_is_set = True  # should be True on all ranks
316
317    rng = random.Random()
318
319    logging.info(f"Initializing simulated fault: {fault_desc}")
320
321    fault_type = fault_desc['fault']
322    if fault_type == 'random':
323        fault_type = rng.choice(['rank_killed', 'rank_hung'])
324
325    rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
326    rank_to_fail = torch.tensor([rank_to_fail], device=device)
327    dist_utils.broadcast(rank_to_fail, 0)
328    rank_to_fail = int(rank_to_fail.item())
329
330    rank = torch.distributed.get_rank()
331    if rank != rank_to_fail:
332        return
333
334    if fault_type == 'rank_killed':
335        target_pid = os.getpid()
336        target_sig = signal.SIGKILL
337    elif fault_type == 'rank_hung':
338        target_pid = os.getpid()
339        target_sig = signal.SIGSTOP
340    else:
341        raise Exception(f"Unknown fault type {fault_type}")
342
343    delay = fault_desc['delay'] + 4.0 * rng.random()
344
345    logging.info(
346        f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
347    )
348
349    def __fault_thread():
350        time.sleep(delay)
351        if _sim_fault_canceled:
352            return
353        print(
354            f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
355            file=sys.stderr,
356        )
357        os.kill(target_pid, target_sig)
358
359    fault_sim_thread = threading.Thread(target=__fault_thread)
360    fault_sim_thread.daemon = True
361    fault_sim_thread.start()
362
363
364_signal_received = False
365
366
367def _sig_handler(*args, **kwargs):
368    print("Signal received!", file=sys.stderr)
369    global _signal_received
370    _signal_received = True
371
372
373def main():
374    signal.signal(signal.SIGTERM, _sig_handler)
375
376    args = parse_args()
377
378    torch.manual_seed(123)
379    np.random.seed(123)
380    random.seed(123)
381
382    if args.device == 'cuda':
383        device = torch.device('cuda')
384        torch.cuda.set_device(args.local_rank)
385    elif args.device == 'cpu':
386        device = torch.device('cpu')
387    else:
388        raise RuntimeError('Unknown device')
389
390    os.makedirs(args.output_dir, exist_ok=True)
391
392    dist_utils.init_distributed_with_tcp_store(device)
393    rank = dist_utils.get_rank()
394
395    if args.log_all_ranks:
396        log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
397    else:
398        log_file_name = 'train_log.log'
399    log_file_path = os.path.join(args.output_dir, log_file_name)
400
401    # NOTE: logging appends outputs to an existing log file if it already
402    # exists. Results from a single training run (potentially with many
403    # restarts from a checkpoint) are stored in a single log file.
404    log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
405
406    logging.info(args)
407    logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
408
409    # Dummy datasets
410    train_dataset = Dataset(args.train_dataset_size, args.hidden)
411    val_dataset = Dataset(args.val_dataset_size, args.hidden)
412
413    train_sampler = torch.utils.data.DistributedSampler(
414        train_dataset,
415        drop_last=True,
416    )
417
418    val_sampler = torch.utils.data.DistributedSampler(
419        val_dataset,
420    )
421
422    # A dummy model and an optimizer
423    model = Model(args.hidden).to(device)
424    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
425
426    # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
427    progress = {
428        'epoch_idx': 0,
429        'iter_idx': 0,
430    }
431
432    checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
433
434    # Initialize fault tolerance.
435    ft_client = fault_tolerance.RankMonitorClient()
436    ft_client.init_workload_monitoring()
437
438    # try to load FT state from a JSON file
439    ft_state_path = os.path.join(args.output_dir, 'ft_state.json')
440    ft_state = maybe_load_ft_state(ft_state_path)
441    if ft_state:
442        ft_client.load_state_dict(ft_state)
443
444    # Open "init" FT section that covers workload initialization
445    ft_client.start_section('init')
446
447    is_checkpoint_loaded = False
448
449    # try to load checkpoint from disk
450    if os.path.exists(checkpoint_path):
451        checkpoint = load_checkpoint(checkpoint_path)
452        if checkpoint:
453            logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
454            is_checkpoint_loaded = True
455            model.load_state_dict(checkpoint['model_state'])
456            optimizer.load_state_dict(checkpoint['optimizer_state'])
457            progress.update(checkpoint['progress'])
458
459    # Return with zero exit code if model is already fully trained.
460    if progress['epoch_idx'] == args.epochs:
461        ft_client.end_section('init')  # explicitly end "init" section, to avoid FT warning
462        ft_client.shutdown_workload_monitoring()
463        torch.distributed.destroy_process_group()
464        logging.info('Training finished.')
465        sys.exit(0)
466
467    train_dataloader = torch.utils.data.DataLoader(
468        dataset=train_dataset,
469        batch_size=args.batch,
470        sampler=train_sampler,
471        num_workers=4,
472        persistent_workers=True,
473        pin_memory=False,
474    )
475
476    val_dataloader = torch.utils.data.DataLoader(
477        dataset=val_dataset,
478        batch_size=args.batch,
479        sampler=val_sampler,
480        num_workers=4,
481    )
482
483    # Regular DDP init
484    # NOTE: for convenience code is keeping both wrapped and unwrapped model and
485    # uses wrapped model for training and unwrapped model for saving the
486    # checkpoint and validation. It doesn't increase memory consumption
487    # since both models are holding references to the same parameters.
488    # Additionally saved checkpoint is ready for inference and doesn't have to
489    # be manually unwrapped by accessing the (undocumented) "module" attribute
490    # of DDP-wrapped model.
491    if device.type == 'cuda':
492        device_ids = [args.local_rank]
493        output_device = args.local_rank
494    elif device.type == 'cpu':
495        device_ids = None
496        output_device = None
497    else:
498        raise RuntimeError('Unsupported device type')
499    para_model = torch.nn.parallel.DistributedDataParallel(
500        model, device_ids=device_ids, output_device=output_device
501    )
502
503    # "init" FT section ends here
504    ft_client.end_section('init')
505
506    if is_checkpoint_loaded:
507        # init time can be longer if there was checkpoint loading
508        # so we update "init" secton timeout if a checkpoint was loaded
509        update_ft_section_timeouts(ft_client, ['init'], False, ft_state_path)
510
511    # Iteration over epochs, notice that it starts from 'epoch_idx'
512    # which was previously loaded from the checkpoint
513    for epoch_idx in range(progress['epoch_idx'], args.epochs):
514
515        num_tr_iters_made = training_loop(
516            ft_client,
517            para_model,
518            model,
519            optimizer,
520            device,
521            train_dataloader,
522            progress,
523            args,
524        )
525
526        # If there were some training iterations observed, update "step" section timeout
527        if num_tr_iters_made > 0:
528            update_ft_section_timeouts(ft_client, ['step'], False, ft_state_path)
529
530        # epoch_idx is incremented because the current epoch is finished
531        # and potential resume from this checkpoint should start a new training epoch.
532        progress['epoch_idx'] += 1
533        progress['iter_idx'] = 0
534
535        validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
536
537        # Checkpoint contains everything needed for deterministic resume:
538        # state of the model, optimizer and other components,
539        save_checkpoint(
540            progress=progress,
541            model=model,
542            optimizer=optimizer,
543            ft_client=ft_client,
544            output_dir=args.output_dir,
545            checkpoint_fname=args.checkpoint_fname,
546        )
547
548        # Update checkpointing section timeout after checkpoint saving was seen
549        update_ft_section_timeouts(ft_client, ['checkpoint'], False, ft_state_path)
550
551        # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
552        # if _any_ rank received SIGTERM, we leave the main loop
553        if dist_utils.is_true_on_any_rank(_signal_received):
554            logging.info('Leaving the main loop, due to SIGTERM')
555            break
556
557        # Setup simulated fault
558        if args.simulated_fault and not _sim_fault_is_set:
559            _setup_simulated_fault(args.simulated_fault, device)
560
561    _cancel_simulated_fault()
562
563    # update "out-of-section" FT timeout when the training run ends normally
564    update_ft_section_timeouts(ft_client, [], True, ft_state_path)
565    ft_client.shutdown_workload_monitoring()
566    torch.distributed.destroy_process_group()
567    logging.info('Leaving main, ret_code=0')
568    sys.exit(0)
569
570
571if __name__ == "__main__":
572    main()