Heartbeat API usage example with DDP

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16"""
 17Demo of fault tolerance with DDP training, using FT package heartbeats API
 18"""
 19
 20import argparse
 21import logging
 22import os
 23import random
 24import signal
 25import sys
 26import threading
 27import time
 28
 29import dist_utils
 30import log_utils
 31import numpy as np
 32import torch
 33import torch.nn as nn
 34
 35import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
 36
 37
 38# Dummy dataset.
 39class Dataset(torch.utils.data.Dataset):
 40    def __init__(self, size, hidden):
 41        self.size = size
 42        self.hidden = hidden
 43
 44    def __len__(self):
 45        return self.size
 46
 47    def __getitem__(self, idx):
 48        data = torch.full(
 49            (self.hidden,),
 50            fill_value=idx,
 51            dtype=torch.float32,
 52            device='cpu',
 53        )
 54        return data
 55
 56
 57# Dummy model
 58class Model(nn.Module):
 59    def __init__(self, hidden):
 60        super().__init__()
 61        self.l1 = nn.Linear(hidden, hidden)
 62        self.l2 = nn.Linear(hidden, hidden)
 63
 64    def forward(self, x):
 65        x = self.l1(x)
 66        x = self.l2(x)
 67        return x
 68
 69
 70def parse_args():
 71    def tuple_type(strings):
 72        strings = strings.replace("(", "").replace(")", "")
 73        mapped_int = map(int, strings.split(","))
 74        return tuple(mapped_int)
 75
 76    def fault_desc(strings):
 77        parts = strings.split(",")
 78        assert len(parts) == 2
 79        return {'fault': parts[0], 'delay': float(parts[1])}
 80
 81    parser = argparse.ArgumentParser(
 82        description='Example of PyTorch DDP training with the Fault Tolerance package',
 83        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 84    )
 85
 86    # fmt: off
 87    parser.add_argument('--hidden', type=int, default=4096,
 88                        help='Hidden size')
 89    parser.add_argument('--batch', type=int, default=8,
 90                        help='Batch size')
 91    parser.add_argument('--epochs', type=int, default=4,
 92                        help='Number of training epochs')
 93    parser.add_argument('--train_dataset_size', type=int, default=1000000,
 94                        help='Train dataset size')
 95    parser.add_argument('--val_dataset_size', type=int, default=2000,
 96                        help='Validation dataset size')
 97    parser.add_argument('--device', type=str, default='cuda',
 98                        choices=['cpu', 'cuda'],
 99                        help='Device')
100    
101    parser.add_argument('--interrupt_at', type=tuple_type, nargs='*',
102                        help='Manual interruption after (epoch, iteration), '
103                        'for testing only')
104    parser.add_argument('--save_interval', type=int, default=-1,
105                        help='Interval for saving periodic checkpoints')
106    parser.add_argument('--logging_interval', type=int, default=1,
107                        help='Interval for log entries')
108    parser.add_argument('--log_all_ranks', action='store_true',
109                        help='Enable logging from all distributed ranks')
110    parser.add_argument('--output_dir', type=str, default='results/output',
111                        help='Output dir')
112    parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
113                        help='Name of a checkpoint file')
114    
115    parser.add_argument('--local_rank', type=int,
116                        default=os.getenv('LOCAL_RANK', 0))
117    parser.add_argument('--init_distributed_method', type=str, default='tcp',
118                        help='Init distributed group with TCP store ("tcp") or file store ("file")')
119
120    parser.add_argument('--simulated_fault', type=fault_desc,
121                        help='Description of a fault to be simulated')
122    # fmt: on
123
124    args = parser.parse_args()
125
126    if args.interrupt_at:
127        args.interrupt_at = set(args.interrupt_at)
128    else:
129        args.interrupt_at = set()
130
131    return args
132
133
134def load_checkpoint(path):
135    map_location = {
136        'cpu': 'cpu',
137    }
138    if torch.cuda.is_available():
139        map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
140
141    logging.info(f'Loading checkpoint from {path}')
142    checkpoint = torch.load(path, map_location=map_location, weights_only=True)
143    return checkpoint
144
145
146def save_checkpoint(
147    progress,
148    model,
149    optimizer,
150    ft_client,
151    output_dir,
152    checkpoint_fname,
153):
154    state = {
155        'progress': progress,
156        'model_state': model.state_dict(),
157        'optimizer_state': optimizer.state_dict(),
158        'ft_state': ft_client.state_dict(),
159    }
160
161    checkpoint_path = os.path.join(output_dir, checkpoint_fname)
162
163    with dist_utils.sync_workers() as rank:
164        if rank == 0:
165            logging.info(f'Saving checkpoint to {checkpoint_path}')
166            torch.save(state, checkpoint_path)
167
168
169def training_loop(
170    ft_client,
171    para_model,
172    model,
173    optimizer,
174    device,
175    dataloader,
176    sampler,
177    progress,
178    args,
179):
180    epoch_idx = progress['epoch_idx']
181
182    # NOTE: torch.utils.data.DistributedSampler must be prepared for current epoch
183    # need to do it before starting iteration
184    sampler.start_sample_idx = progress['iter_idx'] * args.batch
185    sampler.set_epoch(epoch_idx)
186
187    para_model.train()
188
189    last_log_time = time.monotonic()
190
191    for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
192        if ft_client.hb_timeouts.are_valid is False and epoch_idx == 1 and iter_idx == 1:
193            # after 0th epoch is completed and we've done 0th iteration of the 1st epoch,
194            # we can calculate and set timeouts. this is a good moment to do so,
195            # because now we've seen the possibly long interval where checkpoint was saved.
196            ft_client.calculate_and_set_hb_timeouts()
197
198        optimizer.zero_grad()
199        x = x.to(device)
200        y = para_model(x)
201        loss = y.mean()
202        train_loss = loss.item()
203        loss.backward()
204
205        if iter_idx % args.logging_interval == 0:
206            avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
207            logging.info(
208                f'CHECK TRAIN epoch: {epoch_idx:4d} '
209                f'iter: {iter_idx:5d} '
210                f'loss: {avg_train_loss} '
211                f'input: {x[:, 0]}'
212            )
213            if iter_idx > 0:
214                time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
215                last_log_time = time.monotonic()
216                logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
217
218        progress['iter_idx'] = iter_idx + 1
219
220        ft_client.send_heartbeat()
221        optimizer.step()
222
223        # Whether to do a periodic checkpointing
224        periodic_save = iter_idx % args.save_interval == args.save_interval - 1
225
226        if periodic_save or (epoch_idx, iter_idx) in args.interrupt_at:
227            save_checkpoint(
228                progress=progress,
229                model=model,
230                optimizer=optimizer,
231                ft_client=ft_client,
232                output_dir=args.output_dir,
233                checkpoint_fname=args.checkpoint_fname,
234            )
235            if (epoch_idx, iter_idx) in args.interrupt_at:
236                logging.info('Manual interruption, exiting')
237                sys.exit(0)
238
239
240def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
241    total_val_loss = 0
242    model.eval()
243
244    for iter_idx, x in enumerate(val_dataloader):
245        x = x.to(device)
246        y = model(x)
247        loss = y.mean().item()
248        total_val_loss += loss
249        ft_client.send_heartbeat()
250
251    logging.info(
252        f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
253    )
254
255
256_sim_fault_canceled = False
257_sim_fault_is_set = False
258
259
260def _cancel_simulated_fault():
261    global _sim_fault_canceled
262    _sim_fault_canceled = True
263
264
265def _setup_simulated_fault(ft_client, fault_desc, device):
266    # FIXME: hanging rank with SIGTSTP results in rank monitor
267    # blocked when trying to receive the data in _on_ipc_data_from_rank
268
269    global _sim_fault_is_set
270    _sim_fault_is_set = True  # should be True on all ranks
271
272    rng = random.Random()
273
274    logging.info(f"Initializing simulated fault: {fault_desc}")
275
276    fault_type = fault_desc['fault']
277    if fault_type == 'random':
278        fault_type = rng.choice(['rank_killed', 'rank_hung'])
279
280    rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
281    rank_to_fail = torch.tensor([rank_to_fail], device=device)
282    dist_utils.broadcast(rank_to_fail, 0)
283    rank_to_fail = int(rank_to_fail.item())
284
285    rank = torch.distributed.get_rank()
286    if rank != rank_to_fail:
287        return
288
289    if fault_type == 'rank_killed':
290        target_pid = os.getpid()
291        target_sig = signal.SIGKILL
292    elif fault_type == 'rank_hung':
293        target_pid = os.getpid()
294        target_sig = signal.SIGSTOP
295    else:
296        raise Exception(f"Unknown fault type {fault_type}")
297
298    delay = fault_desc['delay'] + 4.0 * rng.random()
299
300    logging.info(
301        f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
302    )
303
304    def __fault_thread():
305        time.sleep(delay)
306        if _sim_fault_canceled:
307            return
308        print(
309            f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
310            file=sys.stderr,
311        )
312        os.kill(target_pid, target_sig)
313
314    fault_sim_thread = threading.Thread(target=__fault_thread)
315    fault_sim_thread.daemon = True
316    fault_sim_thread.start()
317
318
319_signal_received = False
320
321
322def _sig_handler(*args, **kwargs):
323    print("Signal received!", file=sys.stderr)
324    global _signal_received
325    _signal_received = True
326
327
328def main():
329    signal.signal(signal.SIGTERM, _sig_handler)
330
331    args = parse_args()
332
333    torch.manual_seed(123)
334    np.random.seed(123)
335    random.seed(123)
336
337    if args.device == 'cuda':
338        if torch.cuda.is_available():
339            device = torch.device('cuda')
340            torch.cuda.set_device(args.local_rank)
341        else:
342            raise RuntimeError("Selected 'cuda' device but torch.cuda is not available.")
343    elif args.device == 'cpu':
344        device = torch.device('cpu')
345    else:
346        raise RuntimeError('Unknown device')
347
348    if int(os.getenv('WORLD_SIZE', '1')) == 1:
349        raise RuntimeError('This example supports only multi-gpu training')
350
351    os.makedirs(args.output_dir, exist_ok=True)
352
353    if args.init_distributed_method == 'tcp':
354        # NOTE: when runing tests with tcp init method we noticed
355        # occasional "address already in use" errors, after workload
356        # is restarted
357        dist_utils.init_distributed_with_tcp_store(device)
358    elif args.init_distributed_method == 'file':
359        dist_utils.init_distributed_with_file_store(device, store_file_dir=args.output_dir)
360    else:
361        raise RuntimeError(
362            f"--init_distributed_method should be ['tcp','file'] it is {args.init_distributed_method}"
363        )
364
365    if args.log_all_ranks:
366        log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
367    else:
368        log_file_name = 'train_log.log'
369    log_file_path = os.path.join(args.output_dir, log_file_name)
370
371    # NOTE: logging appends outputs to an existing log file if it already
372    # exists. Results from a single training run (potentially with many
373    # restarts from a checkpoint) are stored in a single log file.
374    log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
375    logging.info(args)
376
377    rank = dist_utils.get_rank()
378
379    logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
380
381    # Dummy datasets
382    train_dataset = Dataset(args.train_dataset_size, args.hidden)
383    val_dataset = Dataset(args.val_dataset_size, args.hidden)
384
385    # ResumableDistributedSampler is needed to skip consumed samples
386    train_sampler = dist_utils.ResumableDistributedSampler(
387        train_dataset,
388        drop_last=True,
389    )
390
391    val_sampler = torch.utils.data.DistributedSampler(
392        val_dataset,
393    )
394
395    # A dummy model and an optimizer
396    model = Model(args.hidden).to(device)
397    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
398
399    # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
400    progress = {
401        'epoch_idx': 0,
402        'iter_idx': 0,
403    }
404
405    checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
406
407    # Initialize fault tolerance.
408    ft_client = fault_tolerance.RankMonitorClient()
409    ft_client.init_workload_monitoring()
410
411    checkpoint = None
412
413    # try to load checkpoint from disk
414    if os.path.exists(checkpoint_path):
415        checkpoint = load_checkpoint(checkpoint_path)
416        if checkpoint:
417            logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
418
419    if checkpoint:
420        model.load_state_dict(checkpoint['model_state'])
421        optimizer.load_state_dict(checkpoint['optimizer_state'])
422        ft_client.load_state_dict(checkpoint['ft_state'])
423        progress.update(checkpoint['progress'])
424        # Return with zero exit code if model is already fully trained.
425        if progress['epoch_idx'] == args.epochs:
426            logging.info('Training finished.')
427            ft_client.shutdown_workload_monitoring()
428            torch.distributed.destroy_process_group()
429            sys.exit(0)
430
431    train_dataloader = torch.utils.data.DataLoader(
432        dataset=train_dataset,
433        batch_size=args.batch,
434        sampler=train_sampler,
435        num_workers=4,
436        persistent_workers=True,
437        pin_memory=False,
438    )
439
440    val_dataloader = torch.utils.data.DataLoader(
441        dataset=val_dataset,
442        batch_size=args.batch,
443        sampler=val_sampler,
444        num_workers=4,
445    )
446
447    # Regular DDP init
448    # NOTE: for convenience code is keeping both wrapped and unwrapped model and
449    # uses wrapped model for training and unwrapped model for saving the
450    # checkpoint and validation. It doesn't increase memory consumption
451    # since both models are holding references to the same parameters.
452    # Additionally saved checkpoint is ready for inference and doesn't have to
453    # be manually unwrapped by accessing the (undocumented) "module" attribute
454    # of DDP-wrapped model.
455    if device.type == 'cuda':
456        device_ids = [args.local_rank]
457        output_device = args.local_rank
458    elif device.type == 'cpu':
459        device_ids = None
460        output_device = None
461    else:
462        raise RuntimeError('Unsupported device type')
463    para_model = torch.nn.parallel.DistributedDataParallel(
464        model, device_ids=device_ids, output_device=output_device
465    )
466
467    # Iteration over epochs, notice that it starts from 'epoch_idx'
468    # which was previously loaded from the checkpoint
469    for epoch_idx in range(progress['epoch_idx'], args.epochs):
470        training_loop(
471            ft_client,
472            para_model,
473            model,
474            optimizer,
475            device,
476            train_dataloader,
477            train_sampler,
478            progress,
479            args,
480        )
481
482        # epoch_idx is incremented because the current epoch is finished
483        # and potential resume from this checkpoint should start a new training epoch.
484        progress['epoch_idx'] += 1
485        progress['iter_idx'] = 0
486
487        validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
488
489        # Checkpoint contains everything needed for deterministic resume:
490        # state of the model, optimizer and other components,
491        save_checkpoint(
492            progress=progress,
493            model=model,
494            optimizer=optimizer,
495            ft_client=ft_client,
496            output_dir=args.output_dir,
497            checkpoint_fname=args.checkpoint_fname,
498        )
499
500        # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
501        # if _any_ rank received SIGTERM, we leave the main loop
502        if dist_utils.is_true_on_any_rank(_signal_received):
503            logging.info('Leaving the main loop, due to SIGTERM')
504            break
505
506        # Setup simulated fault as soon as we have valid timeouts
507        if args.simulated_fault and not _sim_fault_is_set and ft_client.hb_timeouts.are_valid:
508            _setup_simulated_fault(ft_client, args.simulated_fault, device)
509
510    _cancel_simulated_fault()
511    ft_client.shutdown_workload_monitoring()
512    torch.distributed.destroy_process_group()
513    logging.info('Leaving main, ret_code=0')
514    sys.exit(0)
515
516
517if __name__ == "__main__":
518    main()