Section API usage example with DDP
Warning
This example loads checkpoints with torch.load(..., weights_only=True)
because it saves only plain state dictionaries. For PyTorch versions before
2.10.0, CVE-2026-24747 affects the weights_only unpickler. Do not load
untrusted checkpoint files with affected PyTorch versions; use PyTorch
2.10.0 or newer when checkpoint provenance is not fully trusted.
1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""
17Demo of DDP training with fault tolerance, using FT package sections API
18
19It should be run with `ft_launcher`. E.g.
20`ft_launcher --nproc-per-node=2 --ft-cfg-path=./examples/fault_tolerance/fault_tol_cfg_sections.yaml examples/fault_tolerance/train_ddp_sections_api.py --device=cpu`
21
22This example uses following custom FT sections
23- 'init' - covers workload initialization
24- 'step' - covers training/evaluation step (fwd/bwd, loss calculation etc)
25- 'checkpoint' - covers checkpoint saving
26
27Timeout for each section is calculated when enough data is collected.
28FT "out-of-section" timeout is calculated when the training run ends normally.
29FT state is saved in a JSON file.
30
31This example allows to simulate a training fault:
32- selected rank hung
33- selected rank terminated
34
35Security note:
36This example loads checkpoints with torch.load(..., weights_only=True)
37because it saves only plain state dictionaries. For PyTorch versions before
382.10.0, CVE-2026-24747 affects the weights_only unpickler. Do not load
39untrusted checkpoint files with affected PyTorch versions; use PyTorch 2.10.0
40or newer when checkpoint provenance is not fully trusted.
41"""
42import argparse
43import json
44import logging
45import os
46import random
47import signal
48import sys
49import threading
50import time
51
52import dist_utils
53import log_utils
54import numpy as np
55import torch
56import torch.nn as nn
57
58import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
59
60
61# Dummy dataset.
62class Dataset(torch.utils.data.Dataset):
63 def __init__(self, size, hidden):
64 self.size = size
65 self.hidden = hidden
66
67 def __len__(self):
68 return self.size
69
70 def __getitem__(self, idx):
71 data = torch.full(
72 (self.hidden,),
73 fill_value=idx,
74 dtype=torch.float32,
75 device='cpu',
76 )
77 return data
78
79
80# Dummy model
81class Model(nn.Module):
82 def __init__(self, hidden):
83 super().__init__()
84 self.l1 = nn.Linear(hidden, hidden)
85 self.l2 = nn.Linear(hidden, hidden)
86
87 def forward(self, x):
88 x = self.l1(x)
89 x = self.l2(x)
90 return x
91
92
93def parse_args():
94 def fault_desc(strings):
95 parts = strings.split(",")
96 assert len(parts) == 2, "Fault description must be in format 'fault,delay'"
97 return {'fault': parts[0], 'delay': float(parts[1])}
98
99 parser = argparse.ArgumentParser(
100 description='Example of PyTorch DDP training with the Fault Tolerance package',
101 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
102 )
103
104 # fmt: off
105 parser.add_argument('--hidden', type=int, default=4096,
106 help='Hidden size')
107 parser.add_argument('--batch', type=int, default=8,
108 help='Batch size')
109 parser.add_argument('--epochs', type=int, default=4,
110 help='Number of training epochs')
111 parser.add_argument('--train_dataset_size', type=int, default=1000000,
112 help='Train dataset size')
113 parser.add_argument('--val_dataset_size', type=int, default=2000,
114 help='Validation dataset size')
115 parser.add_argument('--device', type=str, default='cuda',
116 choices=['cpu', 'cuda'],
117 help='Device')
118 parser.add_argument('--save_interval', type=int, default=-1,
119 help='Interval for saving periodic checkpoints.')
120 parser.add_argument('--logging_interval', type=int, default=1,
121 help='Interval for log entries')
122 parser.add_argument('--log_all_ranks', action='store_true',
123 help='Enable logging from all distributed ranks')
124 parser.add_argument('--output_dir', type=str, default='results/output',
125 help='Output dir')
126 parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
127 help='Name of a checkpoint file')
128 parser.add_argument('--local_rank', type=int,
129 default=os.getenv('LOCAL_RANK', 0))
130 parser.add_argument('--simulated_fault', type=fault_desc,
131 help='Description of a fault to be simulated')
132 # fmt: on
133
134 args = parser.parse_args()
135 return args
136
137
138def load_checkpoint(path):
139 map_location = {
140 'cpu': 'cpu',
141 }
142 if torch.cuda.is_available():
143 map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
144
145 logging.info(f'Loading checkpoint from {path}')
146 checkpoint = torch.load(path, map_location=map_location, weights_only=True)
147 return checkpoint
148
149
150def save_checkpoint(
151 progress,
152 model,
153 optimizer,
154 ft_client,
155 output_dir,
156 checkpoint_fname,
157):
158 # Checkpointing is wrapped into "checkpoint" FT section
159 # NOTE: FT state is not stored in the checkpoint, but in a separate JSON file
160 ft_client.start_section('checkpoint')
161
162 state = {
163 'progress': progress,
164 'model_state': model.state_dict(),
165 'optimizer_state': optimizer.state_dict(),
166 }
167
168 checkpoint_path = os.path.join(output_dir, checkpoint_fname)
169
170 with dist_utils.sync_workers() as rank:
171 if rank == 0:
172 logging.info(f'Saving checkpoint to {checkpoint_path}')
173 torch.save(state, checkpoint_path)
174
175 ft_client.end_section('checkpoint')
176
177
178def maybe_load_ft_state(path):
179 # Load FT state from JSON file
180 if os.path.exists(path):
181 logging.info(f'FT state loading from: {path}')
182 with open(path, 'r') as f:
183 return json.load(f)
184 else:
185 logging.info(f'FT state file not found at: {path}')
186 return None
187
188
189def save_ft_state(ft_client, path):
190 # Save FT state into a JSON file
191 with dist_utils.sync_workers() as rank:
192 if rank == 0:
193 logging.info(f'Saving FT state into: {path}')
194 ft_state = ft_client.state_dict()
195 with open(path, 'w') as f:
196 json.dump(ft_state, f)
197
198
199def update_ft_section_timeouts(ft_client, selected_sections, calc_out_of_section, ft_state_path):
200 # Update FT timeouts and save the FT state
201 logging.info(
202 f'Updating FT section timeouts for: {selected_sections} will update out-of-section: {calc_out_of_section}'
203 )
204 ft_client.calculate_and_set_section_timeouts(
205 selected_sections=selected_sections, calc_out_of_section=calc_out_of_section
206 )
207 save_ft_state(ft_client, ft_state_path)
208
209
210def training_loop(
211 ft_client,
212 para_model,
213 model,
214 optimizer,
215 device,
216 dataloader,
217 progress,
218 args,
219):
220 # Training epoch implementation
221
222 epoch_idx = progress['epoch_idx']
223
224 para_model.train()
225
226 last_log_time = time.monotonic()
227
228 num_iters_made = 0
229
230 for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
231
232 # fwd/bwd and optimizer step are wrapped into "step" FT section
233 ft_client.start_section('step')
234
235 optimizer.zero_grad()
236 x = x.to(device)
237 y = para_model(x)
238 loss = y.mean()
239 train_loss = loss.item()
240 loss.backward()
241
242 if iter_idx % args.logging_interval == 0:
243 avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
244 logging.info(
245 f'CHECK TRAIN epoch: {epoch_idx:4d} '
246 f'iter: {iter_idx:5d} '
247 f'loss: {avg_train_loss} '
248 f'input: {x[:, 0]}'
249 )
250 if iter_idx > 0:
251 time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
252 last_log_time = time.monotonic()
253 logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
254
255 progress['iter_idx'] = iter_idx + 1
256
257 optimizer.step()
258
259 ft_client.end_section('step')
260
261 # Whether to do a periodic checkpointing
262 periodic_save = iter_idx % args.save_interval == args.save_interval - 1
263 if periodic_save:
264 save_checkpoint(
265 progress=progress,
266 model=model,
267 optimizer=optimizer,
268 ft_client=ft_client,
269 output_dir=args.output_dir,
270 checkpoint_fname=args.checkpoint_fname,
271 )
272
273 num_iters_made += 1
274
275 return num_iters_made
276
277
278def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
279
280 # Validation epoch implementation
281
282 total_val_loss = 0
283 model.eval()
284
285 for iter_idx, x in enumerate(val_dataloader):
286
287 # fwd and loss are wrapped into "step" FT section
288 # 'step' section is used for both: training and eval steps
289 ft_client.start_section('step')
290
291 x = x.to(device)
292 y = model(x)
293 loss = y.mean().item()
294 total_val_loss += loss
295
296 ft_client.end_section('step')
297
298 logging.info(
299 f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
300 )
301
302
303_sim_fault_canceled = False
304_sim_fault_is_set = False
305
306
307def _cancel_simulated_fault():
308 global _sim_fault_canceled
309 _sim_fault_canceled = True
310
311
312def _setup_simulated_fault(fault_desc, device):
313
314 global _sim_fault_is_set
315 _sim_fault_is_set = True # should be True on all ranks
316
317 rng = random.Random()
318
319 logging.info(f"Initializing simulated fault: {fault_desc}")
320
321 fault_type = fault_desc['fault']
322 if fault_type == 'random':
323 fault_type = rng.choice(['rank_killed', 'rank_hung'])
324
325 rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
326 rank_to_fail = torch.tensor([rank_to_fail], device=device)
327 dist_utils.broadcast(rank_to_fail, 0)
328 rank_to_fail = int(rank_to_fail.item())
329
330 rank = torch.distributed.get_rank()
331 if rank != rank_to_fail:
332 return
333
334 if fault_type == 'rank_killed':
335 target_pid = os.getpid()
336 target_sig = signal.SIGKILL
337 elif fault_type == 'rank_hung':
338 target_pid = os.getpid()
339 target_sig = signal.SIGSTOP
340 else:
341 raise Exception(f"Unknown fault type {fault_type}")
342
343 delay = fault_desc['delay'] + 4.0 * rng.random()
344
345 logging.info(
346 f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
347 )
348
349 def __fault_thread():
350 time.sleep(delay)
351 if _sim_fault_canceled:
352 return
353 print(
354 f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
355 file=sys.stderr,
356 )
357 os.kill(target_pid, target_sig)
358
359 fault_sim_thread = threading.Thread(target=__fault_thread)
360 fault_sim_thread.daemon = True
361 fault_sim_thread.start()
362
363
364_signal_received = False
365
366
367def _sig_handler(*args, **kwargs):
368 print("Signal received!", file=sys.stderr)
369 global _signal_received
370 _signal_received = True
371
372
373def main():
374 signal.signal(signal.SIGTERM, _sig_handler)
375
376 args = parse_args()
377
378 torch.manual_seed(123)
379 np.random.seed(123)
380 random.seed(123)
381
382 if args.device == 'cuda':
383 device = torch.device('cuda')
384 torch.cuda.set_device(args.local_rank)
385 elif args.device == 'cpu':
386 device = torch.device('cpu')
387 else:
388 raise RuntimeError('Unknown device')
389
390 os.makedirs(args.output_dir, exist_ok=True)
391
392 dist_utils.init_distributed_with_tcp_store(device)
393 rank = dist_utils.get_rank()
394
395 if args.log_all_ranks:
396 log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
397 else:
398 log_file_name = 'train_log.log'
399 log_file_path = os.path.join(args.output_dir, log_file_name)
400
401 # NOTE: logging appends outputs to an existing log file if it already
402 # exists. Results from a single training run (potentially with many
403 # restarts from a checkpoint) are stored in a single log file.
404 log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
405
406 logging.info(args)
407 logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
408
409 # Dummy datasets
410 train_dataset = Dataset(args.train_dataset_size, args.hidden)
411 val_dataset = Dataset(args.val_dataset_size, args.hidden)
412
413 train_sampler = torch.utils.data.DistributedSampler(
414 train_dataset,
415 drop_last=True,
416 )
417
418 val_sampler = torch.utils.data.DistributedSampler(
419 val_dataset,
420 )
421
422 # A dummy model and an optimizer
423 model = Model(args.hidden).to(device)
424 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
425
426 # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
427 progress = {
428 'epoch_idx': 0,
429 'iter_idx': 0,
430 }
431
432 checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
433
434 # Initialize fault tolerance.
435 ft_client = fault_tolerance.RankMonitorClient()
436 ft_client.init_workload_monitoring()
437
438 # try to load FT state from a JSON file
439 ft_state_path = os.path.join(args.output_dir, 'ft_state.json')
440 ft_state = maybe_load_ft_state(ft_state_path)
441 if ft_state:
442 ft_client.load_state_dict(ft_state)
443
444 # Open "init" FT section that covers workload initialization
445 ft_client.start_section('init')
446
447 is_checkpoint_loaded = False
448
449 # try to load checkpoint from disk
450 if os.path.exists(checkpoint_path):
451 checkpoint = load_checkpoint(checkpoint_path)
452 if checkpoint:
453 logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
454 is_checkpoint_loaded = True
455 model.load_state_dict(checkpoint['model_state'])
456 optimizer.load_state_dict(checkpoint['optimizer_state'])
457 progress.update(checkpoint['progress'])
458
459 # Return with zero exit code if model is already fully trained.
460 if progress['epoch_idx'] == args.epochs:
461 ft_client.end_section('init') # explicitly end "init" section, to avoid FT warning
462 ft_client.shutdown_workload_monitoring()
463 torch.distributed.destroy_process_group()
464 logging.info('Training finished.')
465 sys.exit(0)
466
467 train_dataloader = torch.utils.data.DataLoader(
468 dataset=train_dataset,
469 batch_size=args.batch,
470 sampler=train_sampler,
471 num_workers=4,
472 persistent_workers=True,
473 pin_memory=False,
474 )
475
476 val_dataloader = torch.utils.data.DataLoader(
477 dataset=val_dataset,
478 batch_size=args.batch,
479 sampler=val_sampler,
480 num_workers=4,
481 )
482
483 # Regular DDP init
484 # NOTE: for convenience code is keeping both wrapped and unwrapped model and
485 # uses wrapped model for training and unwrapped model for saving the
486 # checkpoint and validation. It doesn't increase memory consumption
487 # since both models are holding references to the same parameters.
488 # Additionally saved checkpoint is ready for inference and doesn't have to
489 # be manually unwrapped by accessing the (undocumented) "module" attribute
490 # of DDP-wrapped model.
491 if device.type == 'cuda':
492 device_ids = [args.local_rank]
493 output_device = args.local_rank
494 elif device.type == 'cpu':
495 device_ids = None
496 output_device = None
497 else:
498 raise RuntimeError('Unsupported device type')
499 para_model = torch.nn.parallel.DistributedDataParallel(
500 model, device_ids=device_ids, output_device=output_device
501 )
502
503 # "init" FT section ends here
504 ft_client.end_section('init')
505
506 if is_checkpoint_loaded:
507 # init time can be longer if there was checkpoint loading
508 # so we update "init" secton timeout if a checkpoint was loaded
509 update_ft_section_timeouts(ft_client, ['init'], False, ft_state_path)
510
511 # Iteration over epochs, notice that it starts from 'epoch_idx'
512 # which was previously loaded from the checkpoint
513 for epoch_idx in range(progress['epoch_idx'], args.epochs):
514
515 num_tr_iters_made = training_loop(
516 ft_client,
517 para_model,
518 model,
519 optimizer,
520 device,
521 train_dataloader,
522 progress,
523 args,
524 )
525
526 # If there were some training iterations observed, update "step" section timeout
527 if num_tr_iters_made > 0:
528 update_ft_section_timeouts(ft_client, ['step'], False, ft_state_path)
529
530 # epoch_idx is incremented because the current epoch is finished
531 # and potential resume from this checkpoint should start a new training epoch.
532 progress['epoch_idx'] += 1
533 progress['iter_idx'] = 0
534
535 validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
536
537 # Checkpoint contains everything needed for deterministic resume:
538 # state of the model, optimizer and other components,
539 save_checkpoint(
540 progress=progress,
541 model=model,
542 optimizer=optimizer,
543 ft_client=ft_client,
544 output_dir=args.output_dir,
545 checkpoint_fname=args.checkpoint_fname,
546 )
547
548 # Update checkpointing section timeout after checkpoint saving was seen
549 update_ft_section_timeouts(ft_client, ['checkpoint'], False, ft_state_path)
550
551 # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
552 # if _any_ rank received SIGTERM, we leave the main loop
553 if dist_utils.is_true_on_any_rank(_signal_received):
554 logging.info('Leaving the main loop, due to SIGTERM')
555 break
556
557 # Setup simulated fault
558 if args.simulated_fault and not _sim_fault_is_set:
559 _setup_simulated_fault(args.simulated_fault, device)
560
561 _cancel_simulated_fault()
562
563 # update "out-of-section" FT timeout when the training run ends normally
564 update_ft_section_timeouts(ft_client, [], True, ft_state_path)
565 ft_client.shutdown_workload_monitoring()
566 torch.distributed.destroy_process_group()
567 logging.info('Leaving main, ret_code=0')
568 sys.exit(0)
569
570
571if __name__ == "__main__":
572 main()