Section API usage example with DDP

  1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2# SPDX-License-Identifier: Apache-2.0
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8# http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15
 16"""
 17Demo of DDP training with fault tolerance, using FT package sections API
 18
 19It should be run with `ft_launcher`. E.g.
 20`ft_launcher --nproc-per-node=2 --fault-tol-cfg-path=./examples/fault_tolerance/fault_tol_cfg_sections.yaml examples/fault_tolerance/train_ddp_sections_api.py --device=cpu`
 21
 22This example uses following custom FT sections
 23- 'init' - covers workload initialization
 24- 'step' - covers training/evaluation step (fwd/bwd, loss calculation etc)
 25- 'checkpoint' - covers checkpoint saving
 26
 27Timeout for each section is calculated when enough data is collected.
 28FT "out-of-section" timeout is calculated when the training run ends normally.
 29FT state is saved in a JSON file.
 30
 31This example allows to simulate a training fault:
 32- selected rank hung
 33- selected rank terminated
 34"""
 35import argparse
 36import json
 37import logging
 38import os
 39import random
 40import signal
 41import sys
 42import threading
 43import time
 44
 45import dist_utils
 46import log_utils
 47import numpy as np
 48import torch
 49import torch.nn as nn
 50
 51import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
 52
 53
 54# Dummy dataset.
 55class Dataset(torch.utils.data.Dataset):
 56    def __init__(self, size, hidden):
 57        self.size = size
 58        self.hidden = hidden
 59
 60    def __len__(self):
 61        return self.size
 62
 63    def __getitem__(self, idx):
 64        data = torch.full(
 65            (self.hidden,),
 66            fill_value=idx,
 67            dtype=torch.float32,
 68            device='cpu',
 69        )
 70        return data
 71
 72
 73# Dummy model
 74class Model(nn.Module):
 75    def __init__(self, hidden):
 76        super().__init__()
 77        self.l1 = nn.Linear(hidden, hidden)
 78        self.l2 = nn.Linear(hidden, hidden)
 79
 80    def forward(self, x):
 81        x = self.l1(x)
 82        x = self.l2(x)
 83        return x
 84
 85
 86def parse_args():
 87    def fault_desc(strings):
 88        parts = strings.split(",")
 89        assert len(parts) == 2, "Fault description must be in format 'fault,delay'"
 90        return {'fault': parts[0], 'delay': float(parts[1])}
 91
 92    parser = argparse.ArgumentParser(
 93        description='Example of PyTorch DDP training with the Fault Tolerance package',
 94        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 95    )
 96
 97    # fmt: off
 98    parser.add_argument('--hidden', type=int, default=4096,
 99                        help='Hidden size')
100    parser.add_argument('--batch', type=int, default=8,
101                        help='Batch size')
102    parser.add_argument('--epochs', type=int, default=4,
103                        help='Number of training epochs')
104    parser.add_argument('--train_dataset_size', type=int, default=1000000,
105                        help='Train dataset size')
106    parser.add_argument('--val_dataset_size', type=int, default=2000,
107                        help='Validation dataset size')
108    parser.add_argument('--device', type=str, default='cuda',
109                        choices=['cpu', 'cuda'],
110                        help='Device')
111    parser.add_argument('--save_interval', type=int, default=-1,
112                        help='Interval for saving periodic checkpoints.')
113    parser.add_argument('--logging_interval', type=int, default=1,
114                        help='Interval for log entries')
115    parser.add_argument('--log_all_ranks', action='store_true',
116                        help='Enable logging from all distributed ranks')
117    parser.add_argument('--output_dir', type=str, default='results/output',
118                        help='Output dir')
119    parser.add_argument('--checkpoint_fname', type=str, default='checkpoint.pt',
120                        help='Name of a checkpoint file')
121    parser.add_argument('--local_rank', type=int,
122                        default=os.getenv('LOCAL_RANK', 0))
123    parser.add_argument('--simulated_fault', type=fault_desc,
124                        help='Description of a fault to be simulated')
125    # fmt: on
126
127    args = parser.parse_args()
128    return args
129
130
131def load_checkpoint(path):
132    map_location = {
133        'cpu': 'cpu',
134    }
135    if torch.cuda.is_available():
136        map_location['cuda:0'] = f'cuda:{torch.cuda.current_device()}'
137
138    logging.info(f'Loading checkpoint from {path}')
139    checkpoint = torch.load(path, map_location=map_location, weights_only=True)
140    return checkpoint
141
142
143def save_checkpoint(
144    progress,
145    model,
146    optimizer,
147    ft_client,
148    output_dir,
149    checkpoint_fname,
150):
151    # Checkpointing is wrapped into "checkpoint" FT section
152    # NOTE: FT state is not stored in the checkpoint, but in a separate JSON file
153    ft_client.start_section('checkpoint')
154
155    state = {
156        'progress': progress,
157        'model_state': model.state_dict(),
158        'optimizer_state': optimizer.state_dict(),
159    }
160
161    checkpoint_path = os.path.join(output_dir, checkpoint_fname)
162
163    with dist_utils.sync_workers() as rank:
164        if rank == 0:
165            logging.info(f'Saving checkpoint to {checkpoint_path}')
166            torch.save(state, checkpoint_path)
167
168    ft_client.end_section('checkpoint')
169
170
171def maybe_load_ft_state(path):
172    # Load FT state from JSON file
173    if os.path.exists(path):
174        logging.info(f'FT state loading from: {path}')
175        with open(path, 'r') as f:
176            return json.load(f)
177    else:
178        logging.info(f'FT state file not found at: {path}')
179        return None
180
181
182def save_ft_state(ft_client, path):
183    # Save FT state into a JSON file
184    with dist_utils.sync_workers() as rank:
185        if rank == 0:
186            logging.info(f'Saving FT state into: {path}')
187            ft_state = ft_client.state_dict()
188            with open(path, 'w') as f:
189                json.dump(ft_state, f)
190
191
192def update_ft_section_timeouts(ft_client, selected_sections, calc_out_of_section, ft_state_path):
193    # Update FT timeouts and save the FT state
194    logging.info(
195        f'Updating FT section timeouts for: {selected_sections} will update out-of-section: {calc_out_of_section}'
196    )
197    ft_client.calculate_and_set_section_timeouts(
198        selected_sections=selected_sections, calc_out_of_section=calc_out_of_section
199    )
200    save_ft_state(ft_client, ft_state_path)
201
202
203def training_loop(
204    ft_client,
205    para_model,
206    model,
207    optimizer,
208    device,
209    dataloader,
210    progress,
211    args,
212):
213    # Training epoch implementation
214
215    epoch_idx = progress['epoch_idx']
216
217    para_model.train()
218
219    last_log_time = time.monotonic()
220
221    num_iters_made = 0
222
223    for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']):
224
225        # fwd/bwd and optimizer step are wrapped into "step" FT section
226        ft_client.start_section('step')
227
228        optimizer.zero_grad()
229        x = x.to(device)
230        y = para_model(x)
231        loss = y.mean()
232        train_loss = loss.item()
233        loss.backward()
234
235        if iter_idx % args.logging_interval == 0:
236            avg_train_loss = dist_utils.all_reduce_item(train_loss, op='mean')
237            logging.info(
238                f'CHECK TRAIN epoch: {epoch_idx:4d} '
239                f'iter: {iter_idx:5d} '
240                f'loss: {avg_train_loss} '
241                f'input: {x[:, 0]}'
242            )
243            if iter_idx > 0:
244                time_per_iter = (time.monotonic() - last_log_time) / args.logging_interval
245                last_log_time = time.monotonic()
246                logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')
247
248        progress['iter_idx'] = iter_idx + 1
249
250        optimizer.step()
251
252        ft_client.end_section('step')
253
254        # Whether to do a periodic checkpointing
255        periodic_save = iter_idx % args.save_interval == args.save_interval - 1
256        if periodic_save:
257            save_checkpoint(
258                progress=progress,
259                model=model,
260                optimizer=optimizer,
261                ft_client=ft_client,
262                output_dir=args.output_dir,
263                checkpoint_fname=args.checkpoint_fname,
264            )
265
266        num_iters_made += 1
267
268    return num_iters_made
269
270
271def validation_loop(ft_client, model, val_dataloader, epoch_idx, device):
272
273    # Validation epoch implementation
274
275    total_val_loss = 0
276    model.eval()
277
278    for iter_idx, x in enumerate(val_dataloader):
279
280        # fwd and loss are wrapped into "step" FT section
281        # 'step' section is used for both: training and eval steps
282        ft_client.start_section('step')
283
284        x = x.to(device)
285        y = model(x)
286        loss = y.mean().item()
287        total_val_loss += loss
288
289        ft_client.end_section('step')
290
291    logging.info(
292        f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} ' f'loss: {total_val_loss / (iter_idx + 1)}'
293    )
294
295
296_sim_fault_canceled = False
297_sim_fault_is_set = False
298
299
300def _cancel_simulated_fault():
301    global _sim_fault_canceled
302    _sim_fault_canceled = True
303
304
305def _setup_simulated_fault(fault_desc, device):
306
307    global _sim_fault_is_set
308    _sim_fault_is_set = True  # should be True on all ranks
309
310    rng = random.Random()
311
312    logging.info(f"Initializing simulated fault: {fault_desc}")
313
314    fault_type = fault_desc['fault']
315    if fault_type == 'random':
316        fault_type = rng.choice(['rank_killed', 'rank_hung'])
317
318    rank_to_fail = rng.randint(0, dist_utils.get_world_size() - 1)
319    rank_to_fail = torch.tensor([rank_to_fail], device=device)
320    dist_utils.broadcast(rank_to_fail, 0)
321    rank_to_fail = int(rank_to_fail.item())
322
323    rank = torch.distributed.get_rank()
324    if rank != rank_to_fail:
325        return
326
327    if fault_type == 'rank_killed':
328        target_pid = os.getpid()
329        target_sig = signal.SIGKILL
330    elif fault_type == 'rank_hung':
331        target_pid = os.getpid()
332        target_sig = signal.SIGSTOP
333    else:
334        raise Exception(f"Unknown fault type {fault_type}")
335
336    delay = fault_desc['delay'] + 4.0 * rng.random()
337
338    logging.info(
339        f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",
340    )
341
342    def __fault_thread():
343        time.sleep(delay)
344        if _sim_fault_canceled:
345            return
346        print(
347            f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
348            file=sys.stderr,
349        )
350        os.kill(target_pid, target_sig)
351
352    fault_sim_thread = threading.Thread(target=__fault_thread)
353    fault_sim_thread.daemon = True
354    fault_sim_thread.start()
355
356
357_signal_received = False
358
359
360def _sig_handler(*args, **kwargs):
361    print("Signal received!", file=sys.stderr)
362    global _signal_received
363    _signal_received = True
364
365
366def main():
367    signal.signal(signal.SIGTERM, _sig_handler)
368
369    args = parse_args()
370
371    torch.manual_seed(123)
372    np.random.seed(123)
373    random.seed(123)
374
375    if args.device == 'cuda':
376        device = torch.device('cuda')
377        torch.cuda.set_device(args.local_rank)
378    elif args.device == 'cpu':
379        device = torch.device('cpu')
380    else:
381        raise RuntimeError('Unknown device')
382
383    os.makedirs(args.output_dir, exist_ok=True)
384
385    dist_utils.init_distributed_with_tcp_store(device)
386    rank = dist_utils.get_rank()
387
388    if args.log_all_ranks:
389        log_file_name = f'train_log_rank_{dist_utils.get_rank()}.log'
390    else:
391        log_file_name = 'train_log.log'
392    log_file_path = os.path.join(args.output_dir, log_file_name)
393
394    # NOTE: logging appends outputs to an existing log file if it already
395    # exists. Results from a single training run (potentially with many
396    # restarts from a checkpoint) are stored in a single log file.
397    log_utils.setup_logging(args.log_all_ranks, filename=log_file_path, filemode='a')
398
399    logging.info(args)
400    logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")
401
402    # Dummy datasets
403    train_dataset = Dataset(args.train_dataset_size, args.hidden)
404    val_dataset = Dataset(args.val_dataset_size, args.hidden)
405
406    train_sampler = torch.utils.data.DistributedSampler(
407        train_dataset,
408        drop_last=True,
409    )
410
411    val_sampler = torch.utils.data.DistributedSampler(
412        val_dataset,
413    )
414
415    # A dummy model and an optimizer
416    model = Model(args.hidden).to(device)
417    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
418
419    # Initial value for start epoch - will be overwritten if training is resumed from a checkpoint
420    progress = {
421        'epoch_idx': 0,
422        'iter_idx': 0,
423    }
424
425    checkpoint_path = os.path.join(args.output_dir, args.checkpoint_fname)
426
427    # Initialize fault tolerance.
428    ft_client = fault_tolerance.RankMonitorClient()
429    ft_client.init_workload_monitoring()
430
431    # try to load FT state from a JSON file
432    ft_state_path = os.path.join(args.output_dir, 'ft_state.json')
433    ft_state = maybe_load_ft_state(ft_state_path)
434    if ft_state:
435        ft_client.load_state_dict(ft_state)
436
437    # Open "init" FT section that covers workload initialization
438    ft_client.start_section('init')
439
440    is_checkpoint_loaded = False
441
442    # try to load checkpoint from disk
443    if os.path.exists(checkpoint_path):
444        checkpoint = load_checkpoint(checkpoint_path)
445        if checkpoint:
446            logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')
447            is_checkpoint_loaded = True
448            model.load_state_dict(checkpoint['model_state'])
449            optimizer.load_state_dict(checkpoint['optimizer_state'])
450            progress.update(checkpoint['progress'])
451
452    # Return with zero exit code if model is already fully trained.
453    if progress['epoch_idx'] == args.epochs:
454        ft_client.end_section('init')  # explicitly end "init" section, to avoid FT warning
455        ft_client.shutdown_workload_monitoring()
456        torch.distributed.destroy_process_group()
457        logging.info('Training finished.')
458        sys.exit(0)
459
460    train_dataloader = torch.utils.data.DataLoader(
461        dataset=train_dataset,
462        batch_size=args.batch,
463        sampler=train_sampler,
464        num_workers=4,
465        persistent_workers=True,
466        pin_memory=False,
467    )
468
469    val_dataloader = torch.utils.data.DataLoader(
470        dataset=val_dataset,
471        batch_size=args.batch,
472        sampler=val_sampler,
473        num_workers=4,
474    )
475
476    # Regular DDP init
477    # NOTE: for convenience code is keeping both wrapped and unwrapped model and
478    # uses wrapped model for training and unwrapped model for saving the
479    # checkpoint and validation. It doesn't increase memory consumption
480    # since both models are holding references to the same parameters.
481    # Additionally saved checkpoint is ready for inference and doesn't have to
482    # be manually unwrapped by accessing the (undocumented) "module" attribute
483    # of DDP-wrapped model.
484    if device.type == 'cuda':
485        device_ids = [args.local_rank]
486        output_device = args.local_rank
487    elif device.type == 'cpu':
488        device_ids = None
489        output_device = None
490    else:
491        raise RuntimeError('Unsupported device type')
492    para_model = torch.nn.parallel.DistributedDataParallel(
493        model, device_ids=device_ids, output_device=output_device
494    )
495
496    # "init" FT section ends here
497    ft_client.end_section('init')
498
499    if is_checkpoint_loaded:
500        # init time can be longer if there was checkpoint loading
501        # so we update "init" secton timeout if a checkpoint was loaded
502        update_ft_section_timeouts(ft_client, ['init'], False, ft_state_path)
503
504    # Iteration over epochs, notice that it starts from 'epoch_idx'
505    # which was previously loaded from the checkpoint
506    for epoch_idx in range(progress['epoch_idx'], args.epochs):
507
508        num_tr_iters_made = training_loop(
509            ft_client,
510            para_model,
511            model,
512            optimizer,
513            device,
514            train_dataloader,
515            progress,
516            args,
517        )
518
519        # If there were some training iterations observed, update "step" section timeout
520        if num_tr_iters_made > 0:
521            update_ft_section_timeouts(ft_client, ['step'], False, ft_state_path)
522
523        # epoch_idx is incremented because the current epoch is finished
524        # and potential resume from this checkpoint should start a new training epoch.
525        progress['epoch_idx'] += 1
526        progress['iter_idx'] = 0
527
528        validation_loop(ft_client, model, val_dataloader, epoch_idx, device)
529
530        # Checkpoint contains everything needed for deterministic resume:
531        # state of the model, optimizer and other components,
532        save_checkpoint(
533            progress=progress,
534            model=model,
535            optimizer=optimizer,
536            ft_client=ft_client,
537            output_dir=args.output_dir,
538            checkpoint_fname=args.checkpoint_fname,
539        )
540
541        # Update checkpointing section timeout after checkpoint saving was seen
542        update_ft_section_timeouts(ft_client, ['checkpoint'], False, ft_state_path)
543
544        # NOTE: SIGTERM is used by SLURM to initiate graceful job termination
545        # if _any_ rank received SIGTERM, we leave the main loop
546        if dist_utils.is_true_on_any_rank(_signal_received):
547            logging.info('Leaving the main loop, due to SIGTERM')
548            break
549
550        # Setup simulated fault
551        if args.simulated_fault and not _sim_fault_is_set:
552            _setup_simulated_fault(args.simulated_fault, device)
553
554    _cancel_simulated_fault()
555
556    # update "out-of-section" FT timeout when the training run ends normally
557    update_ft_section_timeouts(ft_client, [], True, ft_state_path)
558    ft_client.shutdown_workload_monitoring()
559    torch.distributed.destroy_process_group()
560    logging.info('Leaving main, ret_code=0')
561    sys.exit(0)
562
563
564if __name__ == "__main__":
565    main()