1# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# This example demonstrates how to integrate ``inprocess.Wrapper()`` into an
19# existing PyTorch training codebase.
20#
21# This example show the optimal usage:
22# - only the training loop and objects depending on a torch distributed process
23# group are being restarted upon a failure
24# - process-group-independent objects (e.g. TCPStore, Model, Optimizer) are
25# created once, and reused between all restart iterations to minimize restart
26# latency
27#
28# NOTE: inprocess.Wrapper is not fully compatible with modern
29# ``torch.distributed.run``, because it automatically terminates all local
30# workers upon any local worker process failure; in this case inprocess.Wrapper
31# can only recover from transient faults that don't terminate any of the
32# training processes
33
34import argparse
35import datetime
36import logging
37import os
38import pathlib
39import random
40import time
41from typing import Optional
42
43import torch
44
45import nvidia_resiliency_ext.inprocess as inprocess
46
47raise_timestamp = None
48
49
50def parse_args():
51 parser = argparse.ArgumentParser(
52 description='Inprocess Restart Optimal Example',
53 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
54 )
55
56 parser.add_argument(
57 '--size',
58 default=64,
59 type=int,
60 help='model hidden size',
61 )
62 parser.add_argument(
63 '--layers',
64 default=4,
65 type=int,
66 help='number of layers',
67 )
68 parser.add_argument(
69 '--log-interval',
70 default=100,
71 type=int,
72 help='logging interval',
73 )
74 parser.add_argument(
75 '--chkpt-interval',
76 default=100,
77 type=int,
78 help='checkpointing interval',
79 )
80 parser.add_argument(
81 '--total-iterations',
82 default=1000000,
83 type=int,
84 help='total training iterations',
85 )
86 parser.add_argument(
87 '--seed',
88 default=None,
89 type=int,
90 help='random seed, time-based if None',
91 )
92 parser.add_argument(
93 '--path',
94 default='/tmp/',
95 type=str,
96 help='directory for the checkpoint file',
97 )
98 parser.add_argument(
99 '--fault-prob',
100 default=0.001,
101 type=float,
102 help='fault injection probability',
103 )
104 parser.add_argument(
105 '--device',
106 default='cpu',
107 choices=['cpu', 'cuda'],
108 help='device',
109 )
110 parser.add_argument(
111 '--log-level',
112 type=lambda s: logging._nameToLevel[s.upper()],
113 default=logging.INFO,
114 help='logging level',
115 )
116
117 return parser.parse_args()
118
119
120# TCPStore created by the Wrapper uses ``(MASTER_PORT + 2)`` port for the
121# internal Wrapper TCPStore to avoid conflicts with application's TCPStore
122# listening on ``(MASTER_PORT + 1)``, and with a TCPStore created by
123# ``torch.distributed.run`` listening on ``MASTER_PORT``.
124#
125# An instance of ``inprocess.CallWrapper` is automatically injected into
126# wrapped function arguments when Wrapper is invoked.
127@inprocess.Wrapper(
128 store_kwargs={'port': int(os.getenv('MASTER_PORT', 29500)) + 2},
129 health_check=inprocess.health_check.CudaHealthCheck(),
130)
131def train(
132 base_store,
133 model,
134 opt,
135 backend,
136 device,
137 timeout,
138 args,
139 call_wrapper: Optional[inprocess.CallWrapper] = None,
140):
141 global raise_timestamp
142 if raise_timestamp is not None:
143 restart_latency = time.perf_counter() - raise_timestamp
144 logging.info(f'restart latency: {restart_latency:.3f}s')
145 raise_timestamp = None
146
147 log_interval = args.log_interval
148 chkpt_interval = args.chkpt_interval
149
150 rank = int(os.environ['RANK'])
151 world_size = int(os.environ['WORLD_SIZE'])
152
153 # Create a new Store by adding a prefix based on the current inprocess
154 # restart iteration. PrefixStore wraps the baseline TCPStore which is
155 # reused for all restart iterations
156 store = torch.distributed.PrefixStore(
157 str(call_wrapper.iteration), base_store
158 )
159
160 torch.distributed.init_process_group(
161 backend,
162 store=store,
163 rank=rank,
164 world_size=world_size,
165 timeout=timeout,
166 )
167 model_ddp = torch.nn.parallel.DistributedDataParallel(model)
168
169 iteration = 0
170 loss = torch.tensor(float('nan'))
171 checkpoint_path = pathlib.Path(args.path) / 'checkpoint.pt'
172
173 # Application loads state from the latest checkpoint on every restart
174 # iteration of the wrapped function.
175 if checkpoint_path.exists():
176 checkpoint = torch.load(checkpoint_path)
177 model.load_state_dict(checkpoint['model'])
178 opt.load_state_dict(checkpoint['opt'])
179 torch.set_rng_state(checkpoint['rng'])
180 iteration = checkpoint['iteration']
181
182 if args.seed is not None:
183 random.seed(args.seed + iteration * world_size + rank)
184 else:
185 random.seed(time.perf_counter_ns())
186
187 for iteration in range(iteration, args.total_iterations):
188
189 # Application periodically saves a checkpoint. The checkpoint allows
190 # the application to continue from previous state after a restart.
191 if iteration % chkpt_interval == chkpt_interval - 1:
192 torch.distributed.barrier()
193 if rank == 0:
194 checkpoint = {
195 'model': model.state_dict(),
196 'opt': opt.state_dict(),
197 'rng': torch.get_rng_state(),
198 'iteration': iteration,
199 }
200 # Saving the checkpoint is performed within atomic() context
201 # manager to ensure that the main thread won't execute
202 # torch.save while a restart procedure is in progress.
203 with call_wrapper.atomic():
204 torch.save(checkpoint, checkpoint_path)
205
206 # Randomly trigger an example fault
207 if random.random() < args.fault_prob:
208 raise_timestamp = time.perf_counter()
209 raise RuntimeError(f'example fault at {iteration=} from {rank=}')
210
211 inp = torch.rand(args.size, args.size).to(device)
212 model.zero_grad()
213 out = model_ddp(inp)
214 loss = out.square().mean()
215 loss.backward()
216 opt.step()
217 loss.item()
218
219 if rank == 0 and iteration % log_interval == log_interval - 1:
220 logging.info(f'{rank=} {iteration=} {loss.item()=}')
221
222
223def main():
224 args = parse_args()
225 logging.basicConfig(
226 format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
227 level=args.log_level,
228 )
229 logging.info(f'{args}')
230
231 rank = int(os.environ['RANK'])
232 local_rank = int(os.environ['LOCAL_RANK'])
233
234 if args.device == 'cuda':
235 torch.cuda.set_device(local_rank)
236 device = torch.device('cuda')
237 backend = 'nccl'
238 timeout = datetime.timedelta(seconds=150)
239 elif args.device == 'cpu':
240 device = torch.device('cpu')
241 backend = 'gloo'
242 timeout = datetime.timedelta(seconds=10)
243 else:
244 raise RuntimeError
245
246 # All objects created in ``main()`` are constructed only once, and reused
247 # for all restart iterations.
248 if args.seed is not None:
249 torch.manual_seed(args.seed)
250 model = torch.nn.Sequential(
251 *[torch.nn.Linear(args.size, args.size) for _ in range(args.layers)]
252 ).to(device)
253 opt = torch.optim.Adam(model.parameters(), lr=1e-5)
254
255 # TCPStore uses ``(MASTER_PORT + 1)`` to avoid conflicts with TCPStore
256 # created by ``torch.distributed.run`` and listening on ``MASTER_PORT``.
257 store = torch.distributed.TCPStore(
258 host_name=os.environ['MASTER_ADDR'],
259 port=int(os.environ['MASTER_PORT']) + 1,
260 world_size=int(os.environ['WORLD_SIZE']),
261 is_master=(int(os.environ['RANK']) == 0),
262 multi_tenant=True,
263 wait_for_workers=True,
264 use_libuv=True,
265 )
266
267 # Call the wrapped function.
268 # ``train()`` is automatically restarted to recover from faults.
269 train(store, model, opt, backend, device, timeout, args)
270
271
272if __name__ == '__main__':
273 main()