1# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18import datetime
19import logging
20import os
21import pathlib
22import random
23import time
24from typing import Optional
25
26os.environ['TORCH_CPP_LOG_LEVEL'] = 'error'
27import torch
28
29import nvidia_resiliency_ext.inprocess as inprocess
30
31raise_timestamp = None
32
33################################################################################
34#
35# THIS EXAMPLE DEMONSTRATES AN ADVANCED CASE WHERE THE MODEL AND OPTIMIZER ARE
36# PROCESS-GROUP INDEPENDENT.
37#
38# In this example, the model and optimizer are process-group independent,
39# so they can safely be created outside of the wrapped function and reused.
40#
41# See the usage guide for more details.
42#
43# This example demonstrates how to integrate ``inprocess.Wrapper()`` into an
44# existing PyTorch training codebase for an advanced use-case where:
45# - only the training loop and objects depending on a torch distributed process
46# group are being restarted upon a failure
47# - process-group-independent objects (e.g. TCPStore, Model, Optimizer) are
48# created once, and reused between all restart iterations to minimize restart
49# latency
50#
51# NOTE: inprocess.Wrapper is not fully compatible with modern
52# ``torch.distributed.run``, because it automatically terminates all local
53# workers upon any local worker process failure; in this case inprocess.Wrapper
54# can only recover from transient faults that don't terminate any of the
55# training processes
56#
57################################################################################
58
59
60def parse_args():
61 parser = argparse.ArgumentParser(
62 description='Inprocess Restart Advanced Example',
63 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
64 )
65
66 parser.add_argument(
67 '--size',
68 default=64,
69 type=int,
70 help='model hidden size',
71 )
72 parser.add_argument(
73 '--layers',
74 default=4,
75 type=int,
76 help='number of layers',
77 )
78 parser.add_argument(
79 '--log-interval',
80 default=100,
81 type=int,
82 help='logging interval',
83 )
84 parser.add_argument(
85 '--chkpt-interval',
86 default=100,
87 type=int,
88 help='checkpointing interval',
89 )
90 parser.add_argument(
91 '--total-iterations',
92 default=1000000,
93 type=int,
94 help='total training iterations',
95 )
96 parser.add_argument(
97 '--seed',
98 default=None,
99 type=int,
100 help='random seed, time-based if None',
101 )
102 parser.add_argument(
103 '--path',
104 default='/tmp/',
105 type=str,
106 help='directory for the checkpoint file',
107 )
108 parser.add_argument(
109 '--fault-prob',
110 default=0.001,
111 type=float,
112 help='fault injection probability',
113 )
114 parser.add_argument(
115 '--device',
116 default='cuda',
117 choices=['cpu', 'cuda'],
118 help='device',
119 )
120 parser.add_argument(
121 '--log-level',
122 type=lambda s: logging._nameToLevel[s.upper()],
123 default=logging.INFO,
124 help='logging level',
125 )
126
127 parser.add_argument(
128 '--nested-restarter',
129 action='store_true',
130 help='extra logging for the nested restarter',
131 )
132
133 return parser.parse_args()
134
135
136def train(
137 base_store,
138 model,
139 opt,
140 backend,
141 device,
142 timeout,
143 args,
144 call_wrapper: Optional[inprocess.CallWrapper] = None,
145):
146 global raise_timestamp
147 if raise_timestamp is not None:
148 restart_latency = time.perf_counter() - raise_timestamp
149 logging.info(f'restart latency: {restart_latency:.3f}s')
150 raise_timestamp = None
151
152 log_interval = args.log_interval
153 chkpt_interval = args.chkpt_interval
154
155 rank = int(os.environ['RANK'])
156 world_size = int(os.environ['WORLD_SIZE'])
157
158 # Create a new Store by adding a prefix based on the current inprocess
159 # restart iteration. PrefixStore wraps the baseline TCPStore which is
160 # reused for all restart iterations
161 store = torch.distributed.PrefixStore(str(call_wrapper.iteration), base_store)
162
163 torch.distributed.init_process_group(
164 backend,
165 store=store,
166 rank=rank,
167 world_size=world_size,
168 timeout=timeout,
169 )
170 model_ddp = torch.nn.parallel.DistributedDataParallel(model)
171
172 iteration = 0
173 loss = torch.tensor(float('nan'))
174 checkpoint_path = pathlib.Path(args.path) / 'checkpoint.pt'
175
176 # Application loads state from the latest checkpoint on every restart
177 # iteration of the wrapped function.
178 if checkpoint_path.exists():
179 checkpoint = torch.load(checkpoint_path)
180 model.load_state_dict(checkpoint['model'])
181 opt.load_state_dict(checkpoint['opt'])
182 torch.set_rng_state(checkpoint['rng'])
183 iteration = checkpoint['iteration']
184
185 if args.seed is not None:
186 random.seed(args.seed + iteration * world_size + rank)
187 else:
188 random.seed(time.perf_counter_ns())
189
190 for iteration in range(iteration, args.total_iterations):
191
192 # Application periodically saves a checkpoint. The checkpoint allows
193 # the application to continue from previous state after a restart.
194 if iteration % chkpt_interval == chkpt_interval - 1:
195 torch.distributed.barrier()
196 if rank == 0:
197 checkpoint = {
198 'model': model.state_dict(),
199 'opt': opt.state_dict(),
200 'rng': torch.get_rng_state(),
201 'iteration': iteration,
202 }
203 # Saving the checkpoint is performed within atomic() context
204 # manager to ensure that the main thread won't execute
205 # torch.save while a restart procedure is in progress.
206 with call_wrapper.atomic():
207 torch.save(checkpoint, checkpoint_path)
208
209 # Randomly trigger an example fault
210 if random.random() < args.fault_prob:
211 raise_timestamp = time.perf_counter()
212 raise RuntimeError(f'example fault at {iteration=} from {rank=}')
213
214 inp = torch.rand(args.size, args.size).to(device)
215 model.zero_grad()
216 out = model_ddp(inp)
217 loss = out.square().mean()
218 loss.backward()
219 opt.step()
220 loss.item()
221
222 if rank == 0 and iteration % log_interval == log_interval - 1:
223 logging.info(f'{rank=} {iteration=} {loss.item()=}')
224
225
226def main():
227 args = parse_args()
228 logging.basicConfig(
229 format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
230 level=args.log_level,
231 )
232 logging.info(f'{args}')
233
234 local_rank = int(os.environ['LOCAL_RANK'])
235
236 if args.device == 'cuda':
237 torch.cuda.set_device(local_rank)
238 device = torch.device('cuda')
239 backend = 'nccl'
240 timeout = datetime.timedelta(seconds=150)
241 elif args.device == 'cpu':
242 device = torch.device('cpu')
243 backend = 'gloo'
244 timeout = datetime.timedelta(seconds=10)
245 else:
246 raise RuntimeError
247
248 # In this *specific* example, all objects created in ``main()`` are constructed
249 # only once and reused for all restart iterations. Here, the model and optimizer
250 # are process-group independent, so those objects can safely be created outside
251 # of the wrapped function and reused.
252 #
253 # Normally, variables that are process-group dependent should be created or
254 # initialized inside the wrapped function or the workload is responsible for
255 # ensuring that objects refer to valid process-groups after a restart. See the
256 # usage guide for more details.
257
258 if args.seed is not None:
259 torch.manual_seed(args.seed)
260 model = torch.nn.Sequential(
261 *[torch.nn.Linear(args.size, args.size) for _ in range(args.layers)]
262 ).to(device)
263 opt = torch.optim.Adam(model.parameters(), lr=1e-5)
264
265 # TCPStore uses ``(MASTER_PORT + 2)`` to avoid conflicts with TCPStore
266 # created by ``torch.distributed.run`` and listening on ``MASTER_PORT``,
267 # and Wrapper's TCPStore listening on ``(MASTER_PORT + 1)``
268 store = torch.distributed.TCPStore(
269 host_name=os.environ['MASTER_ADDR'],
270 port=int(os.environ['MASTER_PORT']) + 2,
271 world_size=int(os.environ['WORLD_SIZE']),
272 is_master=(int(os.environ['RANK']) == 0),
273 multi_tenant=True,
274 wait_for_workers=True,
275 use_libuv=True,
276 )
277
278 # TCPStore created by the Wrapper uses ``(MASTER_PORT + 1)`` port for the
279 # internal Wrapper's TCPStore to avoid conflicts with application's TCPStore
280 # listening on ``(MASTER_PORT + 2)``, and with a TCPStore created by
281 # ``torch.distributed.run`` listening on ``MASTER_PORT``.
282 #
283 wrapper_kwargs = {
284 'store_kwargs': {'port': int(os.getenv('MASTER_PORT', 29500)) + 1},
285 'health_check': inprocess.health_check.CudaHealthCheck(),
286 }
287
288 if args.nested_restarter:
289 wrapper_kwargs['initialize'] = inprocess.nested_restarter.NestedRestarterHandlingCompleted()
290 wrapper_kwargs['abort'] = inprocess.Compose(
291 inprocess.abort.AbortTorchDistributed(),
292 inprocess.nested_restarter.NestedRestarterHandlingStarting(),
293 )
294 wrapper_kwargs['completion'] = inprocess.nested_restarter.NestedRestarterFinalized()
295 wrapper_kwargs['terminate'] = inprocess.nested_restarter.NestedRestarterAborted()
296
297 # An instance of ``inprocess.CallWrapper` is automatically injected into
298 # wrapped function arguments when Wrapper is invoked.
299 wrapped_train = inprocess.Wrapper(**wrapper_kwargs)(train)
300
301 # Call the wrapped function.
302 # ``wrapped_train()`` is automatically restarted to recover from faults.
303 wrapped_train(store, model, opt, backend, device, timeout, args)
304
305
306if __name__ == '__main__':
307 main()