1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""
17Demo of fault tolerance with DDP training, using FT package heartbeats API
18"""
19
20import argparse
21import logging
22import os
23import random
24import signal
25import sys
26import threading
27import time
28
29import dist_utils
30import log_utils
31import numpy as np
32import torch
33import torch.nn as nn
34
35import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
36
37
38# Dummy dataset.
39class Dataset(torch.utils.data.Dataset):
40 def __init__(self, size, hidden):
41 self.size = size
42 self.hidden = hidden
43
44 def __len__(self):
45 return self.size
46
47 def __getitem__(self, idx):
48 data = torch.full(
49 (self.hidden,),
50 fill_value=idx,
51 dtype=torch.float32,
52 device='cpu',
53 )
54 return data
55
56
57# Dummy model
58class Model(nn.Module):
59 def __init__(self, hidden):
60 super().__init__()
61 self.l1 = nn.Linear(hidden, hidden)
62 self.l2 = nn.Linear(hidden, hidden)
63
64 def forward(self, x):
65 x = self.l1(x)
66 x = self.l2(x)
67 return x
68
69
70def parse_args():
71 def tuple_type(strings):
72 strings = strings.replace("(", "").replace(")", "")
73 mapped_int = map(int, strings.split(","))
74 return tuple(mapped_int)
75
76 def fault_desc(strings):
77 parts = strings.split(",")
78 assert len(parts) == 2
79 return {'fault': parts[0], 'delay': float(parts[1])}
80
81 parser = argparse.ArgumentParser(
82 description='Example of PyTorch DDP training with the Fault Tolerance package',
83 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
84 )
85
86 # fmt: off
87 parser.add_argument('--hidden', type=int, default=4096,
88 help='Hidden size')
89 parser.add_argument('--batch', type=int, default=8,
90 help='Batch size')
91 parser.add_argument('--epochs', type=int, default=4,
92 help='Number of training epochs')
93 parser.add_argument('--train_dataset_size', type=int, default=1000000,
94 help='Train dataset size')
95 parser.add_argument('--val_dataset_size', type=int, default=2000,
96 help='Validation dataset size')
97 parser.add_argument('--device', type=str, default='cuda',
98 choices=['cpu', 'cuda'],
99 help='Device')
100
101 parser.add_argument('--interrupt_at', type=tuple_type, nargs='*',
102 help='Manual interruption after (epoch, iteration), '
103 'for testing only')
104 parser.add_argument('--save_interval', type=int, default=-1,
105 help='Interval for saving periodic checkpoints')
106 parser.add_argument('--logging_interval', type=int, default=1,
107 help='Interval for log entries')
108 parser.add_argument('--log_all_ranks', action='store_true',
109 help='Enable logging from all distributed ranks')
110 parser.add_argument('--output_dir', type=str, default='results/output',
111 help='Output dir')
112 parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
113 help='Name of a checkpoint file')
114
115 parser.add_argument('--local_rank', type=int,
116 default=os.getenv('LOCAL_RANK', 0))
117 parser.add_argument('--init_distributed_method', type=str, default='tcp',
118 help='Init distributed group with TCP store ("tcp") or file store ("file")')
119
120 parser.add_argument('--simulated_fault', type=fault_desc,
121 help='Description of a fault to be simulated')
122 # fmt: on
123
124 args = parser.parse_args()
125
126 if args.interrupt_at:
127 args.interrupt_at = set(args.interrupt_at)
128 else:
129 args.interrupt_at = set()
130
131 return args
132
133
134def load_checkpoint(path):
135 map_location = {
136 'cpu': 'cpu',
137 }
138 if torch.cuda.is_available():
139 map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
140
141 logging.info(f'Loading checkpoint from {path}')
142 checkpoint = torch.load(path, map_location=map_location, weights_only=True)
143 return checkpoint
144
145
146def save_checkpoint(
147 progress,
148 model,
149 optimizer,
150 ft_client,
151 output_dir,
152 checkpoint_fname,
153):
154 state = {
155 'progress': progress,
156 'model_state': model.state_dict(),
157 'optimizer_state': optimizer.state_dict(),
158 'ft_state': ft_client.state_dict(),
159 }
160
161 checkpoint_path = os.path.join(output_dir, checkpoint_fname)
162
163 with dist_utils.sync_workers() as rank:
164 if rank == 0:
165 logging.info(f'Saving checkpoint to {checkpoint_path}')
166 torch.save(state, checkpoint_path)
167
168
169def training_loop(
170 ft_client,
171 para_model,
172 model,
173 optimizer,
174 device,
175 dataloader,
176 sampler,
177 progress,
178 args,
179):
180 epoch_idx = progress['epoch_idx']
181
182 # NOTE: torch.utils.data.DistributedSampler must be prepared for current epoch
183 # need to do it before starting iteration
184 sampler.start_sample_idx = progress['iter_idx'] * args.batch
185 sampler.set_epoch(epoch_idx)
186
187 para_model.train()
188
189 last_log_time = time.monotonic()
190
191 for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
192 if ft_client.hb_timeouts.are_valid is False and epoch_idx == 1 and iter_idx == 1:
193 # after 0th epoch is completed and we've done 0th iteration of the 1st epoch,
194 # we can calculate and set timeouts. this is a good moment to do so,
195 # because now we've seen the possibly long interval where checkpoint was saved.
196 ft_client.calculate_and_set_hb_timeouts()
197
198 optimizer.zero_grad()
199 x = x.to(device)
200 y = para_model(x)
201 loss = y.mean()
202 train_loss = loss.item()
203 loss.backward()
204
205 if iter_idx % args.logging_interval == 0:
206 avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
207 logging.info(
208 f'CHECK TRAIN epoch: {epoch_idx:4d} '
209 f'iter: {iter_idx:5d} '
210 f'loss: {avg_train_loss} '
211 f'input: {x[:, 0]}'
212 )
213 if iter_idx > 0:
214 time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
215 last_log_time = time.monotonic()
216 logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
217
218 progress['iter_idx'] = iter_idx + 1
219
220 ft_client.send_heartbeat()
221 optimizer.step()
222
223 # Whether to do a periodic checkpointing
224 periodic_save = iter_idx % args.save_interval == args.save_interval - 1
225
226 if periodic_save or (epoch_idx, iter_idx) in args.interrupt_at:
227 save_checkpoint(
228 progress=progress,
229 model=model,
230 optimizer=optimizer,
231 ft_client=ft_client,
232 output_dir=args.output_dir,
233 checkpoint_fname=args.checkpoint_fname,
234 )
235 if (epoch_idx, iter_idx) in args.interrupt_at:
236 logging.info('Manual interruption, exiting')
237 sys.exit(0)
238
239
240def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
241 total_val_loss = 0
242 model.eval()
243
244 for iter_idx, x in enumerate(val_dataloader):
245 x = x.to(device)
246 y = model(x)
247 loss = y.mean().item()
248 total_val_loss += loss
249 ft_client.send_heartbeat()
250
251 logging.info(
252 f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
253 )
254
255
256_sim_fault_canceled = False
257_sim_fault_is_set = False
258
259
260def _cancel_simulated_fault():
261 global _sim_fault_canceled
262 _sim_fault_canceled = True
263
264
265def _setup_simulated_fault(ft_client, fault_desc, device):
266 # FIXME: hanging rank with SIGTSTP results in rank monitor
267 # blocked when trying to receive the data in _on_ipc_data_from_rank
268
269 global _sim_fault_is_set
270 _sim_fault_is_set = True # should be True on all ranks
271
272 rng = random.Random()
273
274 logging.info(f"Initializing simulated fault: {fault_desc}")
275
276 fault_type = fault_desc['fault']
277 if fault_type == 'random':
278 fault_type = rng.choice(['rank_killed', 'rank_hung'])
279
280 rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
281 rank_to_fail = torch.tensor([rank_to_fail], device=device)
282 dist_utils.broadcast(rank_to_fail, 0)
283 rank_to_fail = int(rank_to_fail.item())
284
285 rank = torch.distributed.get_rank()
286 if rank != rank_to_fail:
287 return
288
289 if fault_type == 'rank_killed':
290 target_pid = os.getpid()
291 target_sig = signal.SIGKILL
292 elif fault_type == 'rank_hung':
293 target_pid = os.getpid()
294 target_sig = signal.SIGSTOP
295 else:
296 raise Exception(f"Unknown fault type {fault_type}")
297
298 delay = fault_desc['delay'] + 4.0 * rng.random()
299
300 logging.info(
301 f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
302 )
303
304 def __fault_thread():
305 time.sleep(delay)
306 if _sim_fault_canceled:
307 return
308 print(
309 f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
310 file=sys.stderr,
311 )
312 os.kill(target_pid, target_sig)
313
314 fault_sim_thread = threading.Thread(target=__fault_thread)
315 fault_sim_thread.daemon = True
316 fault_sim_thread.start()
317
318
319_signal_received = False
320
321
322def _sig_handler(*args, **kwargs):
323 print("Signal received!", file=sys.stderr)
324 global _signal_received
325 _signal_received = True
326
327
328def main():
329 signal.signal(signal.SIGTERM, _sig_handler)
330
331 args = parse_args()
332
333 torch.manual_seed(123)
334 np.random.seed(123)
335 random.seed(123)
336
337 if args.device == 'cuda':
338 if torch.cuda.is_available():
339 device = torch.device('cuda')
340 torch.cuda.set_device(args.local_rank)
341 else:
342 raise RuntimeError("Selected 'cuda' device but torch.cuda is not available.")
343 elif args.device == 'cpu':
344 device = torch.device('cpu')
345 else:
346 raise RuntimeError('Unknown device')
347
348 if int(os.getenv('WORLD_SIZE', '1')) == 1:
349 raise RuntimeError('This example supports only multi-gpu training')
350
351 os.makedirs(args.output_dir, exist_ok=True)
352
353 if args.init_distributed_method == 'tcp':
354 # NOTE: when runing tests with tcp init method we noticed
355 # occasional "address already in use" errors, after workload
356 # is restarted
357 dist_utils.init_distributed_with_tcp_store(device)
358 elif args.init_distributed_method == 'file':
359 dist_utils.init_distributed_with_file_store(device, store_file_dir=args.output_dir)
360 else:
361 raise RuntimeError(
362 f"--init_distributed_method should be ['tcp','file'] it is {args.init_distributed_method}"
363 )
364
365 if args.log_all_ranks:
366 log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
367 else:
368 log_file_name = 'train_log.log'
369 log_file_path = os.path.join(args.output_dir, log_file_name)
370
371 # NOTE: logging appends outputs to an existing log file if it already
372 # exists. Results from a single training run (potentially with many
373 # restarts from a checkpoint) are stored in a single log file.
374 log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
375 logging.info(args)
376
377 rank = dist_utils.get_rank()
378
379 logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
380
381 # Dummy datasets
382 train_dataset = Dataset(args.train_dataset_size, args.hidden)
383 val_dataset = Dataset(args.val_dataset_size, args.hidden)
384
385 # ResumableDistributedSampler is needed to skip consumed samples
386 train_sampler = dist_utils.ResumableDistributedSampler(
387 train_dataset,
388 drop_last=True,
389 )
390
391 val_sampler = torch.utils.data.DistributedSampler(
392 val_dataset,
393 )
394
395 # A dummy model and an optimizer
396 model = Model(args.hidden).to(device)
397 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
398
399 # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
400 progress = {
401 'epoch_idx': 0,
402 'iter_idx': 0,
403 }
404
405 checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
406
407 # Initialize fault tolerance.
408 ft_client = fault_tolerance.RankMonitorClient()
409 ft_client.init_workload_monitoring()
410
411 checkpoint = None
412
413 # try to load checkpoint from disk
414 if os.path.exists(checkpoint_path):
415 checkpoint = load_checkpoint(checkpoint_path)
416 if checkpoint:
417 logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
418
419 if checkpoint:
420 model.load_state_dict(checkpoint['model_state'])
421 optimizer.load_state_dict(checkpoint['optimizer_state'])
422 ft_client.load_state_dict(checkpoint['ft_state'])
423 progress.update(checkpoint['progress'])
424 # Return with zero exit code if model is already fully trained.
425 if progress['epoch_idx'] == args.epochs:
426 logging.info('Training finished.')
427 ft_client.shutdown_workload_monitoring()
428 torch.distributed.destroy_process_group()
429 sys.exit(0)
430
431 train_dataloader = torch.utils.data.DataLoader(
432 dataset=train_dataset,
433 batch_size=args.batch,
434 sampler=train_sampler,
435 num_workers=4,
436 persistent_workers=True,
437 pin_memory=False,
438 )
439
440 val_dataloader = torch.utils.data.DataLoader(
441 dataset=val_dataset,
442 batch_size=args.batch,
443 sampler=val_sampler,
444 num_workers=4,
445 )
446
447 # Regular DDP init
448 # NOTE: for convenience code is keeping both wrapped and unwrapped model and
449 # uses wrapped model for training and unwrapped model for saving the
450 # checkpoint and validation. It doesn't increase memory consumption
451 # since both models are holding references to the same parameters.
452 # Additionally saved checkpoint is ready for inference and doesn't have to
453 # be manually unwrapped by accessing the (undocumented) "module" attribute
454 # of DDP-wrapped model.
455 if device.type == 'cuda':
456 device_ids = [args.local_rank]
457 output_device = args.local_rank
458 elif device.type == 'cpu':
459 device_ids = None
460 output_device = None
461 else:
462 raise RuntimeError('Unsupported device type')
463 para_model = torch.nn.parallel.DistributedDataParallel(
464 model, device_ids=device_ids, output_device=output_device
465 )
466
467 # Iteration over epochs, notice that it starts from 'epoch_idx'
468 # which was previously loaded from the checkpoint
469 for epoch_idx in range(progress['epoch_idx'], args.epochs):
470 training_loop(
471 ft_client,
472 para_model,
473 model,
474 optimizer,
475 device,
476 train_dataloader,
477 train_sampler,
478 progress,
479 args,
480 )
481
482 # epoch_idx is incremented because the current epoch is finished
483 # and potential resume from this checkpoint should start a new training epoch.
484 progress['epoch_idx'] += 1
485 progress['iter_idx'] = 0
486
487 validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
488
489 # Checkpoint contains everything needed for deterministic resume:
490 # state of the model, optimizer and other components,
491 save_checkpoint(
492 progress=progress,
493 model=model,
494 optimizer=optimizer,
495 ft_client=ft_client,
496 output_dir=args.output_dir,
497 checkpoint_fname=args.checkpoint_fname,
498 )
499
500 # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
501 # if _any_ rank received SIGTERM, we leave the main loop
502 if dist_utils.is_true_on_any_rank(_signal_received):
503 logging.info('Leaving the main loop, due to SIGTERM')
504 break
505
506 # Setup simulated fault as soon as we have valid timeouts
507 if args.simulated_fault and not _sim_fault_is_set and ft_client.hb_timeouts.are_valid:
508 _setup_simulated_fault(ft_client, args.simulated_fault, device)
509
510 _cancel_simulated_fault()
511 ft_client.shutdown_workload_monitoring()
512 torch.distributed.destroy_process_group()
513 logging.info('Leaving main, ret_code=0')
514 sys.exit(0)
515
516
517if __name__ == "__main__":
518 main()