Heartbeat API usage example with DDP
Warning
This example loads checkpoints with torch.load(..., weights_only=True)
because it saves only plain state dictionaries. For PyTorch versions before
2.10.0, CVE-2026-24747 affects the weights_only unpickler. Do not load
untrusted checkpoint files with affected PyTorch versions; use PyTorch
2.10.0 or newer when checkpoint provenance is not fully trusted.
1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""
17Demo of fault tolerance with DDP training, using FT package heartbeats API
18
19This script demonstrates how to use the FT heartbeats API for hang detection in
20distributed training. It should be run with the ft_launcher command. E.g.:
21
22`ft_launcher --nproc-per-node=2 --ft-cfg-path=./examples/fault_tolerance/fault_tol_cfg_heartbeats.yaml examples/fault_tolerance/train_ddp_heartbeats_api.py --device=cpu`
23
24Fault tolerance features demonstrated:
251. Heartbeat sending during training
262. Timeout calculation and setting
273. State persistence through checkpoints
284. Simulated fault injection
29
30Security note:
31This example loads checkpoints with torch.load(..., weights_only=True)
32because it saves only plain state dictionaries. For PyTorch versions before
332.10.0, CVE-2026-24747 affects the weights_only unpickler. Do not load
34untrusted checkpoint files with affected PyTorch versions; use PyTorch 2.10.0
35or newer when checkpoint provenance is not fully trusted.
36"""
37
38import argparse
39import logging
40import os
41import random
42import signal
43import sys
44import threading
45import time
46
47import dist_utils
48import log_utils
49import numpy as np
50import torch
51import torch.nn as nn
52
53import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
54
55
56# Dummy dataset.
57class Dataset(torch.utils.data.Dataset):
58 def __init__(self, size, hidden):
59 self.size = size
60 self.hidden = hidden
61
62 def __len__(self):
63 return self.size
64
65 def __getitem__(self, idx):
66 data = torch.full(
67 (self.hidden,),
68 fill_value=idx,
69 dtype=torch.float32,
70 device='cpu',
71 )
72 return data
73
74
75# Dummy model
76class Model(nn.Module):
77 def __init__(self, hidden):
78 super().__init__()
79 self.l1 = nn.Linear(hidden, hidden)
80 self.l2 = nn.Linear(hidden, hidden)
81
82 def forward(self, x):
83 x = self.l1(x)
84 x = self.l2(x)
85 return x
86
87
88def parse_args():
89 def tuple_type(strings):
90 strings = strings.replace("(", "").replace(")", "")
91 mapped_int = map(int, strings.split(","))
92 return tuple(mapped_int)
93
94 def fault_desc(strings):
95 parts = strings.split(",")
96 assert len(parts) == 2
97 return {'fault': parts[0], 'delay': float(parts[1])}
98
99 parser = argparse.ArgumentParser(
100 description='Example of PyTorch DDP training with the Fault Tolerance package',
101 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
102 )
103
104 # fmt: off
105 parser.add_argument('--hidden', type=int, default=4096,
106 help='Hidden size')
107 parser.add_argument('--batch', type=int, default=8,
108 help='Batch size')
109 parser.add_argument('--epochs', type=int, default=4,
110 help='Number of training epochs')
111 parser.add_argument('--train_dataset_size', type=int, default=1000000,
112 help='Train dataset size')
113 parser.add_argument('--val_dataset_size', type=int, default=2000,
114 help='Validation dataset size')
115 parser.add_argument('--device', type=str, default='cuda',
116 choices=['cpu', 'cuda'],
117 help='Device')
118
119 parser.add_argument('--interrupt_at', type=tuple_type, nargs='*',
120 help='Manual interruption after (epoch, iteration), '
121 'for testing only')
122 parser.add_argument('--save_interval', type=int, default=-1,
123 help='Interval for saving periodic checkpoints')
124 parser.add_argument('--logging_interval', type=int, default=1,
125 help='Interval for log entries')
126 parser.add_argument('--log_all_ranks', action='store_true',
127 help='Enable logging from all distributed ranks')
128 parser.add_argument('--output_dir', type=str, default='results/output',
129 help='Output dir')
130 parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
131 help='Name of a checkpoint file')
132
133 parser.add_argument('--local_rank', type=int,
134 default=os.getenv('LOCAL_RANK', 0))
135 parser.add_argument('--init_distributed_method', type=str, default='tcp',
136 help='Init distributed group with TCP store ("tcp") or file store ("file")')
137
138 parser.add_argument('--simulated_fault', type=fault_desc,
139 help='Description of a fault to be simulated')
140 # fmt: on
141
142 args = parser.parse_args()
143
144 if args.interrupt_at:
145 args.interrupt_at = set(args.interrupt_at)
146 else:
147 args.interrupt_at = set()
148
149 return args
150
151
152def load_checkpoint(path):
153 map_location = {
154 'cpu': 'cpu',
155 }
156 if torch.cuda.is_available():
157 map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
158
159 logging.info(f'Loading checkpoint from {path}')
160 checkpoint = torch.load(path, map_location=map_location, weights_only=True)
161 return checkpoint
162
163
164def save_checkpoint(
165 progress,
166 model,
167 optimizer,
168 ft_client,
169 output_dir,
170 checkpoint_fname,
171):
172 state = {
173 'progress': progress,
174 'model_state': model.state_dict(),
175 'optimizer_state': optimizer.state_dict(),
176 'ft_state': ft_client.state_dict(),
177 }
178
179 checkpoint_path = os.path.join(output_dir, checkpoint_fname)
180
181 with dist_utils.sync_workers() as rank:
182 if rank == 0:
183 logging.info(f'Saving checkpoint to {checkpoint_path}')
184 torch.save(state, checkpoint_path)
185
186
187def training_loop(
188 ft_client,
189 para_model,
190 model,
191 optimizer,
192 device,
193 dataloader,
194 sampler,
195 progress,
196 args,
197):
198 epoch_idx = progress['epoch_idx']
199
200 # NOTE: torch.utils.data.DistributedSampler must be prepared for current epoch
201 # need to do it before starting iteration
202 sampler.start_sample_idx = progress['iter_idx'] * args.batch
203 sampler.set_epoch(epoch_idx)
204
205 para_model.train()
206
207 last_log_time = time.monotonic()
208
209 for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
210 if ft_client.hb_timeouts.are_valid is False and epoch_idx == 1 and iter_idx == 1:
211 # after 0th epoch is completed and we've done 0th iteration of the 1st epoch,
212 # we can calculate and set timeouts. this is a good moment to do so,
213 # because now we've seen the possibly long interval where checkpoint was saved.
214 ft_client.calculate_and_set_hb_timeouts()
215
216 optimizer.zero_grad()
217 x = x.to(device)
218 y = para_model(x)
219 loss = y.mean()
220 train_loss = loss.item()
221 loss.backward()
222
223 if iter_idx % args.logging_interval == 0:
224 avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
225 logging.info(
226 f'CHECK TRAIN epoch: {epoch_idx:4d} '
227 f'iter: {iter_idx:5d} '
228 f'loss: {avg_train_loss} '
229 f'input: {x[:, 0]}'
230 )
231 if iter_idx > 0:
232 time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
233 last_log_time = time.monotonic()
234 logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
235
236 progress['iter_idx'] = iter_idx + 1
237
238 ft_client.send_heartbeat()
239 optimizer.step()
240
241 # Whether to do a periodic checkpointing
242 periodic_save = iter_idx % args.save_interval == args.save_interval - 1
243
244 if periodic_save or (epoch_idx, iter_idx) in args.interrupt_at:
245 save_checkpoint(
246 progress=progress,
247 model=model,
248 optimizer=optimizer,
249 ft_client=ft_client,
250 output_dir=args.output_dir,
251 checkpoint_fname=args.checkpoint_fname,
252 )
253 if (epoch_idx, iter_idx) in args.interrupt_at:
254 logging.info('Manual interruption, exiting')
255 sys.exit(0)
256
257
258def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
259 total_val_loss = 0
260 model.eval()
261
262 for iter_idx, x in enumerate(val_dataloader):
263 x = x.to(device)
264 y = model(x)
265 loss = y.mean().item()
266 total_val_loss += loss
267 ft_client.send_heartbeat()
268
269 logging.info(
270 f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
271 )
272
273
274_sim_fault_canceled = False
275_sim_fault_is_set = False
276
277
278def _cancel_simulated_fault():
279 global _sim_fault_canceled
280 _sim_fault_canceled = True
281
282
283def _setup_simulated_fault(ft_client, fault_desc, device):
284 # FIXME: hanging rank with SIGTSTP results in rank monitor
285 # blocked when trying to receive the data in _on_ipc_data_from_rank
286
287 global _sim_fault_is_set
288 _sim_fault_is_set = True # should be True on all ranks
289
290 rng = random.Random()
291
292 logging.info(f"Initializing simulated fault: {fault_desc}")
293
294 fault_type = fault_desc['fault']
295 if fault_type == 'random':
296 fault_type = rng.choice(['rank_killed', 'rank_hung'])
297
298 rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
299 rank_to_fail = torch.tensor([rank_to_fail], device=device)
300 dist_utils.broadcast(rank_to_fail, 0)
301 rank_to_fail = int(rank_to_fail.item())
302
303 rank = torch.distributed.get_rank()
304 if rank != rank_to_fail:
305 return
306
307 if fault_type == 'rank_killed':
308 target_pid = os.getpid()
309 target_sig = signal.SIGKILL
310 elif fault_type == 'rank_hung':
311 target_pid = os.getpid()
312 target_sig = signal.SIGSTOP
313 else:
314 raise Exception(f"Unknown fault type {fault_type}")
315
316 delay = fault_desc['delay'] + 4.0 * rng.random()
317
318 logging.info(
319 f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
320 )
321
322 def __fault_thread():
323 time.sleep(delay)
324 if _sim_fault_canceled:
325 return
326 print(
327 f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
328 file=sys.stderr,
329 )
330 os.kill(target_pid, target_sig)
331
332 fault_sim_thread = threading.Thread(target=__fault_thread)
333 fault_sim_thread.daemon = True
334 fault_sim_thread.start()
335
336
337_signal_received = False
338
339
340def _sig_handler(*args, **kwargs):
341 print("Signal received!", file=sys.stderr)
342 global _signal_received
343 _signal_received = True
344
345
346def main():
347 signal.signal(signal.SIGTERM, _sig_handler)
348
349 args = parse_args()
350
351 torch.manual_seed(123)
352 np.random.seed(123)
353 random.seed(123)
354
355 if args.device == 'cuda':
356 if torch.cuda.is_available():
357 device = torch.device('cuda')
358 torch.cuda.set_device(args.local_rank)
359 else:
360 raise RuntimeError("Selected 'cuda' device but torch.cuda is not available.")
361 elif args.device == 'cpu':
362 device = torch.device('cpu')
363 else:
364 raise RuntimeError('Unknown device')
365
366 if int(os.getenv('WORLD_SIZE', '1')) == 1:
367 raise RuntimeError('This example supports only multi-gpu training')
368
369 os.makedirs(args.output_dir, exist_ok=True)
370
371 if args.init_distributed_method == 'tcp':
372 # NOTE: when runing tests with tcp init method we noticed
373 # occasional "address already in use" errors, after workload
374 # is restarted
375 dist_utils.init_distributed_with_tcp_store(device)
376 elif args.init_distributed_method == 'file':
377 dist_utils.init_distributed_with_file_store(device, store_file_dir=args.output_dir)
378 else:
379 raise RuntimeError(
380 f"--init_distributed_method should be ['tcp','file'] it is {args.init_distributed_method}"
381 )
382
383 if args.log_all_ranks:
384 log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
385 else:
386 log_file_name = 'train_log.log'
387 log_file_path = os.path.join(args.output_dir, log_file_name)
388
389 # NOTE: logging appends outputs to an existing log file if it already
390 # exists. Results from a single training run (potentially with many
391 # restarts from a checkpoint) are stored in a single log file.
392 log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
393 logging.info(args)
394
395 rank = dist_utils.get_rank()
396
397 logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
398
399 # Dummy datasets
400 train_dataset = Dataset(args.train_dataset_size, args.hidden)
401 val_dataset = Dataset(args.val_dataset_size, args.hidden)
402
403 # ResumableDistributedSampler is needed to skip consumed samples
404 train_sampler = dist_utils.ResumableDistributedSampler(
405 train_dataset,
406 drop_last=True,
407 )
408
409 val_sampler = torch.utils.data.DistributedSampler(
410 val_dataset,
411 )
412
413 # A dummy model and an optimizer
414 model = Model(args.hidden).to(device)
415 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
416
417 # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
418 progress = {
419 'epoch_idx': 0,
420 'iter_idx': 0,
421 }
422
423 checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
424
425 # Initialize fault tolerance.
426 ft_client = fault_tolerance.RankMonitorClient()
427 ft_client.init_workload_monitoring()
428
429 checkpoint = None
430
431 # try to load checkpoint from disk
432 if os.path.exists(checkpoint_path):
433 checkpoint = load_checkpoint(checkpoint_path)
434 if checkpoint:
435 logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
436
437 if checkpoint:
438 model.load_state_dict(checkpoint['model_state'])
439 optimizer.load_state_dict(checkpoint['optimizer_state'])
440 ft_client.load_state_dict(checkpoint['ft_state'])
441 progress.update(checkpoint['progress'])
442 # Return with zero exit code if model is already fully trained.
443 if progress['epoch_idx'] == args.epochs:
444 logging.info('Training finished.')
445 ft_client.shutdown_workload_monitoring()
446 torch.distributed.destroy_process_group()
447 sys.exit(0)
448
449 train_dataloader = torch.utils.data.DataLoader(
450 dataset=train_dataset,
451 batch_size=args.batch,
452 sampler=train_sampler,
453 num_workers=4,
454 persistent_workers=True,
455 pin_memory=False,
456 )
457
458 val_dataloader = torch.utils.data.DataLoader(
459 dataset=val_dataset,
460 batch_size=args.batch,
461 sampler=val_sampler,
462 num_workers=4,
463 )
464
465 # Regular DDP init
466 # NOTE: for convenience code is keeping both wrapped and unwrapped model and
467 # uses wrapped model for training and unwrapped model for saving the
468 # checkpoint and validation. It doesn't increase memory consumption
469 # since both models are holding references to the same parameters.
470 # Additionally saved checkpoint is ready for inference and doesn't have to
471 # be manually unwrapped by accessing the (undocumented) "module" attribute
472 # of DDP-wrapped model.
473 if device.type == 'cuda':
474 device_ids = [args.local_rank]
475 output_device = args.local_rank
476 elif device.type == 'cpu':
477 device_ids = None
478 output_device = None
479 else:
480 raise RuntimeError('Unsupported device type')
481 para_model = torch.nn.parallel.DistributedDataParallel(
482 model, device_ids=device_ids, output_device=output_device
483 )
484
485 # Iteration over epochs, notice that it starts from 'epoch_idx'
486 # which was previously loaded from the checkpoint
487 for epoch_idx in range(progress['epoch_idx'], args.epochs):
488 training_loop(
489 ft_client,
490 para_model,
491 model,
492 optimizer,
493 device,
494 train_dataloader,
495 train_sampler,
496 progress,
497 args,
498 )
499
500 # epoch_idx is incremented because the current epoch is finished
501 # and potential resume from this checkpoint should start a new training epoch.
502 progress['epoch_idx'] += 1
503 progress['iter_idx'] = 0
504
505 validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
506
507 # Checkpoint contains everything needed for deterministic resume:
508 # state of the model, optimizer and other components,
509 save_checkpoint(
510 progress=progress,
511 model=model,
512 optimizer=optimizer,
513 ft_client=ft_client,
514 output_dir=args.output_dir,
515 checkpoint_fname=args.checkpoint_fname,
516 )
517
518 # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
519 # if _any_ rank received SIGTERM, we leave the main loop
520 if dist_utils.is_true_on_any_rank(_signal_received):
521 logging.info('Leaving the main loop, due to SIGTERM')
522 break
523
524 # Setup simulated fault as soon as we have valid timeouts
525 if args.simulated_fault and not _sim_fault_is_set and ft_client.hb_timeouts.are_valid:
526 _setup_simulated_fault(ft_client, args.simulated_fault, device)
527
528 _cancel_simulated_fault()
529 ft_client.shutdown_workload_monitoring()
530 torch.distributed.destroy_process_group()
531 logging.info('Leaving main, ret_code=0')
532 sys.exit(0)
533
534
535if __name__ == "__main__":
536 main()