1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""
17Demo of fault tolerance with DDP training, using FT package heartbeats API
18
19This script demonstrates how to use the FT heartbeats API for hang detection in
20distributed training. It should be run with the ft_launcher command. E.g.:
21
22`ft_launcher --nproc-per-node=2 --ft-cfg-path=./examples/fault_tolerance/fault_tol_cfg_heartbeats.yaml examples/fault_tolerance/train_ddp_heartbeats_api.py --device=cpu`
23
24Fault tolerance features demonstrated:
251. Heartbeat sending during training
262. Timeout calculation and setting
273. State persistence through checkpoints
284. Simulated fault injection
29"""
30
31import argparse
32import logging
33import os
34import random
35import signal
36import sys
37import threading
38import time
39
40import dist_utils
41import log_utils
42import numpy as np
43import torch
44import torch.nn as nn
45
46import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
47
48
49# Dummy dataset.
50class Dataset(torch.utils.data.Dataset):
51 def __init__(self, size, hidden):
52 self.size = size
53 self.hidden = hidden
54
55 def __len__(self):
56 return self.size
57
58 def __getitem__(self, idx):
59 data = torch.full(
60 (self.hidden,),
61 fill_value=idx,
62 dtype=torch.float32,
63 device='cpu',
64 )
65 return data
66
67
68# Dummy model
69class Model(nn.Module):
70 def __init__(self, hidden):
71 super().__init__()
72 self.l1 = nn.Linear(hidden, hidden)
73 self.l2 = nn.Linear(hidden, hidden)
74
75 def forward(self, x):
76 x = self.l1(x)
77 x = self.l2(x)
78 return x
79
80
81def parse_args():
82 def tuple_type(strings):
83 strings = strings.replace("(", "").replace(")", "")
84 mapped_int = map(int, strings.split(","))
85 return tuple(mapped_int)
86
87 def fault_desc(strings):
88 parts = strings.split(",")
89 assert len(parts) == 2
90 return {'fault': parts[0], 'delay': float(parts[1])}
91
92 parser = argparse.ArgumentParser(
93 description='Example of PyTorch DDP training with the Fault Tolerance package',
94 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
95 )
96
97 # fmt: off
98 parser.add_argument('--hidden', type=int, default=4096,
99 help='Hidden size')
100 parser.add_argument('--batch', type=int, default=8,
101 help='Batch size')
102 parser.add_argument('--epochs', type=int, default=4,
103 help='Number of training epochs')
104 parser.add_argument('--train_dataset_size', type=int, default=1000000,
105 help='Train dataset size')
106 parser.add_argument('--val_dataset_size', type=int, default=2000,
107 help='Validation dataset size')
108 parser.add_argument('--device', type=str, default='cuda',
109 choices=['cpu', 'cuda'],
110 help='Device')
111
112 parser.add_argument('--interrupt_at', type=tuple_type, nargs='*',
113 help='Manual interruption after (epoch, iteration), '
114 'for testing only')
115 parser.add_argument('--save_interval', type=int, default=-1,
116 help='Interval for saving periodic checkpoints')
117 parser.add_argument('--logging_interval', type=int, default=1,
118 help='Interval for log entries')
119 parser.add_argument('--log_all_ranks', action='store_true',
120 help='Enable logging from all distributed ranks')
121 parser.add_argument('--output_dir', type=str, default='results/output',
122 help='Output dir')
123 parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
124 help='Name of a checkpoint file')
125
126 parser.add_argument('--local_rank', type=int,
127 default=os.getenv('LOCAL_RANK', 0))
128 parser.add_argument('--init_distributed_method', type=str, default='tcp',
129 help='Init distributed group with TCP store ("tcp") or file store ("file")')
130
131 parser.add_argument('--simulated_fault', type=fault_desc,
132 help='Description of a fault to be simulated')
133 # fmt: on
134
135 args = parser.parse_args()
136
137 if args.interrupt_at:
138 args.interrupt_at = set(args.interrupt_at)
139 else:
140 args.interrupt_at = set()
141
142 return args
143
144
145def load_checkpoint(path):
146 map_location = {
147 'cpu': 'cpu',
148 }
149 if torch.cuda.is_available():
150 map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
151
152 logging.info(f'Loading checkpoint from {path}')
153 checkpoint = torch.load(path, map_location=map_location, weights_only=True)
154 return checkpoint
155
156
157def save_checkpoint(
158 progress,
159 model,
160 optimizer,
161 ft_client,
162 output_dir,
163 checkpoint_fname,
164):
165 state = {
166 'progress': progress,
167 'model_state': model.state_dict(),
168 'optimizer_state': optimizer.state_dict(),
169 'ft_state': ft_client.state_dict(),
170 }
171
172 checkpoint_path = os.path.join(output_dir, checkpoint_fname)
173
174 with dist_utils.sync_workers() as rank:
175 if rank == 0:
176 logging.info(f'Saving checkpoint to {checkpoint_path}')
177 torch.save(state, checkpoint_path)
178
179
180def training_loop(
181 ft_client,
182 para_model,
183 model,
184 optimizer,
185 device,
186 dataloader,
187 sampler,
188 progress,
189 args,
190):
191 epoch_idx = progress['epoch_idx']
192
193 # NOTE: torch.utils.data.DistributedSampler must be prepared for current epoch
194 # need to do it before starting iteration
195 sampler.start_sample_idx = progress['iter_idx'] * args.batch
196 sampler.set_epoch(epoch_idx)
197
198 para_model.train()
199
200 last_log_time = time.monotonic()
201
202 for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
203 if ft_client.hb_timeouts.are_valid is False and epoch_idx == 1 and iter_idx == 1:
204 # after 0th epoch is completed and we've done 0th iteration of the 1st epoch,
205 # we can calculate and set timeouts. this is a good moment to do so,
206 # because now we've seen the possibly long interval where checkpoint was saved.
207 ft_client.calculate_and_set_hb_timeouts()
208
209 optimizer.zero_grad()
210 x = x.to(device)
211 y = para_model(x)
212 loss = y.mean()
213 train_loss = loss.item()
214 loss.backward()
215
216 if iter_idx % args.logging_interval == 0:
217 avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
218 logging.info(
219 f'CHECK TRAIN epoch: {epoch_idx:4d} '
220 f'iter: {iter_idx:5d} '
221 f'loss: {avg_train_loss} '
222 f'input: {x[:, 0]}'
223 )
224 if iter_idx > 0:
225 time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
226 last_log_time = time.monotonic()
227 logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
228
229 progress['iter_idx'] = iter_idx + 1
230
231 ft_client.send_heartbeat()
232 optimizer.step()
233
234 # Whether to do a periodic checkpointing
235 periodic_save = iter_idx % args.save_interval == args.save_interval - 1
236
237 if periodic_save or (epoch_idx, iter_idx) in args.interrupt_at:
238 save_checkpoint(
239 progress=progress,
240 model=model,
241 optimizer=optimizer,
242 ft_client=ft_client,
243 output_dir=args.output_dir,
244 checkpoint_fname=args.checkpoint_fname,
245 )
246 if (epoch_idx, iter_idx) in args.interrupt_at:
247 logging.info('Manual interruption, exiting')
248 sys.exit(0)
249
250
251def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
252 total_val_loss = 0
253 model.eval()
254
255 for iter_idx, x in enumerate(val_dataloader):
256 x = x.to(device)
257 y = model(x)
258 loss = y.mean().item()
259 total_val_loss += loss
260 ft_client.send_heartbeat()
261
262 logging.info(
263 f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
264 )
265
266
267_sim_fault_canceled = False
268_sim_fault_is_set = False
269
270
271def _cancel_simulated_fault():
272 global _sim_fault_canceled
273 _sim_fault_canceled = True
274
275
276def _setup_simulated_fault(ft_client, fault_desc, device):
277 # FIXME: hanging rank with SIGTSTP results in rank monitor
278 # blocked when trying to receive the data in _on_ipc_data_from_rank
279
280 global _sim_fault_is_set
281 _sim_fault_is_set = True # should be True on all ranks
282
283 rng = random.Random()
284
285 logging.info(f"Initializing simulated fault: {fault_desc}")
286
287 fault_type = fault_desc['fault']
288 if fault_type == 'random':
289 fault_type = rng.choice(['rank_killed', 'rank_hung'])
290
291 rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
292 rank_to_fail = torch.tensor([rank_to_fail], device=device)
293 dist_utils.broadcast(rank_to_fail, 0)
294 rank_to_fail = int(rank_to_fail.item())
295
296 rank = torch.distributed.get_rank()
297 if rank != rank_to_fail:
298 return
299
300 if fault_type == 'rank_killed':
301 target_pid = os.getpid()
302 target_sig = signal.SIGKILL
303 elif fault_type == 'rank_hung':
304 target_pid = os.getpid()
305 target_sig = signal.SIGSTOP
306 else:
307 raise Exception(f"Unknown fault type {fault_type}")
308
309 delay = fault_desc['delay'] + 4.0 * rng.random()
310
311 logging.info(
312 f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
313 )
314
315 def __fault_thread():
316 time.sleep(delay)
317 if _sim_fault_canceled:
318 return
319 print(
320 f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
321 file=sys.stderr,
322 )
323 os.kill(target_pid, target_sig)
324
325 fault_sim_thread = threading.Thread(target=__fault_thread)
326 fault_sim_thread.daemon = True
327 fault_sim_thread.start()
328
329
330_signal_received = False
331
332
333def _sig_handler(*args, **kwargs):
334 print("Signal received!", file=sys.stderr)
335 global _signal_received
336 _signal_received = True
337
338
339def main():
340 signal.signal(signal.SIGTERM, _sig_handler)
341
342 args = parse_args()
343
344 torch.manual_seed(123)
345 np.random.seed(123)
346 random.seed(123)
347
348 if args.device == 'cuda':
349 if torch.cuda.is_available():
350 device = torch.device('cuda')
351 torch.cuda.set_device(args.local_rank)
352 else:
353 raise RuntimeError("Selected 'cuda' device but torch.cuda is not available.")
354 elif args.device == 'cpu':
355 device = torch.device('cpu')
356 else:
357 raise RuntimeError('Unknown device')
358
359 if int(os.getenv('WORLD_SIZE', '1')) == 1:
360 raise RuntimeError('This example supports only multi-gpu training')
361
362 os.makedirs(args.output_dir, exist_ok=True)
363
364 if args.init_distributed_method == 'tcp':
365 # NOTE: when runing tests with tcp init method we noticed
366 # occasional "address already in use" errors, after workload
367 # is restarted
368 dist_utils.init_distributed_with_tcp_store(device)
369 elif args.init_distributed_method == 'file':
370 dist_utils.init_distributed_with_file_store(device, store_file_dir=args.output_dir)
371 else:
372 raise RuntimeError(
373 f"--init_distributed_method should be ['tcp','file'] it is {args.init_distributed_method}"
374 )
375
376 if args.log_all_ranks:
377 log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
378 else:
379 log_file_name = 'train_log.log'
380 log_file_path = os.path.join(args.output_dir, log_file_name)
381
382 # NOTE: logging appends outputs to an existing log file if it already
383 # exists. Results from a single training run (potentially with many
384 # restarts from a checkpoint) are stored in a single log file.
385 log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
386 logging.info(args)
387
388 rank = dist_utils.get_rank()
389
390 logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
391
392 # Dummy datasets
393 train_dataset = Dataset(args.train_dataset_size, args.hidden)
394 val_dataset = Dataset(args.val_dataset_size, args.hidden)
395
396 # ResumableDistributedSampler is needed to skip consumed samples
397 train_sampler = dist_utils.ResumableDistributedSampler(
398 train_dataset,
399 drop_last=True,
400 )
401
402 val_sampler = torch.utils.data.DistributedSampler(
403 val_dataset,
404 )
405
406 # A dummy model and an optimizer
407 model = Model(args.hidden).to(device)
408 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
409
410 # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
411 progress = {
412 'epoch_idx': 0,
413 'iter_idx': 0,
414 }
415
416 checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
417
418 # Initialize fault tolerance.
419 ft_client = fault_tolerance.RankMonitorClient()
420 ft_client.init_workload_monitoring()
421
422 checkpoint = None
423
424 # try to load checkpoint from disk
425 if os.path.exists(checkpoint_path):
426 checkpoint = load_checkpoint(checkpoint_path)
427 if checkpoint:
428 logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
429
430 if checkpoint:
431 model.load_state_dict(checkpoint['model_state'])
432 optimizer.load_state_dict(checkpoint['optimizer_state'])
433 ft_client.load_state_dict(checkpoint['ft_state'])
434 progress.update(checkpoint['progress'])
435 # Return with zero exit code if model is already fully trained.
436 if progress['epoch_idx'] == args.epochs:
437 logging.info('Training finished.')
438 ft_client.shutdown_workload_monitoring()
439 torch.distributed.destroy_process_group()
440 sys.exit(0)
441
442 train_dataloader = torch.utils.data.DataLoader(
443 dataset=train_dataset,
444 batch_size=args.batch,
445 sampler=train_sampler,
446 num_workers=4,
447 persistent_workers=True,
448 pin_memory=False,
449 )
450
451 val_dataloader = torch.utils.data.DataLoader(
452 dataset=val_dataset,
453 batch_size=args.batch,
454 sampler=val_sampler,
455 num_workers=4,
456 )
457
458 # Regular DDP init
459 # NOTE: for convenience code is keeping both wrapped and unwrapped model and
460 # uses wrapped model for training and unwrapped model for saving the
461 # checkpoint and validation. It doesn't increase memory consumption
462 # since both models are holding references to the same parameters.
463 # Additionally saved checkpoint is ready for inference and doesn't have to
464 # be manually unwrapped by accessing the (undocumented) "module" attribute
465 # of DDP-wrapped model.
466 if device.type == 'cuda':
467 device_ids = [args.local_rank]
468 output_device = args.local_rank
469 elif device.type == 'cpu':
470 device_ids = None
471 output_device = None
472 else:
473 raise RuntimeError('Unsupported device type')
474 para_model = torch.nn.parallel.DistributedDataParallel(
475 model, device_ids=device_ids, output_device=output_device
476 )
477
478 # Iteration over epochs, notice that it starts from 'epoch_idx'
479 # which was previously loaded from the checkpoint
480 for epoch_idx in range(progress['epoch_idx'], args.epochs):
481 training_loop(
482 ft_client,
483 para_model,
484 model,
485 optimizer,
486 device,
487 train_dataloader,
488 train_sampler,
489 progress,
490 args,
491 )
492
493 # epoch_idx is incremented because the current epoch is finished
494 # and potential resume from this checkpoint should start a new training epoch.
495 progress['epoch_idx'] += 1
496 progress['iter_idx'] = 0
497
498 validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
499
500 # Checkpoint contains everything needed for deterministic resume:
501 # state of the model, optimizer and other components,
502 save_checkpoint(
503 progress=progress,
504 model=model,
505 optimizer=optimizer,
506 ft_client=ft_client,
507 output_dir=args.output_dir,
508 checkpoint_fname=args.checkpoint_fname,
509 )
510
511 # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
512 # if _any_ rank received SIGTERM, we leave the main loop
513 if dist_utils.is_true_on_any_rank(_signal_received):
514 logging.info('Leaving the main loop, due to SIGTERM')
515 break
516
517 # Setup simulated fault as soon as we have valid timeouts
518 if args.simulated_fault and not _sim_fault_is_set and ft_client.hb_timeouts.are_valid:
519 _setup_simulated_fault(ft_client, args.simulated_fault, device)
520
521 _cancel_simulated_fault()
522 ft_client.shutdown_workload_monitoring()
523 torch.distributed.destroy_process_group()
524 logging.info('Leaving main, ret_code=0')
525 sys.exit(0)
526
527
528if __name__ == "__main__":
529 main()