DDP usage example

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16"""
 17Demo of fault tolerance with DDP training
 18"""
 19
 20import argparse
 21import logging
 22import os
 23import random
 24import signal
 25import sys
 26import threading
 27import time
 28
 29import dist_utils
 30import log_utils
 31import numpy as np
 32import torch
 33import torch.nn as nn
 34
 35import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
 36
 37
 38# Dummy dataset.
 39class Dataset(torch.utils.data.Dataset):
 40    def __init__(self, size, hidden):
 41        self.size = size
 42        self.hidden = hidden
 43
 44    def __len__(self):
 45        return self.size
 46
 47    def __getitem__(self, idx):
 48        data = torch.full(
 49            (self.hidden,),
 50            fill_value=idx,
 51            dtype=torch.float32,
 52            device='cpu',
 53        )
 54        return data
 55
 56
 57# Dummy model
 58class Model(nn.Module):
 59    def __init__(self, hidden):
 60        super().__init__()
 61        self.l1 = nn.Linear(hidden, hidden)
 62        self.l2 = nn.Linear(hidden, hidden)
 63
 64    def forward(self, x):
 65        x = self.l1(x)
 66        x = self.l2(x)
 67        return x
 68
 69
 70def parse_args():
 71    def tuple_type(strings):
 72        strings = strings.replace("(", "").replace(")", "")
 73        mapped_int = map(int, strings.split(","))
 74        return tuple(mapped_int)
 75
 76    def fault_desc(strings):
 77        parts = strings.split(",")
 78        assert len(parts) == 2
 79        return {'fault': parts[0], 'delay': float(parts[1])}
 80
 81    parser = argparse.ArgumentParser(
 82        description='Example of PyTorch DDP training with the Fault Tolerance package',
 83        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 84    )
 85
 86    # fmt: off
 87    parser.add_argument('--hidden', type=int, default=4096,
 88                        help='Hidden size')
 89    parser.add_argument('--batch', type=int, default=8,
 90                        help='Batch size')
 91    parser.add_argument('--epochs', type=int, default=4,
 92                        help='Number of training epochs')
 93    parser.add_argument('--train_dataset_size', type=int, default=1000000,
 94                        help='Train dataset size')
 95    parser.add_argument('--val_dataset_size', type=int, default=2000,
 96                        help='Validation dataset size')
 97    parser.add_argument('--device', type=str, default='cuda',
 98                        choices=['cpu', 'cuda'],
 99                        help='Device')
100    
101    parser.add_argument('--interrupt_at', type=tuple_type, nargs='*',
102                        help='Manual interruption after (epoch, iteration), '
103                        'for testing only')
104    parser.add_argument('--save_interval', type=int, default=-1,
105                        help='Interval for saving periodic checkpoints')
106    parser.add_argument('--logging_interval', type=int, default=1,
107                        help='Interval for log entries')
108    parser.add_argument('--log_all_ranks', action='store_true',
109                        help='Enable logging from all distributed ranks')
110    parser.add_argument('--output_dir', type=str, default='results/output',
111                        help='Output dir')
112    parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
113                        help='Name of a checkpoint file')
114    
115    parser.add_argument('--local_rank', type=int,
116                        default=os.getenv('LOCAL_RANK', 0))
117    parser.add_argument('--init_distributed_method', type=str, default='tcp',
118                        help='Init distributed group with TCP store ("tcp") or file store ("file")')
119
120    parser.add_argument('--simulated_fault', type=fault_desc,
121                        help='Description of a fault to be simulated')
122    # fmt: on
123
124    args = parser.parse_args()
125
126    if args.interrupt_at:
127        args.interrupt_at = set(args.interrupt_at)
128    else:
129        args.interrupt_at = set()
130
131    return args
132
133
134def load_checkpoint(path):
135    map_location = {
136        'cpu': 'cpu',
137    }
138    if torch.cuda.is_available():
139        map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
140
141    logging.info(f'Loading checkpoint from {path}')
142    checkpoint = torch.load(path, map_location=map_location)
143    return checkpoint
144
145
146def save_checkpoint(
147    progress,
148    model,
149    optimizer,
150    ft_client,
151    output_dir,
152    checkpoint_fname,
153):
154    state = {
155        'progress': progress,
156        'model_state': model.state_dict(),
157        'optimizer_state': optimizer.state_dict(),
158        'ft_state': ft_client.state_dict(),
159    }
160
161    checkpoint_path = os.path.join(output_dir, checkpoint_fname)
162
163    with dist_utils.sync_workers() as rank:
164        if rank == 0:
165            logging.info(f'Saving checkpoint to {checkpoint_path}')
166            torch.save(state, checkpoint_path)
167
168
169def training_loop(
170    ft_client,
171    para_model,
172    model,
173    optimizer,
174    device,
175    dataloader,
176    sampler,
177    progress,
178    args,
179):
180    epoch_idx = progress['epoch_idx']
181
182    # NOTE: torch.utils.data.DistributedSampler must be prepared for current epoch
183    # need to do it before starting iteration
184    sampler.start_sample_idx = progress['iter_idx'] * args.batch
185    sampler.set_epoch(epoch_idx)
186
187    para_model.train()
188
189    last_log_time = time.monotonic()
190
191    for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
192        if ft_client.timeouts.are_valid is False and epoch_idx == 1 and iter_idx == 1:
193            # after 0th epoch is completed and we've done 0th iteration of the 1st epoch,
194            # we can calculate and set timeouts. this is a good moment to do so,
195            # because now we've seen the possibly long interval where checkpoint was saved.
196            ft_client.calculate_and_set_timeouts()
197
198        optimizer.zero_grad()
199        x = x.to(device)
200        y = para_model(x)
201        loss = y.mean()
202        train_loss = loss.item()
203        loss.backward()
204
205        if iter_idx % args.logging_interval == 0:
206            avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
207            logging.info(
208                f'CHECK TRAIN epoch: {epoch_idx:4d} '
209                f'iter: {iter_idx:5d} '
210                f'loss: {avg_train_loss} '
211                f'input: {x[:, 0]}'
212            )
213            if iter_idx > 0:
214                time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
215                last_log_time = time.monotonic()
216                logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
217
218        progress['iter_idx'] = iter_idx + 1
219
220        ft_client.send_heartbeat()
221        optimizer.step()
222
223        # Whether to do a periodic checkpointing
224        periodic_save = iter_idx % args.save_interval == args.save_interval - 1
225
226        if periodic_save or (epoch_idx, iter_idx) in args.interrupt_at:
227            save_checkpoint(
228                progress=progress,
229                model=model,
230                optimizer=optimizer,
231                ft_client=ft_client,
232                output_dir=args.output_dir,
233                checkpoint_fname=args.checkpoint_fname,
234            )
235            if (epoch_idx, iter_idx) in args.interrupt_at:
236                logging.info('Manual interruption, exiting')
237                sys.exit(0)
238
239
240def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
241    total_val_loss = 0
242    model.eval()
243
244    for iter_idx, x in enumerate(val_dataloader):
245        x = x.to(device)
246        y = model(x)
247        loss = y.mean().item()
248        total_val_loss += loss
249        ft_client.send_heartbeat()
250
251    logging.info(
252        f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
253    )
254
255
256_sim_fault_canceled = False
257_sim_fault_is_set = False
258
259
260def _cancel_simulated_fault():
261    global _sim_fault_canceled
262    _sim_fault_canceled = True
263
264
265def _setup_simulated_fault(ft_client, fault_desc, device):
266    # FIXME: hanging rank with SIGTSTP results in rank monitor
267    # blocked when trying to receive the data in _on_ipc_data_from_rank
268
269    global _sim_fault_is_set
270    _sim_fault_is_set = True  # should be True on all ranks
271
272    rng = random.Random()
273
274    logging.info(f"Initializing simulated fault: {fault_desc}")
275
276    fault_type = fault_desc['fault']
277    if fault_type == 'random':
278        fault_type = rng.choice(['rank_killed', 'rank_hung'])
279
280    rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
281    rank_to_fail = torch.tensor([rank_to_fail], device=device)
282    dist_utils.broadcast(rank_to_fail, 0)
283    rank_to_fail = int(rank_to_fail.item())
284
285    rank = torch.distributed.get_rank()
286    if rank != rank_to_fail:
287        return
288
289    if fault_type == 'rank_killed':
290        target_pid = os.getpid()
291        target_sig = signal.SIGKILL
292    elif fault_type == 'rank_hung':
293        target_pid = os.getpid()
294        target_sig = signal.SIGSTOP
295    else:
296        raise Exception(f"Unknown fault type {fault_type}")
297
298    delay = fault_desc['delay'] + 4.0 * rng.random()
299
300    logging.info(
301        f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
302    )
303
304    def __fault_thread():
305        time.sleep(delay)
306        if _sim_fault_canceled:
307            return
308        print(
309            f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
310            file=sys.stderr,
311        )
312        os.kill(target_pid, target_sig)
313
314    fault_sim_thread = threading.Thread(target=__fault_thread)
315    fault_sim_thread.daemon = True
316    fault_sim_thread.start()
317
318
319_signal_received = False
320
321
322def _sig_handler(*args, **kwargs):
323    print("Signal received!", file=sys.stderr)
324    global _signal_received
325    _signal_received = True
326
327
328def main():
329    signal.signal(signal.SIGTERM, _sig_handler)
330
331    args = parse_args()
332
333    torch.manual_seed(123)
334    np.random.seed(123)
335    random.seed(123)
336
337    if args.device == 'cuda':
338        if torch.cuda.is_available():
339            device = torch.device('cuda')
340            torch.cuda.set_device(args.local_rank)
341        else:
342            raise RuntimeError("Selected 'cuda' device but torch.cuda is not available.")
343    elif args.device == 'cpu':
344        device = torch.device('cpu')
345    else:
346        raise RuntimeError('Unknown device')
347
348    if int(os.getenv('WORLD_SIZE', '1')) == 1:
349        raise RuntimeError('This example supports only multi-gpu training')
350
351    os.makedirs(args.output_dir, exist_ok=True)
352
353    if args.init_distributed_method == 'tcp':
354        # NOTE: when runing tests with tcp init method we noticed
355        # occasional "address already in use" errors, after workload
356        # is restarted
357        dist_utils.init_distributed_with_tcp_store(device)
358    elif args.init_distributed_method == 'file':
359        dist_utils.init_distributed_with_file_store(device, store_file_dir=args.output_dir)
360    else:
361        raise RuntimeError(
362            f"--init_distributed_method should be ['tcp','file'] it is {args.init_distributed_method}"
363        )
364
365    if args.log_all_ranks:
366        log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
367    else:
368        log_file_name = 'train_log.log'
369    log_file_path = os.path.join(args.output_dir, log_file_name)
370
371    # NOTE: logging appends outputs to an existing log file if it already
372    # exists. Results from a single training run (potentially with many
373    # restarts from a checkpoint) are stored in a single log file.
374    log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
375    logging.info(args)
376
377    rank = dist_utils.get_rank()
378
379    logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
380
381    # Dummy datasets
382    train_dataset = Dataset(args.train_dataset_size, args.hidden)
383    val_dataset = Dataset(args.val_dataset_size, args.hidden)
384
385    # ResumableDistributedSampler is needed to skip consumed samples
386    train_sampler = dist_utils.ResumableDistributedSampler(
387        train_dataset,
388        drop_last=True,
389    )
390
391    val_sampler = torch.utils.data.DistributedSampler(
392        val_dataset,
393    )
394
395    # A dummy model and an optimizer
396    model = Model(args.hidden).to(device)
397    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
398
399    # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
400    progress = {
401        'epoch_idx': 0,
402        'iter_idx': 0,
403    }
404
405    checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
406
407    # Initialize fault tolerance.
408    ft_client = fault_tolerance.RankMonitorClient()
409    ft_client.init_workload_monitoring()
410
411    checkpoint = None
412
413    # try to load checkpoint from disk
414    if os.path.exists(checkpoint_path):
415        checkpoint = load_checkpoint(checkpoint_path)
416        if checkpoint:
417            logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
418
419    if checkpoint:
420        model.load_state_dict(checkpoint['model_state'])
421        optimizer.load_state_dict(checkpoint['optimizer_state'])
422        ft_client.load_state_dict(checkpoint['ft_state'])
423        progress.update(checkpoint['progress'])
424        # Return with zero exit code if model is already fully trained.
425        if progress['epoch_idx'] == args.epochs:
426            logging.info('Training finished.')
427            sys.exit(0)
428
429    train_dataloader = torch.utils.data.DataLoader(
430        dataset=train_dataset,
431        batch_size=args.batch,
432        sampler=train_sampler,
433        num_workers=4,
434        persistent_workers=True,
435        pin_memory=False,
436    )
437
438    val_dataloader = torch.utils.data.DataLoader(
439        dataset=val_dataset,
440        batch_size=args.batch,
441        sampler=val_sampler,
442        num_workers=4,
443    )
444
445    # Regular DDP init
446    # NOTE: for convenience code is keeping both wrapped and unwrapped model and
447    # uses wrapped model for training and unwrapped model for saving the
448    # checkpoint and validation. It doesn't increase memory consumption
449    # since both models are holding references to the same parameters.
450    # Additionally saved checkpoint is ready for inference and doesn't have to
451    # be manually unwrapped by accessing the (undocumented) "module" attribute
452    # of DDP-wrapped model.
453    if device.type == 'cuda':
454        device_ids = [args.local_rank]
455        output_device = args.local_rank
456    elif device.type == 'cpu':
457        device_ids = None
458        output_device = None
459    else:
460        raise RuntimeError('Unsupported device type')
461    para_model = torch.nn.parallel.DistributedDataParallel(
462        model, device_ids=device_ids, output_device=output_device
463    )
464
465    # Iteration over epochs, notice that it starts from 'epoch_idx'
466    # which was previously loaded from the checkpoint
467    for epoch_idx in range(progress['epoch_idx'], args.epochs):
468        training_loop(
469            ft_client,
470            para_model,
471            model,
472            optimizer,
473            device,
474            train_dataloader,
475            train_sampler,
476            progress,
477            args,
478        )
479
480        # epoch_idx is incremented because the current epoch is finished
481        # and potential resume from this checkpoint should start a new training epoch.
482        progress['epoch_idx'] += 1
483        progress['iter_idx'] = 0
484
485        validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
486
487        # Checkpoint contains everything needed for deterministic resume:
488        # state of the model, optimizer and other components,
489        save_checkpoint(
490            progress=progress,
491            model=model,
492            optimizer=optimizer,
493            ft_client=ft_client,
494            output_dir=args.output_dir,
495            checkpoint_fname=args.checkpoint_fname,
496        )
497
498        # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
499        # if _any_ rank received SIGTERM, we leave the main loop
500        if dist_utils.is_true_on_any_rank(_signal_received):
501            logging.info('Leaving the main loop, due to SIGTERM')
502            break
503
504        # Setup simulated fault as soon as we have valid timeouts
505        if args.simulated_fault and not _sim_fault_is_set and ft_client.timeouts.are_valid:
506            _setup_simulated_fault(ft_client, args.simulated_fault, device)
507
508    _cancel_simulated_fault()
509    ft_client.shutdown_workload_monitoring()
510    logging.info('Leaving main, ret_code=0')
511    sys.exit(0)
512
513
514if __name__ == "__main__":
515    main()