Optimal usage example

  1# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
  2# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3# SPDX-License-Identifier: Apache-2.0
  4#
  5# Licensed under the Apache License, Version 2.0 (the "License");
  6# you may not use this file except in compliance with the License.
  7# You may obtain a copy of the License at
  8#
  9# http://www.apache.org/licenses/LICENSE-2.0
 10#
 11# Unless required by applicable law or agreed to in writing, software
 12# distributed under the License is distributed on an "AS IS" BASIS,
 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14# See the License for the specific language governing permissions and
 15# limitations under the License.
 16
 17
 18# This example demonstrates how to integrate ``inprocess.Wrapper()`` into an
 19# existing PyTorch training codebase.
 20#
 21# This example show the optimal usage:
 22# - only the training loop and objects depending on a torch distributed process
 23# group are being restarted upon a failure
 24# - process-group-independent objects (e.g. TCPStore, Model, Optimizer) are
 25# created once, and reused between all restart iterations to minimize restart
 26# latency
 27#
 28# NOTE: inprocess.Wrapper is not fully compatible with modern
 29# ``torch.distributed.run``, because it automatically terminates all local
 30# workers upon any local worker process failure; in this case inprocess.Wrapper
 31# can only recover from transient faults that don't terminate any of the
 32# training processes
 33
 34import argparse
 35import datetime
 36import logging
 37import os
 38import pathlib
 39import random
 40import time
 41from typing import Optional
 42
 43os.environ['TORCH_CPP_LOG_LEVEL'] = 'error'
 44import torch
 45
 46import nvidia_resiliency_ext.inprocess as inprocess
 47
 48raise_timestamp = None
 49
 50
 51def parse_args():
 52    parser = argparse.ArgumentParser(
 53        description='Inprocess Restart Optimal Example',
 54        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 55    )
 56
 57    parser.add_argument(
 58        '--size',
 59        default=64,
 60        type=int,
 61        help='model hidden size',
 62    )
 63    parser.add_argument(
 64        '--layers',
 65        default=4,
 66        type=int,
 67        help='number of layers',
 68    )
 69    parser.add_argument(
 70        '--log-interval',
 71        default=100,
 72        type=int,
 73        help='logging interval',
 74    )
 75    parser.add_argument(
 76        '--chkpt-interval',
 77        default=100,
 78        type=int,
 79        help='checkpointing interval',
 80    )
 81    parser.add_argument(
 82        '--total-iterations',
 83        default=1000000,
 84        type=int,
 85        help='total training iterations',
 86    )
 87    parser.add_argument(
 88        '--seed',
 89        default=None,
 90        type=int,
 91        help='random seed, time-based if None',
 92    )
 93    parser.add_argument(
 94        '--path',
 95        default='/tmp/',
 96        type=str,
 97        help='directory for the checkpoint file',
 98    )
 99    parser.add_argument(
100        '--fault-prob',
101        default=0.001,
102        type=float,
103        help='fault injection probability',
104    )
105    parser.add_argument(
106        '--device',
107        default='cpu',
108        choices=['cpu', 'cuda'],
109        help='device',
110    )
111    parser.add_argument(
112        '--log-level',
113        type=lambda s: logging._nameToLevel[s.upper()],
114        default=logging.INFO,
115        help='logging level',
116    )
117
118    parser.add_argument(
119        '--nested-restarter',
120        action='store_true',
121        help='extra logging for the nested restarter',
122    )
123
124    return parser.parse_args()
125
126
127def train(
128    base_store,
129    model,
130    opt,
131    backend,
132    device,
133    timeout,
134    args,
135    call_wrapper: Optional[inprocess.CallWrapper] = None,
136):
137    global raise_timestamp
138    if raise_timestamp is not None:
139        restart_latency = time.perf_counter() - raise_timestamp
140        logging.info(f'restart latency: {restart_latency:.3f}s')
141    raise_timestamp = None
142
143    log_interval = args.log_interval
144    chkpt_interval = args.chkpt_interval
145
146    rank = int(os.environ['RANK'])
147    world_size = int(os.environ['WORLD_SIZE'])
148
149    # Create a new Store by adding a prefix based on the current inprocess
150    # restart iteration. PrefixStore wraps the baseline TCPStore which is
151    # reused for all restart iterations
152    store = torch.distributed.PrefixStore(str(call_wrapper.iteration), base_store)
153
154    torch.distributed.init_process_group(
155        backend,
156        store=store,
157        rank=rank,
158        world_size=world_size,
159        timeout=timeout,
160    )
161    model_ddp = torch.nn.parallel.DistributedDataParallel(model)
162
163    iteration = 0
164    loss = torch.tensor(float('nan'))
165    checkpoint_path = pathlib.Path(args.path) / 'checkpoint.pt'
166
167    # Application loads state from the latest checkpoint on every restart
168    # iteration of the wrapped function.
169    if checkpoint_path.exists():
170        checkpoint = torch.load(checkpoint_path)
171        model.load_state_dict(checkpoint['model'])
172        opt.load_state_dict(checkpoint['opt'])
173        torch.set_rng_state(checkpoint['rng'])
174        iteration = checkpoint['iteration']
175
176    if args.seed is not None:
177        random.seed(args.seed + iteration * world_size + rank)
178    else:
179        random.seed(time.perf_counter_ns())
180
181    for iteration in range(iteration, args.total_iterations):
182
183        # Application periodically saves a checkpoint. The checkpoint allows
184        # the application to continue from previous state after a restart.
185        if iteration % chkpt_interval == chkpt_interval - 1:
186            torch.distributed.barrier()
187            if rank == 0:
188                checkpoint = {
189                    'model': model.state_dict(),
190                    'opt': opt.state_dict(),
191                    'rng': torch.get_rng_state(),
192                    'iteration': iteration,
193                }
194                # Saving the checkpoint is performed within atomic() context
195                # manager to ensure that the main thread won't execute
196                # torch.save while a restart procedure is in progress.
197                with call_wrapper.atomic():
198                    torch.save(checkpoint, checkpoint_path)
199
200        # Randomly trigger an example fault
201        if random.random() < args.fault_prob:
202            raise_timestamp = time.perf_counter()
203            raise RuntimeError(f'example fault at {iteration=} from {rank=}')
204
205        inp = torch.rand(args.size, args.size).to(device)
206        model.zero_grad()
207        out = model_ddp(inp)
208        loss = out.square().mean()
209        loss.backward()
210        opt.step()
211        loss.item()
212
213        if rank == 0 and iteration % log_interval == log_interval - 1:
214            logging.info(f'{rank=} {iteration=} {loss.item()=}')
215
216
217def main():
218    args = parse_args()
219    logging.basicConfig(
220        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
221        level=args.log_level,
222    )
223    logging.info(f'{args}')
224
225    local_rank = int(os.environ['LOCAL_RANK'])
226
227    if args.device == 'cuda':
228        torch.cuda.set_device(local_rank)
229        device = torch.device('cuda')
230        backend = 'nccl'
231        timeout = datetime.timedelta(seconds=150)
232    elif args.device == 'cpu':
233        device = torch.device('cpu')
234        backend = 'gloo'
235        timeout = datetime.timedelta(seconds=10)
236    else:
237        raise RuntimeError
238
239    # All objects created in ``main()`` are constructed only once, and reused
240    # for all restart iterations.
241    if args.seed is not None:
242        torch.manual_seed(args.seed)
243    model = torch.nn.Sequential(
244        *[torch.nn.Linear(args.size, args.size) for _ in range(args.layers)]
245    ).to(device)
246    opt = torch.optim.Adam(model.parameters(), lr=1e-5)
247
248    # TCPStore uses ``(MASTER_PORT + 2)`` to avoid conflicts with TCPStore
249    # created by ``torch.distributed.run`` and listening on ``MASTER_PORT``,
250    # and Wrapper's TCPStore listening on ``(MASTER_PORT + 1)``
251    store = torch.distributed.TCPStore(
252        host_name=os.environ['MASTER_ADDR'],
253        port=int(os.environ['MASTER_PORT']) + 2,
254        world_size=int(os.environ['WORLD_SIZE']),
255        is_master=(int(os.environ['RANK']) == 0),
256        multi_tenant=True,
257        wait_for_workers=True,
258        use_libuv=True,
259    )
260
261    # TCPStore created by the Wrapper uses ``(MASTER_PORT + 1)`` port for the
262    # internal Wrapper's TCPStore to avoid conflicts with application's TCPStore
263    # listening on ``(MASTER_PORT + 2)``, and with a TCPStore created by
264    # ``torch.distributed.run`` listening on ``MASTER_PORT``.
265    #
266    wrapper_kwargs = {
267        'store_kwargs': {'port': int(os.getenv('MASTER_PORT', 29500)) + 1},
268        'health_check': inprocess.health_check.CudaHealthCheck(),
269    }
270
271    if args.nested_restarter:
272        wrapper_kwargs['initialize'] = inprocess.nested_restarter.NestedRestarterHandlingCompleted()
273        wrapper_kwargs['abort'] = inprocess.Compose(
274            inprocess.abort.AbortTorchDistributed(),
275            inprocess.nested_restarter.NestedRestarterHandlingStarting(),
276        )
277        wrapper_kwargs['completion'] = inprocess.nested_restarter.NestedRestarterFinalized()
278        wrapper_kwargs['terminate'] = inprocess.nested_restarter.NestedRestarterAborted()
279
280    # An instance of ``inprocess.CallWrapper` is automatically injected into
281    # wrapped function arguments when Wrapper is invoked.
282    wrapped_train = inprocess.Wrapper(**wrapper_kwargs)(train)
283
284    # Call the wrapped function.
285    # ``wrapped_train()`` is automatically restarted to recover from faults.
286    wrapped_train(store, model, opt, backend, device, timeout, args)
287
288
289if __name__ == '__main__':
290    main()