FT Launcher & Inprocess integration example

  1# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
  2# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  3# SPDX-License-Identifier: Apache-2.0
  4#
  5# Licensed under the Apache License, Version 2.0 (the "License");
  6# you may not use this file except in compliance with the License.
  7# You may obtain a copy of the License at
  8#
  9# http://www.apache.org/licenses/LICENSE-2.0
 10#
 11# Unless required by applicable law or agreed to in writing, software
 12# distributed under the License is distributed on an "AS IS" BASIS,
 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14# See the License for the specific language governing permissions and
 15# limitations under the License.
 16
 17# This example demonstrates how to integrate ``inprocess`` and ``fault_tolerance``
 18# into an existing PyTorch training codebase. For simplicity, ``inprocess`` does not
 19# filter out any ranks, and there are no idle or spare ranks. Otherwise, fault tolerance (FT)
 20# would need to be disabled on inactive ranks.
 21#
 22# To run this example, use the accompanying bash script:
 23# ./examples/fault_tolerance/run_inprocess_injob_example.sh
 24
 25import argparse
 26import contextlib
 27import datetime
 28import logging
 29import os
 30import pathlib
 31import random
 32import signal
 33import time
 34from dataclasses import dataclass
 35from typing import Mapping, Optional, Tuple
 36
 37import torch
 38
 39import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
 40import nvidia_resiliency_ext.inprocess as inprocess
 41
 42raise_timestamp = None
 43
 44
 45def _get_last_sim_fault_iter_path(rank, target_dir="/tmp/") -> str:
 46    # Returns the path of the file that stores the last simulated fault iteration for this rank
 47    return os.path.join(target_dir, f"_injob_inproc_example_rank{rank}_failed_iter.txt")
 48
 49
 50def _save_last_sim_fault_iter(rank, iteration, target_dir="/tmp/"):
 51    file_path = _get_last_sim_fault_iter_path(rank=rank, target_dir=target_dir)
 52    with open(file_path, mode='w') as f:
 53        f.write(f"{iteration}")
 54
 55
 56def _get_last_sim_fault_iter(rank, target_dir="/tmp/") -> int:
 57    file_path = _get_last_sim_fault_iter_path(rank=rank, target_dir=target_dir)
 58    if os.path.exists(file_path):
 59        with open(file_path, mode='r') as f:
 60            return int(f.read())
 61    return None
 62
 63
 64@dataclass
 65class _SimFaultDesc:
 66    """Represents a simulated fault description"""
 67
 68    rank: int
 69    iteration: int
 70    fault_type: str
 71
 72    @classmethod
 73    def from_str(cls, str_desc):
 74        try:
 75            split = str_desc.split(':')
 76            return cls(int(split[0]), int(split[1]), split[2].strip())
 77        except ValueError:
 78            raise argparse.ArgumentTypeError(
 79                f"Invalid format for a simulated fault description: {str_desc}"
 80            )
 81
 82
 83def _parse_fault_desc_arg(value) -> Mapping[Tuple[int, int], _SimFaultDesc]:
 84    # Returns a mapping of (rank, iteration) to the simulated fault that should occur at that point.
 85    rank_iter_to_fault = dict()
 86    if value:
 87        for str_desc in value.split(','):
 88            f = _SimFaultDesc.from_str(str_desc)
 89            rank_iter_to_fault[(f.rank, f.iteration)] = f
 90    return rank_iter_to_fault
 91
 92
 93def _maybe_simulate_fault(rank, iteration, rank_iter_to_fault):
 94
 95    # Checks whether a simulated fault should be triggered at the given rank and iteration.
 96    # Executes the simulated fault if the conditions are met.
 97
 98    fault_desc = rank_iter_to_fault.get((rank, iteration), None)
 99
100    if fault_desc is None:
101        return
102
103    if _get_last_sim_fault_iter(rank) == iteration:
104        # Prevents re-triggering the same fault after resuming from a checkpoint.
105        logging.info(f'Skipped sim fault {fault_desc} as it was triggered before')
106        return
107
108    _save_last_sim_fault_iter(rank, iteration)
109
110    logging.info(f'\n\n\n### Issuing simulated fault {fault_desc} ###\n\n\n')
111
112    global raise_timestamp
113    raise_timestamp = time.perf_counter()
114
115    if fault_desc.fault_type == 'exc':
116        raise RuntimeError(f'example fault at {iteration=} from {rank=}')
117    elif fault_desc.fault_type == 'sigkill':
118        os.kill(os.getpid(), signal.SIGKILL)
119    elif fault_desc.fault_type == 'sleep':
120        time.sleep(int(1e6))
121    else:
122        raise BaseException(f"Unexpected fault type {fault_desc.fault_type}")
123
124
125def parse_args():
126    parser = argparse.ArgumentParser(
127        description='Inprocess and Fault Tolerance Example',
128        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
129    )
130
131    parser.add_argument(
132        '--size',
133        default=64,
134        type=int,
135        help='model hidden size',
136    )
137    parser.add_argument(
138        '--layers',
139        default=4,
140        type=int,
141        help='number of layers',
142    )
143    parser.add_argument(
144        '--log-interval',
145        default=100,
146        type=int,
147        help='logging interval',
148    )
149    parser.add_argument(
150        '--chkpt-interval',
151        default=100,
152        type=int,
153        help='checkpointing interval',
154    )
155    parser.add_argument(
156        '--total-iterations',
157        default=1000000,
158        type=int,
159        help='total training iterations',
160    )
161    parser.add_argument(
162        '--seed',
163        default=None,
164        type=int,
165        help='random seed, time-based if None',
166    )
167    parser.add_argument(
168        '--path',
169        default='/tmp/',
170        type=str,
171        help='directory for the checkpoint file',
172    )
173    parser.add_argument(
174        '--fault-iters',
175        default='',
176        type=_parse_fault_desc_arg,
177        help='Comma-separated list of rank:iter:fault tuples for fault injection. '
178        'fault can be exc|sleep|sigkill. Example: 0:1000:exc,1:2000,sleep',
179    )
180    parser.add_argument(
181        '--device',
182        default='cpu',
183        choices=['cpu', 'cuda'],
184        help='device',
185    )
186    parser.add_argument(
187        '--log-level',
188        type=lambda s: logging._nameToLevel[s.upper()],
189        default=logging.INFO,
190        help='logging level',
191    )
192
193    return parser.parse_args()
194
195
196# TCPStore created by the Wrapper uses ``(MASTER_PORT + 2)`` port for the
197# internal Wrapper TCPStore to avoid conflicts with application's TCPStore
198# listening on ``(MASTER_PORT + 1)``, and with a TCPStore created by
199# ``torch.distributed.run`` listening on ``MASTER_PORT``.
200#
201# An instance of ``inprocess.CallWrapper` is automatically injected into
202# wrapped function arguments when Wrapper is invoked.
203
204
205@inprocess.Wrapper(
206    store_kwargs={'port': int(os.getenv('MASTER_PORT', 29500)) + 2},
207    health_check=inprocess.health_check.CudaHealthCheck(),
208)
209def train(
210    ft_client,
211    base_store,
212    model,
213    opt,
214    backend,
215    device,
216    timeout,
217    args,
218    call_wrapper: Optional[inprocess.CallWrapper] = None,
219):
220    global raise_timestamp
221    if raise_timestamp is not None:
222        restart_latency = time.perf_counter() - raise_timestamp
223        logging.info(f'restart latency: {restart_latency:.3f}s')
224    raise_timestamp = None
225
226    log_interval = args.log_interval
227    chkpt_interval = args.chkpt_interval
228
229    rank = int(os.environ['RANK'])
230    world_size = int(os.environ['WORLD_SIZE'])
231
232    logging.info(f"### STARTING RANK {rank} IN WORLD_SIZE {world_size} ###")
233
234    # Reconnects FT so that rank monitors are aware of potential changes in rank-to-node mapping
235    if ft_client.is_initialized:
236        ft_client.shutdown_workload_monitoring()
237    ft_client.init_workload_monitoring()
238
239    # Create a new Store by adding a prefix based on the current inprocess
240    # restart iteration. PrefixStore wraps the baseline TCPStore which is
241    # reused for all restart iterations
242    store = torch.distributed.PrefixStore(str(call_wrapper.iteration), base_store)
243
244    torch.distributed.init_process_group(
245        backend,
246        store=store,
247        rank=rank,
248        world_size=world_size,
249        timeout=timeout,
250    )
251
252    model_ddp = torch.nn.parallel.DistributedDataParallel(model)
253
254    iteration = 0
255    loss = torch.tensor(float('nan'))
256    checkpoint_path = pathlib.Path(args.path) / '_in_process_example_checkpoint.pt'
257
258    # Application loads state from the latest checkpoint on every restart
259    # iteration of the wrapped function.
260    if checkpoint_path.exists():
261        checkpoint = torch.load(checkpoint_path, weights_only=True)
262        model.load_state_dict(checkpoint['model'])
263        opt.load_state_dict(checkpoint['opt'])
264        torch.set_rng_state(checkpoint['rng'])
265        iteration = checkpoint['iteration']
266        ft_client.load_state_dict(checkpoint['ft_state'])
267    else:
268        # if starting from scratch
269        with contextlib.suppress(FileNotFoundError):
270            os.unlink(_get_last_sim_fault_iter_path(rank))
271
272    if args.seed is not None:
273        random.seed(args.seed + iteration * world_size + rank)
274    else:
275        random.seed(time.perf_counter_ns())
276
277    for iteration in range(iteration, args.total_iterations):
278
279        # Application periodically saves a checkpoint. The checkpoint allows
280        # the application to continue from previous state after a restart.
281        if iteration % chkpt_interval == chkpt_interval - 1:
282            torch.distributed.barrier()
283            if rank == 0:
284                checkpoint = {
285                    'model': model.state_dict(),
286                    'opt': opt.state_dict(),
287                    'rng': torch.get_rng_state(),
288                    'iteration': iteration,
289                    'ft_state': ft_client.state_dict(),
290                }
291                # Saving the checkpoint is performed within atomic() context
292                # manager to ensure that the main thread won't execute
293                # torch.save while a restart procedure is in progress.
294                with call_wrapper.atomic():
295                    torch.save(checkpoint, checkpoint_path)
296
297        _maybe_simulate_fault(rank, iteration, args.fault_iters)
298
299        inp = torch.rand(args.size, args.size).to(device)
300        model.zero_grad()
301        out = model_ddp(inp)
302        loss = out.square().mean()
303        loss.backward()
304        opt.step()
305        loss.item()
306
307        if rank == 0 and iteration % log_interval == log_interval - 1:
308            logging.info(f'{rank=} {iteration=} {loss.item()=}')
309
310        ft_client.send_heartbeat()  # notifies FT that the training process is still active.
311
312
313def main():
314    args = parse_args()
315    logging.basicConfig(
316        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
317        level=args.log_level,
318    )
319
320    rank = int(os.environ['RANK'])
321    local_rank = int(os.environ['LOCAL_RANK'])
322    world_size = int(os.environ['WORLD_SIZE'])
323
324    if rank == 0:
325        logging.info(f'\n##### NEW RUN {args} #####n')
326
327    if args.device == 'cuda':
328        torch.cuda.set_device(local_rank)
329        device = torch.device('cuda')
330        backend = 'nccl'
331        timeout = datetime.timedelta(seconds=150)
332    elif args.device == 'cpu':
333        device = torch.device('cpu')
334        backend = 'gloo'
335        timeout = datetime.timedelta(seconds=10)
336    else:
337        raise RuntimeError
338
339    # All objects created in ``main()`` are constructed only once, and reused
340    # for all restart iterations.
341    if args.seed is not None:
342        torch.manual_seed(args.seed)
343    model = torch.nn.Sequential(
344        *[torch.nn.Linear(args.size, args.size) for _ in range(args.layers)]
345    ).to(device)
346    opt = torch.optim.Adam(model.parameters(), lr=1e-5)
347
348    # TCPStore uses ``(MASTER_PORT + 1)`` to avoid conflicts with TCPStore
349    # created by ``torch.distributed.run`` and listening on ``MASTER_PORT``.
350    store = torch.distributed.TCPStore(
351        host_name=os.environ['MASTER_ADDR'],
352        port=int(os.environ['MASTER_PORT']) + 1,
353        world_size=int(os.environ['WORLD_SIZE']),
354        is_master=(int(os.environ['RANK']) == 0),
355        multi_tenant=True,
356        wait_for_workers=True,
357        use_libuv=True,
358    )
359
360    # Prepares the FT client instance, it will be initialized in the ``train()``.
361    ft_client = fault_tolerance.RankMonitorClient()
362
363    try:
364        # Call the wrapped function.
365        # ``train()`` is automatically restarted to recover from faults.
366        train(ft_client, store, model, opt, backend, device, timeout, args)
367    finally:
368        if ft_client.is_initialized:
369            ft_client.shutdown_workload_monitoring()
370
371
372if __name__ == '__main__':
373    main()