Basic usage example

  1# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
  2# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3# SPDX-License-Identifier: Apache-2.0
  4#
  5# Licensed under the Apache License, Version 2.0 (the "License");
  6# you may not use this file except in compliance with the License.
  7# You may obtain a copy of the License at
  8#
  9# http://www.apache.org/licenses/LICENSE-2.0
 10#
 11# Unless required by applicable law or agreed to in writing, software
 12# distributed under the License is distributed on an "AS IS" BASIS,
 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14# See the License for the specific language governing permissions and
 15# limitations under the License.
 16
 17
 18# This example demonstrates how to integrate ``inprocess.Wrapper()`` into an
 19# existing PyTorch training codebase.
 20#
 21# In this case, the entire ``main()`` function is wrapped. While all features
 22# of ``inprocess.Wrapper()`` are available and active, the Wrapper is
 23# configured to restart the entire application upon any failure. Consequently,
 24# the application state is not preserved between restarts and the entire
 25# ``main()`` is relaunched, leading to less efficient recovery from failures.
 26#
 27# NOTE: inprocess.Wrapper is not fully compatible with modern
 28# ``torch.distributed.run``, because it automatically terminates all local
 29# workers upon any local worker process failure; in this case inprocess.Wrapper
 30# can only recover from transient faults that don't terminate any of the
 31# training processes
 32
 33import argparse
 34import datetime
 35import logging
 36import os
 37import pathlib
 38import random
 39import time
 40from typing import Optional
 41
 42import torch
 43
 44import nvidia_resiliency_ext.inprocess as inprocess
 45
 46raise_timestamp = None
 47
 48
 49def parse_args():
 50    parser = argparse.ArgumentParser(
 51        description='Inprocess Restart Basic Example',
 52        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 53    )
 54
 55    parser.add_argument(
 56        '--size',
 57        default=64,
 58        type=int,
 59        help='model hidden size',
 60    )
 61    parser.add_argument(
 62        '--layers',
 63        default=4,
 64        type=int,
 65        help='number of layers',
 66    )
 67    parser.add_argument(
 68        '--log-interval',
 69        default=100,
 70        type=int,
 71        help='logging interval',
 72    )
 73    parser.add_argument(
 74        '--chkpt-interval',
 75        default=100,
 76        type=int,
 77        help='checkpointing interval',
 78    )
 79    parser.add_argument(
 80        '--total-iterations',
 81        default=1000000,
 82        type=int,
 83        help='total training iterations',
 84    )
 85    parser.add_argument(
 86        '--seed',
 87        default=None,
 88        type=int,
 89        help='random seed, time-based if None',
 90    )
 91    parser.add_argument(
 92        '--path',
 93        default='/tmp/',
 94        type=str,
 95        help='directory for the checkpoint file',
 96    )
 97    parser.add_argument(
 98        '--fault-prob',
 99        default=0.001,
100        type=float,
101        help='fault injection probability',
102    )
103    parser.add_argument(
104        '--device',
105        default='cpu',
106        choices=['cpu', 'cuda'],
107        help='device',
108    )
109    parser.add_argument(
110        '--log-level',
111        type=lambda s: logging._nameToLevel[s.upper()],
112        default=logging.INFO,
113        help='logging level',
114    )
115
116    return parser.parse_args()
117
118
119# TCPStore created by the Wrapper uses ``(MASTER_PORT + 2)`` port for the
120# internal Wrapper TCPStore to avoid conflicts with application's TCPStore
121# listening on ``(MASTER_PORT + 1)``, and with TCPStore created by
122# ``torch.distributed.run`` listening on ``MASTER_PORT``.
123@inprocess.Wrapper(
124    store_kwargs={'port': int(os.getenv('MASTER_PORT', 29500)) + 2},
125    health_check=inprocess.health_check.CudaHealthCheck(),
126)
127def main(call_wrapper: Optional[inprocess.CallWrapper] = None):
128    global raise_timestamp
129    if raise_timestamp is not None:
130        restart_latency = time.perf_counter() - raise_timestamp
131        logging.info(f'restart latency: {restart_latency:.3f}s')
132    raise_timestamp = None
133
134    args = parse_args()
135    logging.info(f'{args}')
136
137    log_interval = args.log_interval
138    chkpt_interval = args.chkpt_interval
139
140    rank = int(os.environ['RANK'])
141    local_rank = int(os.environ['LOCAL_RANK'])
142    world_size = int(os.environ['WORLD_SIZE'])
143
144    if args.device == 'cuda':
145        torch.cuda.set_device(local_rank)
146        device = torch.device('cuda')
147        backend = 'nccl'
148        timeout = datetime.timedelta(seconds=150)
149    elif args.device == 'cpu':
150        device = torch.device('cpu')
151        backend = 'gloo'
152        timeout = datetime.timedelta(seconds=10)
153    else:
154        raise RuntimeError
155
156    if args.seed is not None:
157        torch.manual_seed(args.seed)
158    model = torch.nn.Sequential(
159        *[torch.nn.Linear(args.size, args.size) for _ in range(args.layers)]
160    ).to(device)
161    opt = torch.optim.Adam(model.parameters(), lr=1e-5)
162
163    # TCPStore uses ``(MASTER_PORT + 1)`` to avoid conflicts with a TCPStore
164    # created by ``torch.distributed.run`` and listening on ``MASTER_PORT``.
165    store = torch.distributed.TCPStore(
166        host_name=os.environ['MASTER_ADDR'],
167        port=int(os.environ['MASTER_PORT']) + 1,
168        world_size=int(os.environ['WORLD_SIZE']),
169        is_master=int(os.environ['RANK']) == 0,
170        multi_tenant=True,
171        wait_for_workers=True,
172        use_libuv=True,
173    )
174
175    torch.distributed.init_process_group(
176        backend=backend,
177        store=store,
178        rank=int(os.environ['RANK']),
179        world_size=int(os.environ['WORLD_SIZE']),
180        timeout=timeout,
181    )
182    model_ddp = torch.nn.parallel.DistributedDataParallel(model)
183
184    iteration = 0
185    loss = torch.tensor(float('nan'))
186    checkpoint_path = pathlib.Path(args.path) / 'checkpoint.pt'
187
188    # Application loads state from the latest checkpoint on every restart
189    # iteration of the wrapped function.
190    if checkpoint_path.exists():
191        checkpoint = torch.load(checkpoint_path)
192        model.load_state_dict(checkpoint['model'])
193        opt.load_state_dict(checkpoint['opt'])
194        torch.set_rng_state(checkpoint['rng'])
195        iteration = checkpoint['iteration']
196
197    if args.seed is not None:
198        random.seed(args.seed + iteration * world_size + rank)
199    else:
200        random.seed(time.perf_counter_ns())
201
202    for iteration in range(iteration, args.total_iterations):
203
204        # Application periodically saves a checkpoint. The checkpoint allows
205        # the application to continue from previous state after a restart.
206        if iteration % chkpt_interval == chkpt_interval - 1:
207            torch.distributed.barrier()
208            if rank == 0:
209                checkpoint = {
210                    'model': model.state_dict(),
211                    'opt': opt.state_dict(),
212                    'rng': torch.get_rng_state(),
213                    'iteration': iteration,
214                }
215                # Saving the checkpoint is performed within atomic() context
216                # manager to ensure that the main thread won't execute
217                # torch.save while a restart procedure is in progress.
218                with call_wrapper.atomic():
219                    torch.save(checkpoint, checkpoint_path)
220
221        # Randomly trigger an example fault
222        if random.random() < args.fault_prob:
223            raise_timestamp = time.perf_counter()
224            raise RuntimeError(f'example fault at {iteration=} from {rank=}')
225
226        inp = torch.rand(args.size, args.size).to(device)
227        model.zero_grad()
228        out = model_ddp(inp)
229        loss = out.square().mean()
230        loss.backward()
231        opt.step()
232        loss.item()
233
234        if rank == 0 and iteration % log_interval == log_interval - 1:
235            logging.info(f'{rank=} {iteration=} {loss.item()=}')
236
237
238if __name__ == '__main__':
239    # ``inprocess.Wrapper`` uses logging library to output messages. In this
240    # example the Wrapper is applied to ``main()``, therefore logging needs to
241    # be initialized and configured before the Wrapper is launched.
242    args = parse_args()
243    logging.basicConfig(
244        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
245        level=args.log_level,
246    )
247    main()