1# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# This example demonstrates how to integrate ``inprocess`` and ``fault_tolerance``
18# into an existing PyTorch training codebase. For simplicity, ``inprocess`` does not
19# filter out any ranks, and there are no idle or spare ranks. Otherwise, fault tolerance (FT)
20# would need to be disabled on inactive ranks.
21#
22# To run this example, use the accompanying bash script:
23# ./examples/fault_tolerance/run_inprocess_injob_example.sh
24
25import argparse
26import contextlib
27import datetime
28import logging
29import os
30import pathlib
31import random
32import signal
33import time
34from dataclasses import dataclass
35from typing import Mapping, Optional, Tuple
36
37import torch
38
39import nvidia_resiliency_ext.fault_tolerance as fault_tolerance
40import nvidia_resiliency_ext.inprocess as inprocess
41
42raise_timestamp = None
43
44
45def _get_last_sim_fault_iter_path(rank, target_dir="/tmp/") -> str:
46 # Returns the path of the file that stores the last simulated fault iteration for this rank
47 return os.path.join(target_dir, f"_injob_inproc_example_rank{rank}_failed_iter.txt")
48
49
50def _save_last_sim_fault_iter(rank, iteration, target_dir="/tmp/"):
51 file_path = _get_last_sim_fault_iter_path(rank=rank, target_dir=target_dir)
52 with open(file_path, mode='w') as f:
53 f.write(f"{iteration}")
54
55
56def _get_last_sim_fault_iter(rank, target_dir="/tmp/") -> int:
57 file_path = _get_last_sim_fault_iter_path(rank=rank, target_dir=target_dir)
58 if os.path.exists(file_path):
59 with open(file_path, mode='r') as f:
60 return int(f.read())
61 return None
62
63
64@dataclass
65class _SimFaultDesc:
66 """Represents a simulated fault description"""
67
68 rank: int
69 iteration: int
70 fault_type: str
71
72 @classmethod
73 def from_str(cls, str_desc):
74 try:
75 split = str_desc.split(':')
76 return cls(int(split[0]), int(split[1]), split[2].strip())
77 except ValueError:
78 raise argparse.ArgumentTypeError(
79 f"Invalid format for a simulated fault description: {str_desc}"
80 )
81
82
83def _parse_fault_desc_arg(value) -> Mapping[Tuple[int, int], _SimFaultDesc]:
84 # Returns a mapping of (rank, iteration) to the simulated fault that should occur at that point.
85 rank_iter_to_fault = dict()
86 if value:
87 for str_desc in value.split(','):
88 f = _SimFaultDesc.from_str(str_desc)
89 rank_iter_to_fault[(f.rank, f.iteration)] = f
90 return rank_iter_to_fault
91
92
93def _maybe_simulate_fault(rank, iteration, rank_iter_to_fault):
94
95 # Checks whether a simulated fault should be triggered at the given rank and iteration.
96 # Executes the simulated fault if the conditions are met.
97
98 fault_desc = rank_iter_to_fault.get((rank, iteration), None)
99
100 if fault_desc is None:
101 return
102
103 if _get_last_sim_fault_iter(rank) == iteration:
104 # Prevents re-triggering the same fault after resuming from a checkpoint.
105 logging.info(f'Skipped sim fault {fault_desc} as it was triggered before')
106 return
107
108 _save_last_sim_fault_iter(rank, iteration)
109
110 logging.info(f'\n\n\n### Issuing simulated fault {fault_desc} ###\n\n\n')
111
112 global raise_timestamp
113 raise_timestamp = time.perf_counter()
114
115 if fault_desc.fault_type == 'exc':
116 raise RuntimeError(f'example fault at {iteration=} from {rank=}')
117 elif fault_desc.fault_type == 'sigkill':
118 os.kill(os.getpid(), signal.SIGKILL)
119 elif fault_desc.fault_type == 'sleep':
120 time.sleep(int(1e6))
121 else:
122 raise BaseException(f"Unexpected fault type {fault_desc.fault_type}")
123
124
125def parse_args():
126 parser = argparse.ArgumentParser(
127 description='Inprocess and Fault Tolerance Example',
128 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
129 )
130
131 parser.add_argument(
132 '--size',
133 default=64,
134 type=int,
135 help='model hidden size',
136 )
137 parser.add_argument(
138 '--layers',
139 default=4,
140 type=int,
141 help='number of layers',
142 )
143 parser.add_argument(
144 '--log-interval',
145 default=100,
146 type=int,
147 help='logging interval',
148 )
149 parser.add_argument(
150 '--chkpt-interval',
151 default=100,
152 type=int,
153 help='checkpointing interval',
154 )
155 parser.add_argument(
156 '--total-iterations',
157 default=1000000,
158 type=int,
159 help='total training iterations',
160 )
161 parser.add_argument(
162 '--seed',
163 default=None,
164 type=int,
165 help='random seed, time-based if None',
166 )
167 parser.add_argument(
168 '--path',
169 default='/tmp/',
170 type=str,
171 help='directory for the checkpoint file',
172 )
173 parser.add_argument(
174 '--fault-iters',
175 default='',
176 type=_parse_fault_desc_arg,
177 help='Comma-separated list of rank:iter:fault tuples for fault injection. '
178 'fault can be exc|sleep|sigkill. Example: 0:1000:exc,1:2000,sleep',
179 )
180 parser.add_argument(
181 '--device',
182 default='cpu',
183 choices=['cpu', 'cuda'],
184 help='device',
185 )
186 parser.add_argument(
187 '--log-level',
188 type=lambda s: logging._nameToLevel[s.upper()],
189 default=logging.INFO,
190 help='logging level',
191 )
192
193 return parser.parse_args()
194
195
196# TCPStore created by the Wrapper uses ``(MASTER_PORT + 2)`` port for the
197# internal Wrapper TCPStore to avoid conflicts with application's TCPStore
198# listening on ``(MASTER_PORT + 1)``, and with a TCPStore created by
199# ``torch.distributed.run`` listening on ``MASTER_PORT``.
200#
201# An instance of ``inprocess.CallWrapper` is automatically injected into
202# wrapped function arguments when Wrapper is invoked.
203
204
205@inprocess.Wrapper(
206 store_kwargs={'port': int(os.getenv('MASTER_PORT', 29500)) + 2},
207 health_check=inprocess.health_check.CudaHealthCheck(),
208)
209def train(
210 ft_client,
211 base_store,
212 model,
213 opt,
214 backend,
215 device,
216 timeout,
217 args,
218 call_wrapper: Optional[inprocess.CallWrapper] = None,
219):
220 global raise_timestamp
221 if raise_timestamp is not None:
222 restart_latency = time.perf_counter() - raise_timestamp
223 logging.info(f'restart latency: {restart_latency:.3f}s')
224 raise_timestamp = None
225
226 log_interval = args.log_interval
227 chkpt_interval = args.chkpt_interval
228
229 rank = int(os.environ['RANK'])
230 world_size = int(os.environ['WORLD_SIZE'])
231
232 logging.info(f"### STARTING RANK {rank} IN WORLD_SIZE {world_size} ###")
233
234 # Reconnects FT so that rank monitors are aware of potential changes in rank-to-node mapping
235 if ft_client.is_initialized:
236 ft_client.shutdown_workload_monitoring()
237 ft_client.init_workload_monitoring()
238
239 # Create a new Store by adding a prefix based on the current inprocess
240 # restart iteration. PrefixStore wraps the baseline TCPStore which is
241 # reused for all restart iterations
242 store = torch.distributed.PrefixStore(str(call_wrapper.iteration), base_store)
243
244 torch.distributed.init_process_group(
245 backend,
246 store=store,
247 rank=rank,
248 world_size=world_size,
249 timeout=timeout,
250 )
251
252 model_ddp = torch.nn.parallel.DistributedDataParallel(model)
253
254 iteration = 0
255 loss = torch.tensor(float('nan'))
256 checkpoint_path = pathlib.Path(args.path) / '_in_process_example_checkpoint.pt'
257
258 # Application loads state from the latest checkpoint on every restart
259 # iteration of the wrapped function.
260 if checkpoint_path.exists():
261 checkpoint = torch.load(checkpoint_path, weights_only=True)
262 model.load_state_dict(checkpoint['model'])
263 opt.load_state_dict(checkpoint['opt'])
264 torch.set_rng_state(checkpoint['rng'])
265 iteration = checkpoint['iteration']
266 ft_client.load_state_dict(checkpoint['ft_state'])
267 else:
268 # if starting from scratch
269 with contextlib.suppress(FileNotFoundError):
270 os.unlink(_get_last_sim_fault_iter_path(rank))
271
272 if args.seed is not None:
273 random.seed(args.seed + iteration * world_size + rank)
274 else:
275 random.seed(time.perf_counter_ns())
276
277 for iteration in range(iteration, args.total_iterations):
278
279 # Application periodically saves a checkpoint. The checkpoint allows
280 # the application to continue from previous state after a restart.
281 if iteration % chkpt_interval == chkpt_interval - 1:
282 torch.distributed.barrier()
283 if rank == 0:
284 checkpoint = {
285 'model': model.state_dict(),
286 'opt': opt.state_dict(),
287 'rng': torch.get_rng_state(),
288 'iteration': iteration,
289 'ft_state': ft_client.state_dict(),
290 }
291 # Saving the checkpoint is performed within atomic() context
292 # manager to ensure that the main thread won't execute
293 # torch.save while a restart procedure is in progress.
294 with call_wrapper.atomic():
295 torch.save(checkpoint, checkpoint_path)
296
297 _maybe_simulate_fault(rank, iteration, args.fault_iters)
298
299 inp = torch.rand(args.size, args.size).to(device)
300 model.zero_grad()
301 out = model_ddp(inp)
302 loss = out.square().mean()
303 loss.backward()
304 opt.step()
305 loss.item()
306
307 if rank == 0 and iteration % log_interval == log_interval - 1:
308 logging.info(f'{rank=} {iteration=} {loss.item()=}')
309
310 ft_client.send_heartbeat() # notifies FT that the training process is still active.
311
312
313def main():
314 args = parse_args()
315 logging.basicConfig(
316 format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
317 level=args.log_level,
318 )
319
320 rank = int(os.environ['RANK'])
321 local_rank = int(os.environ['LOCAL_RANK'])
322 world_size = int(os.environ['WORLD_SIZE'])
323
324 if rank == 0:
325 logging.info(f'\n##### NEW RUN {args} #####n')
326
327 if args.device == 'cuda':
328 torch.cuda.set_device(local_rank)
329 device = torch.device('cuda')
330 backend = 'nccl'
331 timeout = datetime.timedelta(seconds=150)
332 elif args.device == 'cpu':
333 device = torch.device('cpu')
334 backend = 'gloo'
335 timeout = datetime.timedelta(seconds=10)
336 else:
337 raise RuntimeError
338
339 # All objects created in ``main()`` are constructed only once, and reused
340 # for all restart iterations.
341 if args.seed is not None:
342 torch.manual_seed(args.seed)
343 model = torch.nn.Sequential(
344 *[torch.nn.Linear(args.size, args.size) for _ in range(args.layers)]
345 ).to(device)
346 opt = torch.optim.Adam(model.parameters(), lr=1e-5)
347
348 # TCPStore uses ``(MASTER_PORT + 1)`` to avoid conflicts with TCPStore
349 # created by ``torch.distributed.run`` and listening on ``MASTER_PORT``.
350 store = torch.distributed.TCPStore(
351 host_name=os.environ['MASTER_ADDR'],
352 port=int(os.environ['MASTER_PORT']) + 1,
353 world_size=int(os.environ['WORLD_SIZE']),
354 is_master=(int(os.environ['RANK']) == 0),
355 multi_tenant=True,
356 wait_for_workers=True,
357 use_libuv=True,
358 )
359
360 # Prepares the FT client instance, it will be initialized in the ``train()``.
361 ft_client = fault_tolerance.RankMonitorClient()
362
363 try:
364 # Call the wrapped function.
365 # ``train()`` is automatically restarted to recover from faults.
366 train(ft_client, store, model, opt, backend, device, timeout, args)
367 finally:
368 if ft_client.is_initialized:
369 ft_client.shutdown_workload_monitoring()
370
371
372if __name__ == '__main__':
373 main()