Optimal usage example

  1# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
  2# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3# SPDX-License-Identifier: Apache-2.0
  4#
  5# Licensed under the Apache License, Version 2.0 (the "License");
  6# you may not use this file except in compliance with the License.
  7# You may obtain a copy of the License at
  8#
  9# http://www.apache.org/licenses/LICENSE-2.0
 10#
 11# Unless required by applicable law or agreed to in writing, software
 12# distributed under the License is distributed on an "AS IS" BASIS,
 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14# See the License for the specific language governing permissions and
 15# limitations under the License.
 16
 17
 18# This example demonstrates how to integrate ``inprocess.Wrapper()`` into an
 19# existing PyTorch training codebase.
 20#
 21# This example show the optimal usage:
 22# - only the training loop and objects depending on a torch distributed process
 23# group are being restarted upon a failure
 24# - process-group-independent objects (e.g. TCPStore, Model, Optimizer) are
 25# created once, and reused between all restart iterations to minimize restart
 26# latency
 27#
 28# NOTE: inprocess.Wrapper is not fully compatible with modern
 29# ``torch.distributed.run``, because it automatically terminates all local
 30# workers upon any local worker process failure; in this case inprocess.Wrapper
 31# can only recover from transient faults that don't terminate any of the
 32# training processes
 33
 34import argparse
 35import datetime
 36import logging
 37import os
 38import pathlib
 39import random
 40import time
 41from typing import Optional
 42
 43import torch
 44
 45import nvidia_resiliency_ext.inprocess as inprocess
 46
 47raise_timestamp = None
 48
 49
 50def parse_args():
 51    parser = argparse.ArgumentParser(
 52        description='Inprocess Restart Optimal Example',
 53        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 54    )
 55
 56    parser.add_argument(
 57        '--size',
 58        default=64,
 59        type=int,
 60        help='model hidden size',
 61    )
 62    parser.add_argument(
 63        '--layers',
 64        default=4,
 65        type=int,
 66        help='number of layers',
 67    )
 68    parser.add_argument(
 69        '--log-interval',
 70        default=100,
 71        type=int,
 72        help='logging interval',
 73    )
 74    parser.add_argument(
 75        '--chkpt-interval',
 76        default=100,
 77        type=int,
 78        help='checkpointing interval',
 79    )
 80    parser.add_argument(
 81        '--total-iterations',
 82        default=1000000,
 83        type=int,
 84        help='total training iterations',
 85    )
 86    parser.add_argument(
 87        '--seed',
 88        default=None,
 89        type=int,
 90        help='random seed, time-based if None',
 91    )
 92    parser.add_argument(
 93        '--path',
 94        default='/tmp/',
 95        type=str,
 96        help='directory for the checkpoint file',
 97    )
 98    parser.add_argument(
 99        '--fault-prob',
100        default=0.001,
101        type=float,
102        help='fault injection probability',
103    )
104    parser.add_argument(
105        '--device',
106        default='cpu',
107        choices=['cpu', 'cuda'],
108        help='device',
109    )
110    parser.add_argument(
111        '--log-level',
112        type=lambda s: logging._nameToLevel[s.upper()],
113        default=logging.INFO,
114        help='logging level',
115    )
116
117    return parser.parse_args()
118
119
120# TCPStore created by the Wrapper uses ``(MASTER_PORT + 2)`` port for the
121# internal Wrapper TCPStore to avoid conflicts with application's TCPStore
122# listening on ``(MASTER_PORT + 1)``, and with a TCPStore created by
123# ``torch.distributed.run`` listening on ``MASTER_PORT``.
124#
125# An instance of ``inprocess.CallWrapper` is automatically injected into
126# wrapped function arguments when Wrapper is invoked.
127@inprocess.Wrapper(
128    store_kwargs={'port': int(os.getenv('MASTER_PORT', 29500)) + 2},
129    health_check=inprocess.health_check.CudaHealthCheck(),
130)
131def train(
132    base_store,
133    model,
134    opt,
135    backend,
136    device,
137    timeout,
138    args,
139    call_wrapper: Optional[inprocess.CallWrapper] = None,
140):
141    global raise_timestamp
142    if raise_timestamp is not None:
143        restart_latency = time.perf_counter() - raise_timestamp
144        logging.info(f'restart latency: {restart_latency:.3f}s')
145    raise_timestamp = None
146
147    log_interval = args.log_interval
148    chkpt_interval = args.chkpt_interval
149
150    rank = int(os.environ['RANK'])
151    world_size = int(os.environ['WORLD_SIZE'])
152
153    # Create a new Store by adding a prefix based on the current inprocess
154    # restart iteration. PrefixStore wraps the baseline TCPStore which is
155    # reused for all restart iterations
156    store = torch.distributed.PrefixStore(
157        str(call_wrapper.iteration), base_store
158    )
159
160    torch.distributed.init_process_group(
161        backend,
162        store=store,
163        rank=rank,
164        world_size=world_size,
165        timeout=timeout,
166    )
167    model_ddp = torch.nn.parallel.DistributedDataParallel(model)
168
169    iteration = 0
170    loss = torch.tensor(float('nan'))
171    checkpoint_path = pathlib.Path(args.path) / 'checkpoint.pt'
172
173    # Application loads state from the latest checkpoint on every restart
174    # iteration of the wrapped function.
175    if checkpoint_path.exists():
176        checkpoint = torch.load(checkpoint_path)
177        model.load_state_dict(checkpoint['model'])
178        opt.load_state_dict(checkpoint['opt'])
179        torch.set_rng_state(checkpoint['rng'])
180        iteration = checkpoint['iteration']
181
182    if args.seed is not None:
183        random.seed(args.seed + iteration * world_size + rank)
184    else:
185        random.seed(time.perf_counter_ns())
186
187    for iteration in range(iteration, args.total_iterations):
188
189        # Application periodically saves a checkpoint. The checkpoint allows
190        # the application to continue from previous state after a restart.
191        if iteration % chkpt_interval == chkpt_interval - 1:
192            torch.distributed.barrier()
193            if rank == 0:
194                checkpoint = {
195                    'model': model.state_dict(),
196                    'opt': opt.state_dict(),
197                    'rng': torch.get_rng_state(),
198                    'iteration': iteration,
199                }
200                # Saving the checkpoint is performed within atomic() context
201                # manager to ensure that the main thread won't execute
202                # torch.save while a restart procedure is in progress.
203                with call_wrapper.atomic():
204                    torch.save(checkpoint, checkpoint_path)
205
206        # Randomly trigger an example fault
207        if random.random() < args.fault_prob:
208            raise_timestamp = time.perf_counter()
209            raise RuntimeError(f'example fault at {iteration=} from {rank=}')
210
211        inp = torch.rand(args.size, args.size).to(device)
212        model.zero_grad()
213        out = model_ddp(inp)
214        loss = out.square().mean()
215        loss.backward()
216        opt.step()
217        loss.item()
218
219        if rank == 0 and iteration % log_interval == log_interval - 1:
220            logging.info(f'{rank=} {iteration=} {loss.item()=}')
221
222
223def main():
224    args = parse_args()
225    logging.basicConfig(
226        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
227        level=args.log_level,
228    )
229    logging.info(f'{args}')
230
231    rank = int(os.environ['RANK'])
232    local_rank = int(os.environ['LOCAL_RANK'])
233
234    if args.device == 'cuda':
235        torch.cuda.set_device(local_rank)
236        device = torch.device('cuda')
237        backend = 'nccl'
238        timeout = datetime.timedelta(seconds=150)
239    elif args.device == 'cpu':
240        device = torch.device('cpu')
241        backend = 'gloo'
242        timeout = datetime.timedelta(seconds=10)
243    else:
244        raise RuntimeError
245
246    # All objects created in ``main()`` are constructed only once, and reused
247    # for all restart iterations.
248    if args.seed is not None:
249        torch.manual_seed(args.seed)
250    model = torch.nn.Sequential(
251        *[torch.nn.Linear(args.size, args.size) for _ in range(args.layers)]
252    ).to(device)
253    opt = torch.optim.Adam(model.parameters(), lr=1e-5)
254
255    # TCPStore uses ``(MASTER_PORT + 1)`` to avoid conflicts with TCPStore
256    # created by ``torch.distributed.run`` and listening on ``MASTER_PORT``.
257    store = torch.distributed.TCPStore(
258        host_name=os.environ['MASTER_ADDR'],
259        port=int(os.environ['MASTER_PORT']) + 1,
260        world_size=int(os.environ['WORLD_SIZE']),
261        is_master=(int(os.environ['RANK']) == 0),
262        multi_tenant=True,
263        wait_for_workers=True,
264        use_libuv=True,
265    )
266
267    # Call the wrapped function.
268    # ``train()`` is automatically restarted to recover from faults.
269    train(store, model, opt, backend, device, timeout, args)
270
271
272if __name__ == '__main__':
273    main()