1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""
17Demo of DDP training with fault tolerance, using FT package sections API
18
19It should be run with `ft_launcher`. E.g.
20`ft_launcher --nproc-per-node=2 --fault-tol-cfg-path=./examples/fault_tolerance/fault_tol_cfg_sections.yaml examples/fault_tolerance/train_ddp_sections_api.py --device=cpu`
21
22This example uses following custom FT sections
23- 'init' - covers workload initialization
24- 'step' - covers training/evaluation step (fwd/bwd, loss calculation etc)
25- 'checkpoint' - covers checkpoint saving
26
27Timeout for each section is calculated when enough data is collected.
28FT "out-of-section" timeout is calculated when the training run ends normally.
29FT state is saved in a JSON file.
30
31This example allows to simulate a training fault:
32- selected rank hung
33- selected rank terminated
34"""
35import argparse
36import json
37import logging
38import os
39import random
40import signal
41import sys
42import threading
43import time
44
45import dist_utils
46import log_utils
47import numpy as np
48import torch
49import torch.nn as nn
50
51import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
52
53
54# Dummy dataset.
55class Dataset(torch.utils.data.Dataset):
56 def __init__(self, size, hidden):
57 self.size = size
58 self.hidden = hidden
59
60 def __len__(self):
61 return self.size
62
63 def __getitem__(self, idx):
64 data = torch.full(
65 (self.hidden,),
66 fill_value=idx,
67 dtype=torch.float32,
68 device='cpu',
69 )
70 return data
71
72
73# Dummy model
74class Model(nn.Module):
75 def __init__(self, hidden):
76 super().__init__()
77 self.l1 = nn.Linear(hidden, hidden)
78 self.l2 = nn.Linear(hidden, hidden)
79
80 def forward(self, x):
81 x = self.l1(x)
82 x = self.l2(x)
83 return x
84
85
86def parse_args():
87 def fault_desc(strings):
88 parts = strings.split(",")
89 assert len(parts) == 2, "Fault description must be in format 'fault,delay'"
90 return {'fault': parts[0], 'delay': float(parts[1])}
91
92 parser = argparse.ArgumentParser(
93 description='Example of PyTorch DDP training with the Fault Tolerance package',
94 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
95 )
96
97 # fmt: off
98 parser.add_argument('--hidden', type=int, default=4096,
99 help='Hidden size')
100 parser.add_argument('--batch', type=int, default=8,
101 help='Batch size')
102 parser.add_argument('--epochs', type=int, default=4,
103 help='Number of training epochs')
104 parser.add_argument('--train_dataset_size', type=int, default=1000000,
105 help='Train dataset size')
106 parser.add_argument('--val_dataset_size', type=int, default=2000,
107 help='Validation dataset size')
108 parser.add_argument('--device', type=str, default='cuda',
109 choices=['cpu', 'cuda'],
110 help='Device')
111 parser.add_argument('--save_interval', type=int, default=-1,
112 help='Interval for saving periodic checkpoints.')
113 parser.add_argument('--logging_interval', type=int, default=1,
114 help='Interval for log entries')
115 parser.add_argument('--log_all_ranks', action='store_true',
116 help='Enable logging from all distributed ranks')
117 parser.add_argument('--output_dir', type=str, default='results/output',
118 help='Output dir')
119 parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
120 help='Name of a checkpoint file')
121 parser.add_argument('--local_rank', type=int,
122 default=os.getenv('LOCAL_RANK', 0))
123 parser.add_argument('--simulated_fault', type=fault_desc,
124 help='Description of a fault to be simulated')
125 # fmt: on
126
127 args = parser.parse_args()
128 return args
129
130
131def load_checkpoint(path):
132 map_location = {
133 'cpu': 'cpu',
134 }
135 if torch.cuda.is_available():
136 map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
137
138 logging.info(f'Loading checkpoint from {path}')
139 checkpoint = torch.load(path, map_location=map_location, weights_only=True)
140 return checkpoint
141
142
143def save_checkpoint(
144 progress,
145 model,
146 optimizer,
147 ft_client,
148 output_dir,
149 checkpoint_fname,
150):
151 # Checkpointing is wrapped into "checkpoint" FT section
152 # NOTE: FT state is not stored in the checkpoint, but in a separate JSON file
153 ft_client.start_section('checkpoint')
154
155 state = {
156 'progress': progress,
157 'model_state': model.state_dict(),
158 'optimizer_state': optimizer.state_dict(),
159 }
160
161 checkpoint_path = os.path.join(output_dir, checkpoint_fname)
162
163 with dist_utils.sync_workers() as rank:
164 if rank == 0:
165 logging.info(f'Saving checkpoint to {checkpoint_path}')
166 torch.save(state, checkpoint_path)
167
168 ft_client.end_section('checkpoint')
169
170
171def maybe_load_ft_state(path):
172 # Load FT state from JSON file
173 if os.path.exists(path):
174 logging.info(f'FT state loading from: {path}')
175 with open(path, 'r') as f:
176 return json.load(f)
177 else:
178 logging.info(f'FT state file not found at: {path}')
179 return None
180
181
182def save_ft_state(ft_client, path):
183 # Save FT state into a JSON file
184 with dist_utils.sync_workers() as rank:
185 if rank == 0:
186 logging.info(f'Saving FT state into: {path}')
187 ft_state = ft_client.state_dict()
188 with open(path, 'w') as f:
189 json.dump(ft_state, f)
190
191
192def update_ft_section_timeouts(ft_client, selected_sections, calc_out_of_section, ft_state_path):
193 # Update FT timeouts and save the FT state
194 logging.info(
195 f'Updating FT section timeouts for: {selected_sections} will update out-of-section: {calc_out_of_section}'
196 )
197 ft_client.calculate_and_set_section_timeouts(
198 selected_sections=selected_sections, calc_out_of_section=calc_out_of_section
199 )
200 save_ft_state(ft_client, ft_state_path)
201
202
203def training_loop(
204 ft_client,
205 para_model,
206 model,
207 optimizer,
208 device,
209 dataloader,
210 progress,
211 args,
212):
213 # Training epoch implementation
214
215 epoch_idx = progress['epoch_idx']
216
217 para_model.train()
218
219 last_log_time = time.monotonic()
220
221 num_iters_made = 0
222
223 for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
224
225 # fwd/bwd and optimizer step are wrapped into "step" FT section
226 ft_client.start_section('step')
227
228 optimizer.zero_grad()
229 x = x.to(device)
230 y = para_model(x)
231 loss = y.mean()
232 train_loss = loss.item()
233 loss.backward()
234
235 if iter_idx % args.logging_interval == 0:
236 avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
237 logging.info(
238 f'CHECK TRAIN epoch: {epoch_idx:4d} '
239 f'iter: {iter_idx:5d} '
240 f'loss: {avg_train_loss} '
241 f'input: {x[:, 0]}'
242 )
243 if iter_idx > 0:
244 time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
245 last_log_time = time.monotonic()
246 logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
247
248 progress['iter_idx'] = iter_idx + 1
249
250 optimizer.step()
251
252 ft_client.end_section('step')
253
254 # Whether to do a periodic checkpointing
255 periodic_save = iter_idx % args.save_interval == args.save_interval - 1
256 if periodic_save:
257 save_checkpoint(
258 progress=progress,
259 model=model,
260 optimizer=optimizer,
261 ft_client=ft_client,
262 output_dir=args.output_dir,
263 checkpoint_fname=args.checkpoint_fname,
264 )
265
266 num_iters_made += 1
267
268 return num_iters_made
269
270
271def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
272
273 # Validation epoch implementation
274
275 total_val_loss = 0
276 model.eval()
277
278 for iter_idx, x in enumerate(val_dataloader):
279
280 # fwd and loss are wrapped into "step" FT section
281 # 'step' section is used for both: training and eval steps
282 ft_client.start_section('step')
283
284 x = x.to(device)
285 y = model(x)
286 loss = y.mean().item()
287 total_val_loss += loss
288
289 ft_client.end_section('step')
290
291 logging.info(
292 f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
293 )
294
295
296_sim_fault_canceled = False
297_sim_fault_is_set = False
298
299
300def _cancel_simulated_fault():
301 global _sim_fault_canceled
302 _sim_fault_canceled = True
303
304
305def _setup_simulated_fault(fault_desc, device):
306
307 global _sim_fault_is_set
308 _sim_fault_is_set = True # should be True on all ranks
309
310 rng = random.Random()
311
312 logging.info(f"Initializing simulated fault: {fault_desc}")
313
314 fault_type = fault_desc['fault']
315 if fault_type == 'random':
316 fault_type = rng.choice(['rank_killed', 'rank_hung'])
317
318 rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
319 rank_to_fail = torch.tensor([rank_to_fail], device=device)
320 dist_utils.broadcast(rank_to_fail, 0)
321 rank_to_fail = int(rank_to_fail.item())
322
323 rank = torch.distributed.get_rank()
324 if rank != rank_to_fail:
325 return
326
327 if fault_type == 'rank_killed':
328 target_pid = os.getpid()
329 target_sig = signal.SIGKILL
330 elif fault_type == 'rank_hung':
331 target_pid = os.getpid()
332 target_sig = signal.SIGSTOP
333 else:
334 raise Exception(f"Unknown fault type {fault_type}")
335
336 delay = fault_desc['delay'] + 4.0 * rng.random()
337
338 logging.info(
339 f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
340 )
341
342 def __fault_thread():
343 time.sleep(delay)
344 if _sim_fault_canceled:
345 return
346 print(
347 f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
348 file=sys.stderr,
349 )
350 os.kill(target_pid, target_sig)
351
352 fault_sim_thread = threading.Thread(target=__fault_thread)
353 fault_sim_thread.daemon = True
354 fault_sim_thread.start()
355
356
357_signal_received = False
358
359
360def _sig_handler(*args, **kwargs):
361 print("Signal received!", file=sys.stderr)
362 global _signal_received
363 _signal_received = True
364
365
366def main():
367 signal.signal(signal.SIGTERM, _sig_handler)
368
369 args = parse_args()
370
371 torch.manual_seed(123)
372 np.random.seed(123)
373 random.seed(123)
374
375 if args.device == 'cuda':
376 device = torch.device('cuda')
377 torch.cuda.set_device(args.local_rank)
378 elif args.device == 'cpu':
379 device = torch.device('cpu')
380 else:
381 raise RuntimeError('Unknown device')
382
383 os.makedirs(args.output_dir, exist_ok=True)
384
385 dist_utils.init_distributed_with_tcp_store(device)
386 rank = dist_utils.get_rank()
387
388 if args.log_all_ranks:
389 log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
390 else:
391 log_file_name = 'train_log.log'
392 log_file_path = os.path.join(args.output_dir, log_file_name)
393
394 # NOTE: logging appends outputs to an existing log file if it already
395 # exists. Results from a single training run (potentially with many
396 # restarts from a checkpoint) are stored in a single log file.
397 log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
398
399 logging.info(args)
400 logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
401
402 # Dummy datasets
403 train_dataset = Dataset(args.train_dataset_size, args.hidden)
404 val_dataset = Dataset(args.val_dataset_size, args.hidden)
405
406 train_sampler = torch.utils.data.DistributedSampler(
407 train_dataset,
408 drop_last=True,
409 )
410
411 val_sampler = torch.utils.data.DistributedSampler(
412 val_dataset,
413 )
414
415 # A dummy model and an optimizer
416 model = Model(args.hidden).to(device)
417 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
418
419 # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
420 progress = {
421 'epoch_idx': 0,
422 'iter_idx': 0,
423 }
424
425 checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
426
427 # Initialize fault tolerance.
428 ft_client = fault_tolerance.RankMonitorClient()
429 ft_client.init_workload_monitoring()
430
431 # try to load FT state from a JSON file
432 ft_state_path = os.path.join(args.output_dir, 'ft_state.json')
433 ft_state = maybe_load_ft_state(ft_state_path)
434 if ft_state:
435 ft_client.load_state_dict(ft_state)
436
437 # Open "init" FT section that covers workload initialization
438 ft_client.start_section('init')
439
440 is_checkpoint_loaded = False
441
442 # try to load checkpoint from disk
443 if os.path.exists(checkpoint_path):
444 checkpoint = load_checkpoint(checkpoint_path)
445 if checkpoint:
446 logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
447 is_checkpoint_loaded = True
448 model.load_state_dict(checkpoint['model_state'])
449 optimizer.load_state_dict(checkpoint['optimizer_state'])
450 progress.update(checkpoint['progress'])
451
452 # Return with zero exit code if model is already fully trained.
453 if progress['epoch_idx'] == args.epochs:
454 ft_client.end_section('init') # explicitly end "init" section, to avoid FT warning
455 ft_client.shutdown_workload_monitoring()
456 torch.distributed.destroy_process_group()
457 logging.info('Training finished.')
458 sys.exit(0)
459
460 train_dataloader = torch.utils.data.DataLoader(
461 dataset=train_dataset,
462 batch_size=args.batch,
463 sampler=train_sampler,
464 num_workers=4,
465 persistent_workers=True,
466 pin_memory=False,
467 )
468
469 val_dataloader = torch.utils.data.DataLoader(
470 dataset=val_dataset,
471 batch_size=args.batch,
472 sampler=val_sampler,
473 num_workers=4,
474 )
475
476 # Regular DDP init
477 # NOTE: for convenience code is keeping both wrapped and unwrapped model and
478 # uses wrapped model for training and unwrapped model for saving the
479 # checkpoint and validation. It doesn't increase memory consumption
480 # since both models are holding references to the same parameters.
481 # Additionally saved checkpoint is ready for inference and doesn't have to
482 # be manually unwrapped by accessing the (undocumented) "module" attribute
483 # of DDP-wrapped model.
484 if device.type == 'cuda':
485 device_ids = [args.local_rank]
486 output_device = args.local_rank
487 elif device.type == 'cpu':
488 device_ids = None
489 output_device = None
490 else:
491 raise RuntimeError('Unsupported device type')
492 para_model = torch.nn.parallel.DistributedDataParallel(
493 model, device_ids=device_ids, output_device=output_device
494 )
495
496 # "init" FT section ends here
497 ft_client.end_section('init')
498
499 if is_checkpoint_loaded:
500 # init time can be longer if there was checkpoint loading
501 # so we update "init" secton timeout if a checkpoint was loaded
502 update_ft_section_timeouts(ft_client, ['init'], False, ft_state_path)
503
504 # Iteration over epochs, notice that it starts from 'epoch_idx'
505 # which was previously loaded from the checkpoint
506 for epoch_idx in range(progress['epoch_idx'], args.epochs):
507
508 num_tr_iters_made = training_loop(
509 ft_client,
510 para_model,
511 model,
512 optimizer,
513 device,
514 train_dataloader,
515 progress,
516 args,
517 )
518
519 # If there were some training iterations observed, update "step" section timeout
520 if num_tr_iters_made > 0:
521 update_ft_section_timeouts(ft_client, ['step'], False, ft_state_path)
522
523 # epoch_idx is incremented because the current epoch is finished
524 # and potential resume from this checkpoint should start a new training epoch.
525 progress['epoch_idx'] += 1
526 progress['iter_idx'] = 0
527
528 validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
529
530 # Checkpoint contains everything needed for deterministic resume:
531 # state of the model, optimizer and other components,
532 save_checkpoint(
533 progress=progress,
534 model=model,
535 optimizer=optimizer,
536 ft_client=ft_client,
537 output_dir=args.output_dir,
538 checkpoint_fname=args.checkpoint_fname,
539 )
540
541 # Update checkpointing section timeout after checkpoint saving was seen
542 update_ft_section_timeouts(ft_client, ['checkpoint'], False, ft_state_path)
543
544 # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
545 # if _any_ rank received SIGTERM, we leave the main loop
546 if dist_utils.is_true_on_any_rank(_signal_received):
547 logging.info('Leaving the main loop, due to SIGTERM')
548 break
549
550 # Setup simulated fault
551 if args.simulated_fault and not _sim_fault_is_set:
552 _setup_simulated_fault(args.simulated_fault, device)
553
554 _cancel_simulated_fault()
555
556 # update "out-of-section" FT timeout when the training run ends normally
557 update_ft_section_timeouts(ft_client, [], True, ft_state_path)
558 ft_client.shutdown_workload_monitoring()
559 torch.distributed.destroy_process_group()
560 logging.info('Leaving main, ret_code=0')
561 sys.exit(0)
562
563
564if __name__ == "__main__":
565 main()