1# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# This example demonstrates how to integrate ``inprocess.Wrapper()`` into an
19# existing PyTorch training codebase.
20#
21# This example show the optimal usage:
22# - only the training loop and objects depending on a torch distributed process
23# group are being restarted upon a failure
24# - process-group-independent objects (e.g. TCPStore, Model, Optimizer) are
25# created once, and reused between all restart iterations to minimize restart
26# latency
27#
28# NOTE: inprocess.Wrapper is not fully compatible with modern
29# ``torch.distributed.run``, because it automatically terminates all local
30# workers upon any local worker process failure; in this case inprocess.Wrapper
31# can only recover from transient faults that don't terminate any of the
32# training processes
33
34import argparse
35import datetime
36import logging
37import os
38import pathlib
39import random
40import time
41from typing import Optional
42
43os.environ['TORCH_CPP_LOG_LEVEL'] = 'error'
44import torch
45
46import nvidia_resiliency_ext.inprocess as inprocess
47
48raise_timestamp = None
49
50
51def parse_args():
52 parser = argparse.ArgumentParser(
53 description='Inprocess Restart Optimal Example',
54 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
55 )
56
57 parser.add_argument(
58 '--size',
59 default=64,
60 type=int,
61 help='model hidden size',
62 )
63 parser.add_argument(
64 '--layers',
65 default=4,
66 type=int,
67 help='number of layers',
68 )
69 parser.add_argument(
70 '--log-interval',
71 default=100,
72 type=int,
73 help='logging interval',
74 )
75 parser.add_argument(
76 '--chkpt-interval',
77 default=100,
78 type=int,
79 help='checkpointing interval',
80 )
81 parser.add_argument(
82 '--total-iterations',
83 default=1000000,
84 type=int,
85 help='total training iterations',
86 )
87 parser.add_argument(
88 '--seed',
89 default=None,
90 type=int,
91 help='random seed, time-based if None',
92 )
93 parser.add_argument(
94 '--path',
95 default='/tmp/',
96 type=str,
97 help='directory for the checkpoint file',
98 )
99 parser.add_argument(
100 '--fault-prob',
101 default=0.001,
102 type=float,
103 help='fault injection probability',
104 )
105 parser.add_argument(
106 '--device',
107 default='cpu',
108 choices=['cpu', 'cuda'],
109 help='device',
110 )
111 parser.add_argument(
112 '--log-level',
113 type=lambda s: logging._nameToLevel[s.upper()],
114 default=logging.INFO,
115 help='logging level',
116 )
117
118 parser.add_argument(
119 '--nested-restarter',
120 action='store_true',
121 help='extra logging for the nested restarter',
122 )
123
124 return parser.parse_args()
125
126
127def train(
128 base_store,
129 model,
130 opt,
131 backend,
132 device,
133 timeout,
134 args,
135 call_wrapper: Optional[inprocess.CallWrapper] = None,
136):
137 global raise_timestamp
138 if raise_timestamp is not None:
139 restart_latency = time.perf_counter() - raise_timestamp
140 logging.info(f'restart latency: {restart_latency:.3f}s')
141 raise_timestamp = None
142
143 log_interval = args.log_interval
144 chkpt_interval = args.chkpt_interval
145
146 rank = int(os.environ['RANK'])
147 world_size = int(os.environ['WORLD_SIZE'])
148
149 # Create a new Store by adding a prefix based on the current inprocess
150 # restart iteration. PrefixStore wraps the baseline TCPStore which is
151 # reused for all restart iterations
152 store = torch.distributed.PrefixStore(str(call_wrapper.iteration), base_store)
153
154 torch.distributed.init_process_group(
155 backend,
156 store=store,
157 rank=rank,
158 world_size=world_size,
159 timeout=timeout,
160 )
161 model_ddp = torch.nn.parallel.DistributedDataParallel(model)
162
163 iteration = 0
164 loss = torch.tensor(float('nan'))
165 checkpoint_path = pathlib.Path(args.path) / 'checkpoint.pt'
166
167 # Application loads state from the latest checkpoint on every restart
168 # iteration of the wrapped function.
169 if checkpoint_path.exists():
170 checkpoint = torch.load(checkpoint_path)
171 model.load_state_dict(checkpoint['model'])
172 opt.load_state_dict(checkpoint['opt'])
173 torch.set_rng_state(checkpoint['rng'])
174 iteration = checkpoint['iteration']
175
176 if args.seed is not None:
177 random.seed(args.seed + iteration * world_size + rank)
178 else:
179 random.seed(time.perf_counter_ns())
180
181 for iteration in range(iteration, args.total_iterations):
182
183 # Application periodically saves a checkpoint. The checkpoint allows
184 # the application to continue from previous state after a restart.
185 if iteration % chkpt_interval == chkpt_interval - 1:
186 torch.distributed.barrier()
187 if rank == 0:
188 checkpoint = {
189 'model': model.state_dict(),
190 'opt': opt.state_dict(),
191 'rng': torch.get_rng_state(),
192 'iteration': iteration,
193 }
194 # Saving the checkpoint is performed within atomic() context
195 # manager to ensure that the main thread won't execute
196 # torch.save while a restart procedure is in progress.
197 with call_wrapper.atomic():
198 torch.save(checkpoint, checkpoint_path)
199
200 # Randomly trigger an example fault
201 if random.random() < args.fault_prob:
202 raise_timestamp = time.perf_counter()
203 raise RuntimeError(f'example fault at {iteration=} from {rank=}')
204
205 inp = torch.rand(args.size, args.size).to(device)
206 model.zero_grad()
207 out = model_ddp(inp)
208 loss = out.square().mean()
209 loss.backward()
210 opt.step()
211 loss.item()
212
213 if rank == 0 and iteration % log_interval == log_interval - 1:
214 logging.info(f'{rank=} {iteration=} {loss.item()=}')
215
216
217def main():
218 args = parse_args()
219 logging.basicConfig(
220 format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
221 level=args.log_level,
222 )
223 logging.info(f'{args}')
224
225 local_rank = int(os.environ['LOCAL_RANK'])
226
227 if args.device == 'cuda':
228 torch.cuda.set_device(local_rank)
229 device = torch.device('cuda')
230 backend = 'nccl'
231 timeout = datetime.timedelta(seconds=150)
232 elif args.device == 'cpu':
233 device = torch.device('cpu')
234 backend = 'gloo'
235 timeout = datetime.timedelta(seconds=10)
236 else:
237 raise RuntimeError
238
239 # All objects created in ``main()`` are constructed only once, and reused
240 # for all restart iterations.
241 if args.seed is not None:
242 torch.manual_seed(args.seed)
243 model = torch.nn.Sequential(
244 *[torch.nn.Linear(args.size, args.size) for _ in range(args.layers)]
245 ).to(device)
246 opt = torch.optim.Adam(model.parameters(), lr=1e-5)
247
248 # TCPStore uses ``(MASTER_PORT + 2)`` to avoid conflicts with TCPStore
249 # created by ``torch.distributed.run`` and listening on ``MASTER_PORT``,
250 # and Wrapper's TCPStore listening on ``(MASTER_PORT + 1)``
251 store = torch.distributed.TCPStore(
252 host_name=os.environ['MASTER_ADDR'],
253 port=int(os.environ['MASTER_PORT']) + 2,
254 world_size=int(os.environ['WORLD_SIZE']),
255 is_master=(int(os.environ['RANK']) == 0),
256 multi_tenant=True,
257 wait_for_workers=True,
258 use_libuv=True,
259 )
260
261 # TCPStore created by the Wrapper uses ``(MASTER_PORT + 1)`` port for the
262 # internal Wrapper's TCPStore to avoid conflicts with application's TCPStore
263 # listening on ``(MASTER_PORT + 2)``, and with a TCPStore created by
264 # ``torch.distributed.run`` listening on ``MASTER_PORT``.
265 #
266 wrapper_kwargs = {
267 'store_kwargs': {'port': int(os.getenv('MASTER_PORT', 29500)) + 1},
268 'health_check': inprocess.health_check.CudaHealthCheck(),
269 }
270
271 if args.nested_restarter:
272 wrapper_kwargs['initialize'] = inprocess.nested_restarter.NestedRestarterHandlingCompleted()
273 wrapper_kwargs['abort'] = inprocess.Compose(
274 inprocess.abort.AbortTorchDistributed(),
275 inprocess.nested_restarter.NestedRestarterHandlingStarting(),
276 )
277 wrapper_kwargs['completion'] = inprocess.nested_restarter.NestedRestarterFinalized()
278 wrapper_kwargs['terminate'] = inprocess.nested_restarter.NestedRestarterAborted()
279
280 # An instance of ``inprocess.CallWrapper` is automatically injected into
281 # wrapped function arguments when Wrapper is invoked.
282 wrapped_train = inprocess.Wrapper(**wrapper_kwargs)(train)
283
284 # Call the wrapped function.
285 # ``wrapped_train()`` is automatically restarted to recover from faults.
286 wrapped_train(store, model, opt, backend, device, timeout, args)
287
288
289if __name__ == '__main__':
290 main()