Advanced usage example

  1# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
  2# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3# SPDX-License-Identifier: Apache-2.0
  4#
  5# Licensed under the Apache License, Version 2.0 (the "License");
  6# you may not use this file except in compliance with the License.
  7# You may obtain a copy of the License at
  8#
  9# http://www.apache.org/licenses/LICENSE-2.0
 10#
 11# Unless required by applicable law or agreed to in writing, software
 12# distributed under the License is distributed on an "AS IS" BASIS,
 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14# See the License for the specific language governing permissions and
 15# limitations under the License.
 16
 17import argparse
 18import datetime
 19import logging
 20import os
 21import pathlib
 22import random
 23import time
 24from typing import Optional
 25
 26os.environ['TORCH_CPP_LOG_LEVEL'] = 'error'
 27import torch
 28
 29import nvidia_resiliency_ext.inprocess as inprocess
 30
 31raise_timestamp = None
 32
 33################################################################################
 34#
 35# THIS EXAMPLE DEMONSTRATES AN ADVANCED CASE WHERE THE MODEL AND OPTIMIZER ARE
 36# PROCESS-GROUP INDEPENDENT.
 37#
 38# In this example, the model and optimizer are process-group independent,
 39# so they can safely be created outside of the wrapped function and reused.
 40#
 41# See the usage guide for more details.
 42#
 43# This example demonstrates how to integrate ``inprocess.Wrapper()`` into an
 44# existing PyTorch training codebase for an advanced use-case where:
 45# - only the training loop and objects depending on a torch distributed process
 46# group are being restarted upon a failure
 47# - process-group-independent objects (e.g. TCPStore, Model, Optimizer) are
 48# created once, and reused between all restart iterations to minimize restart
 49# latency
 50#
 51# NOTE: inprocess.Wrapper is not fully compatible with modern
 52# ``torch.distributed.run``, because it automatically terminates all local
 53# workers upon any local worker process failure; in this case inprocess.Wrapper
 54# can only recover from transient faults that don't terminate any of the
 55# training processes
 56#
 57################################################################################
 58
 59
 60def parse_args():
 61    parser = argparse.ArgumentParser(
 62        description='Inprocess Restart Advanced Example',
 63        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 64    )
 65
 66    parser.add_argument(
 67        '--size',
 68        default=64,
 69        type=int,
 70        help='model hidden size',
 71    )
 72    parser.add_argument(
 73        '--layers',
 74        default=4,
 75        type=int,
 76        help='number of layers',
 77    )
 78    parser.add_argument(
 79        '--log-interval',
 80        default=100,
 81        type=int,
 82        help='logging interval',
 83    )
 84    parser.add_argument(
 85        '--chkpt-interval',
 86        default=100,
 87        type=int,
 88        help='checkpointing interval',
 89    )
 90    parser.add_argument(
 91        '--total-iterations',
 92        default=1000000,
 93        type=int,
 94        help='total training iterations',
 95    )
 96    parser.add_argument(
 97        '--seed',
 98        default=None,
 99        type=int,
100        help='random seed, time-based if None',
101    )
102    parser.add_argument(
103        '--path',
104        default='/tmp/',
105        type=str,
106        help='directory for the checkpoint file',
107    )
108    parser.add_argument(
109        '--fault-prob',
110        default=0.001,
111        type=float,
112        help='fault injection probability',
113    )
114    parser.add_argument(
115        '--device',
116        default='cuda',
117        choices=['cpu', 'cuda'],
118        help='device',
119    )
120    parser.add_argument(
121        '--log-level',
122        type=lambda s: logging._nameToLevel[s.upper()],
123        default=logging.INFO,
124        help='logging level',
125    )
126
127    parser.add_argument(
128        '--nested-restarter',
129        action='store_true',
130        help='extra logging for the nested restarter',
131    )
132
133    return parser.parse_args()
134
135
136def train(
137    base_store,
138    model,
139    opt,
140    backend,
141    device,
142    timeout,
143    args,
144    call_wrapper: Optional[inprocess.CallWrapper] = None,
145):
146    global raise_timestamp
147    if raise_timestamp is not None:
148        restart_latency = time.perf_counter() - raise_timestamp
149        logging.info(f'restart latency: {restart_latency:.3f}s')
150    raise_timestamp = None
151
152    log_interval = args.log_interval
153    chkpt_interval = args.chkpt_interval
154
155    rank = int(os.environ['RANK'])
156    world_size = int(os.environ['WORLD_SIZE'])
157
158    # Create a new Store by adding a prefix based on the current inprocess
159    # restart iteration. PrefixStore wraps the baseline TCPStore which is
160    # reused for all restart iterations
161    store = torch.distributed.PrefixStore(str(call_wrapper.iteration), base_store)
162
163    torch.distributed.init_process_group(
164        backend,
165        store=store,
166        rank=rank,
167        world_size=world_size,
168        timeout=timeout,
169    )
170    model_ddp = torch.nn.parallel.DistributedDataParallel(model)
171
172    iteration = 0
173    loss = torch.tensor(float('nan'))
174    checkpoint_path = pathlib.Path(args.path) / 'checkpoint.pt'
175
176    # Application loads state from the latest checkpoint on every restart
177    # iteration of the wrapped function.
178    if checkpoint_path.exists():
179        checkpoint = torch.load(checkpoint_path)
180        model.load_state_dict(checkpoint['model'])
181        opt.load_state_dict(checkpoint['opt'])
182        torch.set_rng_state(checkpoint['rng'])
183        iteration = checkpoint['iteration']
184
185    if args.seed is not None:
186        random.seed(args.seed + iteration * world_size + rank)
187    else:
188        random.seed(time.perf_counter_ns())
189
190    for iteration in range(iteration, args.total_iterations):
191
192        # Application periodically saves a checkpoint. The checkpoint allows
193        # the application to continue from previous state after a restart.
194        if iteration % chkpt_interval == chkpt_interval - 1:
195            torch.distributed.barrier()
196            if rank == 0:
197                checkpoint = {
198                    'model': model.state_dict(),
199                    'opt': opt.state_dict(),
200                    'rng': torch.get_rng_state(),
201                    'iteration': iteration,
202                }
203                # Saving the checkpoint is performed within atomic() context
204                # manager to ensure that the main thread won't execute
205                # torch.save while a restart procedure is in progress.
206                with call_wrapper.atomic():
207                    torch.save(checkpoint, checkpoint_path)
208
209        # Randomly trigger an example fault
210        if random.random() < args.fault_prob:
211            raise_timestamp = time.perf_counter()
212            raise RuntimeError(f'example fault at {iteration=} from {rank=}')
213
214        inp = torch.rand(args.size, args.size).to(device)
215        model.zero_grad()
216        out = model_ddp(inp)
217        loss = out.square().mean()
218        loss.backward()
219        opt.step()
220        loss.item()
221
222        if rank == 0 and iteration % log_interval == log_interval - 1:
223            logging.info(f'{rank=} {iteration=} {loss.item()=}')
224
225
226def main():
227    args = parse_args()
228    logging.basicConfig(
229        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
230        level=args.log_level,
231    )
232    logging.info(f'{args}')
233
234    local_rank = int(os.environ['LOCAL_RANK'])
235
236    if args.device == 'cuda':
237        torch.cuda.set_device(local_rank)
238        device = torch.device('cuda')
239        backend = 'nccl'
240        timeout = datetime.timedelta(seconds=150)
241    elif args.device == 'cpu':
242        device = torch.device('cpu')
243        backend = 'gloo'
244        timeout = datetime.timedelta(seconds=10)
245    else:
246        raise RuntimeError
247
248    # In this *specific* example, all objects created in ``main()`` are constructed
249    # only once and reused for all restart iterations. Here, the model and optimizer
250    # are process-group independent, so those objects can safely be created outside
251    # of the wrapped function and reused.
252    #
253    # Normally, variables that are process-group dependent should be created or
254    # initialized inside the wrapped function or the workload is responsible for
255    # ensuring that objects refer to valid process-groups after a restart. See the
256    # usage guide for more details.
257
258    if args.seed is not None:
259        torch.manual_seed(args.seed)
260    model = torch.nn.Sequential(
261        *[torch.nn.Linear(args.size, args.size) for _ in range(args.layers)]
262    ).to(device)
263    opt = torch.optim.Adam(model.parameters(), lr=1e-5)
264
265    # TCPStore uses ``(MASTER_PORT + 2)`` to avoid conflicts with TCPStore
266    # created by ``torch.distributed.run`` and listening on ``MASTER_PORT``,
267    # and Wrapper's TCPStore listening on ``(MASTER_PORT + 1)``
268    store = torch.distributed.TCPStore(
269        host_name=os.environ['MASTER_ADDR'],
270        port=int(os.environ['MASTER_PORT']) + 2,
271        world_size=int(os.environ['WORLD_SIZE']),
272        is_master=(int(os.environ['RANK']) == 0),
273        multi_tenant=True,
274        wait_for_workers=True,
275        use_libuv=True,
276    )
277
278    # TCPStore created by the Wrapper uses ``(MASTER_PORT + 1)`` port for the
279    # internal Wrapper's TCPStore to avoid conflicts with application's TCPStore
280    # listening on ``(MASTER_PORT + 2)``, and with a TCPStore created by
281    # ``torch.distributed.run`` listening on ``MASTER_PORT``.
282    #
283    wrapper_kwargs = {
284        'store_kwargs': {'port': int(os.getenv('MASTER_PORT', 29500)) + 1},
285        'health_check': inprocess.health_check.CudaHealthCheck(),
286    }
287
288    if args.nested_restarter:
289        wrapper_kwargs['initialize'] = inprocess.nested_restarter.NestedRestarterHandlingCompleted()
290        wrapper_kwargs['abort'] = inprocess.Compose(
291            inprocess.abort.AbortTorchDistributed(),
292            inprocess.nested_restarter.NestedRestarterHandlingStarting(),
293        )
294        wrapper_kwargs['completion'] = inprocess.nested_restarter.NestedRestarterFinalized()
295        wrapper_kwargs['terminate'] = inprocess.nested_restarter.NestedRestarterAborted()
296
297    # An instance of ``inprocess.CallWrapper` is automatically injected into
298    # wrapped function arguments when Wrapper is invoked.
299    wrapped_train = inprocess.Wrapper(**wrapper_kwargs)(train)
300
301    # Call the wrapped function.
302    # ``wrapped_train()`` is automatically restarted to recover from faults.
303    wrapped_train(store, model, opt, backend, device, timeout, args)
304
305
306if __name__ == '__main__':
307    main()