Heartbeat API usage example with DDP

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16"""
 17Demo of fault tolerance with DDP training, using FT package heartbeats API
 18
 19This script demonstrates how to use the FT heartbeats API for hang detection in
 20distributed training. It should be run with the ft_launcher command. E.g.:
 21
 22`ft_launcher --nproc-per-node=2 --ft-cfg-path=./examples/fault_tolerance/fault_tol_cfg_heartbeats.yaml examples/fault_tolerance/train_ddp_heartbeats_api.py --device=cpu`
 23
 24Fault tolerance features demonstrated:
 251. Heartbeat sending during training
 262. Timeout calculation and setting
 273. State persistence through checkpoints
 284. Simulated fault injection
 29"""
 30
 31import argparse
 32import logging
 33import os
 34import random
 35import signal
 36import sys
 37import threading
 38import time
 39
 40import dist_utils
 41import log_utils
 42import numpy as np
 43import torch
 44import torch.nn as nn
 45
 46import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
 47
 48
 49# Dummy dataset.
 50class Dataset(torch.utils.data.Dataset):
 51    def __init__(self, size, hidden):
 52        self.size = size
 53        self.hidden = hidden
 54
 55    def __len__(self):
 56        return self.size
 57
 58    def __getitem__(self, idx):
 59        data = torch.full(
 60            (self.hidden,),
 61            fill_value=idx,
 62            dtype=torch.float32,
 63            device='cpu',
 64        )
 65        return data
 66
 67
 68# Dummy model
 69class Model(nn.Module):
 70    def __init__(self, hidden):
 71        super().__init__()
 72        self.l1 = nn.Linear(hidden, hidden)
 73        self.l2 = nn.Linear(hidden, hidden)
 74
 75    def forward(self, x):
 76        x = self.l1(x)
 77        x = self.l2(x)
 78        return x
 79
 80
 81def parse_args():
 82    def tuple_type(strings):
 83        strings = strings.replace("(", "").replace(")", "")
 84        mapped_int = map(int, strings.split(","))
 85        return tuple(mapped_int)
 86
 87    def fault_desc(strings):
 88        parts = strings.split(",")
 89        assert len(parts) == 2
 90        return {'fault': parts[0], 'delay': float(parts[1])}
 91
 92    parser = argparse.ArgumentParser(
 93        description='Example of PyTorch DDP training with the Fault Tolerance package',
 94        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 95    )
 96
 97    # fmt: off
 98    parser.add_argument('--hidden', type=int, default=4096,
 99                        help='Hidden size')
100    parser.add_argument('--batch', type=int, default=8,
101                        help='Batch size')
102    parser.add_argument('--epochs', type=int, default=4,
103                        help='Number of training epochs')
104    parser.add_argument('--train_dataset_size', type=int, default=1000000,
105                        help='Train dataset size')
106    parser.add_argument('--val_dataset_size', type=int, default=2000,
107                        help='Validation dataset size')
108    parser.add_argument('--device', type=str, default='cuda',
109                        choices=['cpu', 'cuda'],
110                        help='Device')
111    
112    parser.add_argument('--interrupt_at', type=tuple_type, nargs='*',
113                        help='Manual interruption after (epoch, iteration), '
114                        'for testing only')
115    parser.add_argument('--save_interval', type=int, default=-1,
116                        help='Interval for saving periodic checkpoints')
117    parser.add_argument('--logging_interval', type=int, default=1,
118                        help='Interval for log entries')
119    parser.add_argument('--log_all_ranks', action='store_true',
120                        help='Enable logging from all distributed ranks')
121    parser.add_argument('--output_dir', type=str, default='results/output',
122                        help='Output dir')
123    parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
124                        help='Name of a checkpoint file')
125    
126    parser.add_argument('--local_rank', type=int,
127                        default=os.getenv('LOCAL_RANK', 0))
128    parser.add_argument('--init_distributed_method', type=str, default='tcp',
129                        help='Init distributed group with TCP store ("tcp") or file store ("file")')
130
131    parser.add_argument('--simulated_fault', type=fault_desc,
132                        help='Description of a fault to be simulated')
133    # fmt: on
134
135    args = parser.parse_args()
136
137    if args.interrupt_at:
138        args.interrupt_at = set(args.interrupt_at)
139    else:
140        args.interrupt_at = set()
141
142    return args
143
144
145def load_checkpoint(path):
146    map_location = {
147        'cpu': 'cpu',
148    }
149    if torch.cuda.is_available():
150        map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
151
152    logging.info(f'Loading checkpoint from {path}')
153    checkpoint = torch.load(path, map_location=map_location, weights_only=True)
154    return checkpoint
155
156
157def save_checkpoint(
158    progress,
159    model,
160    optimizer,
161    ft_client,
162    output_dir,
163    checkpoint_fname,
164):
165    state = {
166        'progress': progress,
167        'model_state': model.state_dict(),
168        'optimizer_state': optimizer.state_dict(),
169        'ft_state': ft_client.state_dict(),
170    }
171
172    checkpoint_path = os.path.join(output_dir, checkpoint_fname)
173
174    with dist_utils.sync_workers() as rank:
175        if rank == 0:
176            logging.info(f'Saving checkpoint to {checkpoint_path}')
177            torch.save(state, checkpoint_path)
178
179
180def training_loop(
181    ft_client,
182    para_model,
183    model,
184    optimizer,
185    device,
186    dataloader,
187    sampler,
188    progress,
189    args,
190):
191    epoch_idx = progress['epoch_idx']
192
193    # NOTE: torch.utils.data.DistributedSampler must be prepared for current epoch
194    # need to do it before starting iteration
195    sampler.start_sample_idx = progress['iter_idx'] * args.batch
196    sampler.set_epoch(epoch_idx)
197
198    para_model.train()
199
200    last_log_time = time.monotonic()
201
202    for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
203        if ft_client.hb_timeouts.are_valid is False and epoch_idx == 1 and iter_idx == 1:
204            # after 0th epoch is completed and we've done 0th iteration of the 1st epoch,
205            # we can calculate and set timeouts. this is a good moment to do so,
206            # because now we've seen the possibly long interval where checkpoint was saved.
207            ft_client.calculate_and_set_hb_timeouts()
208
209        optimizer.zero_grad()
210        x = x.to(device)
211        y = para_model(x)
212        loss = y.mean()
213        train_loss = loss.item()
214        loss.backward()
215
216        if iter_idx % args.logging_interval == 0:
217            avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
218            logging.info(
219                f'CHECK TRAIN epoch: {epoch_idx:4d} '
220                f'iter: {iter_idx:5d} '
221                f'loss: {avg_train_loss} '
222                f'input: {x[:, 0]}'
223            )
224            if iter_idx > 0:
225                time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
226                last_log_time = time.monotonic()
227                logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
228
229        progress['iter_idx'] = iter_idx + 1
230
231        ft_client.send_heartbeat()
232        optimizer.step()
233
234        # Whether to do a periodic checkpointing
235        periodic_save = iter_idx % args.save_interval == args.save_interval - 1
236
237        if periodic_save or (epoch_idx, iter_idx) in args.interrupt_at:
238            save_checkpoint(
239                progress=progress,
240                model=model,
241                optimizer=optimizer,
242                ft_client=ft_client,
243                output_dir=args.output_dir,
244                checkpoint_fname=args.checkpoint_fname,
245            )
246            if (epoch_idx, iter_idx) in args.interrupt_at:
247                logging.info('Manual interruption, exiting')
248                sys.exit(0)
249
250
251def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
252    total_val_loss = 0
253    model.eval()
254
255    for iter_idx, x in enumerate(val_dataloader):
256        x = x.to(device)
257        y = model(x)
258        loss = y.mean().item()
259        total_val_loss += loss
260        ft_client.send_heartbeat()
261
262    logging.info(
263        f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
264    )
265
266
267_sim_fault_canceled = False
268_sim_fault_is_set = False
269
270
271def _cancel_simulated_fault():
272    global _sim_fault_canceled
273    _sim_fault_canceled = True
274
275
276def _setup_simulated_fault(ft_client, fault_desc, device):
277    # FIXME: hanging rank with SIGTSTP results in rank monitor
278    # blocked when trying to receive the data in _on_ipc_data_from_rank
279
280    global _sim_fault_is_set
281    _sim_fault_is_set = True  # should be True on all ranks
282
283    rng = random.Random()
284
285    logging.info(f"Initializing simulated fault: {fault_desc}")
286
287    fault_type = fault_desc['fault']
288    if fault_type == 'random':
289        fault_type = rng.choice(['rank_killed', 'rank_hung'])
290
291    rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
292    rank_to_fail = torch.tensor([rank_to_fail], device=device)
293    dist_utils.broadcast(rank_to_fail, 0)
294    rank_to_fail = int(rank_to_fail.item())
295
296    rank = torch.distributed.get_rank()
297    if rank != rank_to_fail:
298        return
299
300    if fault_type == 'rank_killed':
301        target_pid = os.getpid()
302        target_sig = signal.SIGKILL
303    elif fault_type == 'rank_hung':
304        target_pid = os.getpid()
305        target_sig = signal.SIGSTOP
306    else:
307        raise Exception(f"Unknown fault type {fault_type}")
308
309    delay = fault_desc['delay'] + 4.0 * rng.random()
310
311    logging.info(
312        f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
313    )
314
315    def __fault_thread():
316        time.sleep(delay)
317        if _sim_fault_canceled:
318            return
319        print(
320            f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
321            file=sys.stderr,
322        )
323        os.kill(target_pid, target_sig)
324
325    fault_sim_thread = threading.Thread(target=__fault_thread)
326    fault_sim_thread.daemon = True
327    fault_sim_thread.start()
328
329
330_signal_received = False
331
332
333def _sig_handler(*args, **kwargs):
334    print("Signal received!", file=sys.stderr)
335    global _signal_received
336    _signal_received = True
337
338
339def main():
340    signal.signal(signal.SIGTERM, _sig_handler)
341
342    args = parse_args()
343
344    torch.manual_seed(123)
345    np.random.seed(123)
346    random.seed(123)
347
348    if args.device == 'cuda':
349        if torch.cuda.is_available():
350            device = torch.device('cuda')
351            torch.cuda.set_device(args.local_rank)
352        else:
353            raise RuntimeError("Selected 'cuda' device but torch.cuda is not available.")
354    elif args.device == 'cpu':
355        device = torch.device('cpu')
356    else:
357        raise RuntimeError('Unknown device')
358
359    if int(os.getenv('WORLD_SIZE', '1')) == 1:
360        raise RuntimeError('This example supports only multi-gpu training')
361
362    os.makedirs(args.output_dir, exist_ok=True)
363
364    if args.init_distributed_method == 'tcp':
365        # NOTE: when runing tests with tcp init method we noticed
366        # occasional "address already in use" errors, after workload
367        # is restarted
368        dist_utils.init_distributed_with_tcp_store(device)
369    elif args.init_distributed_method == 'file':
370        dist_utils.init_distributed_with_file_store(device, store_file_dir=args.output_dir)
371    else:
372        raise RuntimeError(
373            f"--init_distributed_method should be ['tcp','file'] it is {args.init_distributed_method}"
374        )
375
376    if args.log_all_ranks:
377        log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
378    else:
379        log_file_name = 'train_log.log'
380    log_file_path = os.path.join(args.output_dir, log_file_name)
381
382    # NOTE: logging appends outputs to an existing log file if it already
383    # exists. Results from a single training run (potentially with many
384    # restarts from a checkpoint) are stored in a single log file.
385    log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
386    logging.info(args)
387
388    rank = dist_utils.get_rank()
389
390    logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
391
392    # Dummy datasets
393    train_dataset = Dataset(args.train_dataset_size, args.hidden)
394    val_dataset = Dataset(args.val_dataset_size, args.hidden)
395
396    # ResumableDistributedSampler is needed to skip consumed samples
397    train_sampler = dist_utils.ResumableDistributedSampler(
398        train_dataset,
399        drop_last=True,
400    )
401
402    val_sampler = torch.utils.data.DistributedSampler(
403        val_dataset,
404    )
405
406    # A dummy model and an optimizer
407    model = Model(args.hidden).to(device)
408    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
409
410    # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
411    progress = {
412        'epoch_idx': 0,
413        'iter_idx': 0,
414    }
415
416    checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
417
418    # Initialize fault tolerance.
419    ft_client = fault_tolerance.RankMonitorClient()
420    ft_client.init_workload_monitoring()
421
422    checkpoint = None
423
424    # try to load checkpoint from disk
425    if os.path.exists(checkpoint_path):
426        checkpoint = load_checkpoint(checkpoint_path)
427        if checkpoint:
428            logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
429
430    if checkpoint:
431        model.load_state_dict(checkpoint['model_state'])
432        optimizer.load_state_dict(checkpoint['optimizer_state'])
433        ft_client.load_state_dict(checkpoint['ft_state'])
434        progress.update(checkpoint['progress'])
435        # Return with zero exit code if model is already fully trained.
436        if progress['epoch_idx'] == args.epochs:
437            logging.info('Training finished.')
438            ft_client.shutdown_workload_monitoring()
439            torch.distributed.destroy_process_group()
440            sys.exit(0)
441
442    train_dataloader = torch.utils.data.DataLoader(
443        dataset=train_dataset,
444        batch_size=args.batch,
445        sampler=train_sampler,
446        num_workers=4,
447        persistent_workers=True,
448        pin_memory=False,
449    )
450
451    val_dataloader = torch.utils.data.DataLoader(
452        dataset=val_dataset,
453        batch_size=args.batch,
454        sampler=val_sampler,
455        num_workers=4,
456    )
457
458    # Regular DDP init
459    # NOTE: for convenience code is keeping both wrapped and unwrapped model and
460    # uses wrapped model for training and unwrapped model for saving the
461    # checkpoint and validation. It doesn't increase memory consumption
462    # since both models are holding references to the same parameters.
463    # Additionally saved checkpoint is ready for inference and doesn't have to
464    # be manually unwrapped by accessing the (undocumented) "module" attribute
465    # of DDP-wrapped model.
466    if device.type == 'cuda':
467        device_ids = [args.local_rank]
468        output_device = args.local_rank
469    elif device.type == 'cpu':
470        device_ids = None
471        output_device = None
472    else:
473        raise RuntimeError('Unsupported device type')
474    para_model = torch.nn.parallel.DistributedDataParallel(
475        model, device_ids=device_ids, output_device=output_device
476    )
477
478    # Iteration over epochs, notice that it starts from 'epoch_idx'
479    # which was previously loaded from the checkpoint
480    for epoch_idx in range(progress['epoch_idx'], args.epochs):
481        training_loop(
482            ft_client,
483            para_model,
484            model,
485            optimizer,
486            device,
487            train_dataloader,
488            train_sampler,
489            progress,
490            args,
491        )
492
493        # epoch_idx is incremented because the current epoch is finished
494        # and potential resume from this checkpoint should start a new training epoch.
495        progress['epoch_idx'] += 1
496        progress['iter_idx'] = 0
497
498        validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
499
500        # Checkpoint contains everything needed for deterministic resume:
501        # state of the model, optimizer and other components,
502        save_checkpoint(
503            progress=progress,
504            model=model,
505            optimizer=optimizer,
506            ft_client=ft_client,
507            output_dir=args.output_dir,
508            checkpoint_fname=args.checkpoint_fname,
509        )
510
511        # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
512        # if _any_ rank received SIGTERM, we leave the main loop
513        if dist_utils.is_true_on_any_rank(_signal_received):
514            logging.info('Leaving the main loop, due to SIGTERM')
515            break
516
517        # Setup simulated fault as soon as we have valid timeouts
518        if args.simulated_fault and not _sim_fault_is_set and ft_client.hb_timeouts.are_valid:
519            _setup_simulated_fault(ft_client, args.simulated_fault, device)
520
521    _cancel_simulated_fault()
522    ft_client.shutdown_workload_monitoring()
523    torch.distributed.destroy_process_group()
524    logging.info('Leaving main, ret_code=0')
525    sys.exit(0)
526
527
528if __name__ == "__main__":
529    main()