1# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3# SPDX-License-Identifier: Apache-2.0 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17 18# This example demonstrates how to integrate ``inprocess.Wrapper()`` into an 19# existing PyTorch training codebase. 20# 21# In this case, the entire ``main()`` function is wrapped. While all features 22# of ``inprocess.Wrapper()`` are available and active, the Wrapper is 23# configured to restart the entire application upon any failure. Consequently, 24# the application state is not preserved between restarts and the entire 25# ``main()`` is relaunched, leading to less efficient recovery from failures. 26# 27# NOTE: inprocess.Wrapper is not fully compatible with modern 28# ``torch.distributed.run``, because it automatically terminates all local 29# workers upon any local worker process failure; in this case inprocess.Wrapper 30# can only recover from transient faults that don't terminate any of the 31# training processes 32 33importargparse 34importdatetime 35importlogging 36importos 37importpathlib 38importrandom 39importtime 40fromtypingimportOptional 41 42importtorch 43 44importnvidia_resiliency_ext.inprocessasinprocess 45 46raise_timestamp=None 47 48 49defparse_args(): 50parser=argparse.ArgumentParser( 51description='Inprocess Restart Basic Example', 52formatter_class=argparse.ArgumentDefaultsHelpFormatter, 53) 54 55parser.add_argument( 56'--size', 57default=64, 58type=int, 59help='model hidden size', 60) 61parser.add_argument( 62'--layers', 63default=4, 64type=int, 65help='number of layers', 66) 67parser.add_argument( 68'--log-interval', 69default=100, 70type=int, 71help='logging interval', 72) 73parser.add_argument( 74'--chkpt-interval', 75default=100, 76type=int, 77help='checkpointing interval', 78) 79parser.add_argument( 80'--total-iterations', 81default=1000000, 82type=int, 83help='total training iterations', 84) 85parser.add_argument( 86'--seed', 87default=None, 88type=int, 89help='random seed, time-based if None', 90) 91parser.add_argument( 92'--path', 93default='/tmp/', 94type=str, 95help='directory for the checkpoint file', 96) 97parser.add_argument( 98'--fault-prob', 99default=0.001,100type=float,101help='fault injection probability',102)103parser.add_argument(104'--device',105default='cpu',106choices=['cpu','cuda'],107help='device',108)109parser.add_argument(110'--log-level',111type=lambdas:logging._nameToLevel[s.upper()],112default=logging.INFO,113help='logging level',114)115116returnparser.parse_args()117118119# TCPStore created by the Wrapper uses ``(MASTER_PORT + 2)`` port for the120# internal Wrapper TCPStore to avoid conflicts with application's TCPStore121# listening on ``(MASTER_PORT + 1)``, and with TCPStore created by122# ``torch.distributed.run`` listening on ``MASTER_PORT``.123@inprocess.Wrapper(124store_kwargs={'port':int(os.getenv('MASTER_PORT',29500))+2},125health_check=inprocess.health_check.CudaHealthCheck(),126)127defmain(call_wrapper:Optional[inprocess.CallWrapper]=None):128globalraise_timestamp129ifraise_timestampisnotNone:130restart_latency=time.perf_counter()-raise_timestamp131logging.info(f'restart latency: {restart_latency:.3f}s')132raise_timestamp=None133134args=parse_args()135logging.info(f'{args}')136137log_interval=args.log_interval138chkpt_interval=args.chkpt_interval139140rank=int(os.environ['RANK'])141local_rank=int(os.environ['LOCAL_RANK'])142world_size=int(os.environ['WORLD_SIZE'])143144ifargs.device=='cuda':145torch.cuda.set_device(local_rank)146device=torch.device('cuda')147backend='nccl'148timeout=datetime.timedelta(seconds=150)149elifargs.device=='cpu':150device=torch.device('cpu')151backend='gloo'152timeout=datetime.timedelta(seconds=10)153else:154raiseRuntimeError155156ifargs.seedisnotNone:157torch.manual_seed(args.seed)158model=torch.nn.Sequential(159*[torch.nn.Linear(args.size,args.size)for_inrange(args.layers)]160).to(device)161opt=torch.optim.Adam(model.parameters(),lr=1e-5)162163# TCPStore uses ``(MASTER_PORT + 1)`` to avoid conflicts with a TCPStore164# created by ``torch.distributed.run`` and listening on ``MASTER_PORT``.165store=torch.distributed.TCPStore(166host_name=os.environ['MASTER_ADDR'],167port=int(os.environ['MASTER_PORT'])+1,168world_size=int(os.environ['WORLD_SIZE']),169is_master=int(os.environ['RANK'])==0,170multi_tenant=True,171wait_for_workers=True,172use_libuv=True,173)174175torch.distributed.init_process_group(176backend=backend,177store=store,178rank=int(os.environ['RANK']),179world_size=int(os.environ['WORLD_SIZE']),180timeout=timeout,181)182model_ddp=torch.nn.parallel.DistributedDataParallel(model)183184iteration=0185loss=torch.tensor(float('nan'))186checkpoint_path=pathlib.Path(args.path)/'checkpoint.pt'187188# Application loads state from the latest checkpoint on every restart189# iteration of the wrapped function.190ifcheckpoint_path.exists():191checkpoint=torch.load(checkpoint_path)192model.load_state_dict(checkpoint['model'])193opt.load_state_dict(checkpoint['opt'])194torch.set_rng_state(checkpoint['rng'])195iteration=checkpoint['iteration']196197ifargs.seedisnotNone:198random.seed(args.seed+iteration*world_size+rank)199else:200random.seed(time.perf_counter_ns())201202foriterationinrange(iteration,args.total_iterations):203204# Application periodically saves a checkpoint. The checkpoint allows205# the application to continue from previous state after a restart.206ifiteration%chkpt_interval==chkpt_interval-1:207torch.distributed.barrier()208ifrank==0:209checkpoint={210'model':model.state_dict(),211'opt':opt.state_dict(),212'rng':torch.get_rng_state(),213'iteration':iteration,214}215# Saving the checkpoint is performed within atomic() context216# manager to ensure that the main thread won't execute217# torch.save while a restart procedure is in progress.218withcall_wrapper.atomic():219torch.save(checkpoint,checkpoint_path)220221# Randomly trigger an example fault222ifrandom.random()<args.fault_prob:223raise_timestamp=time.perf_counter()224raiseRuntimeError(f'example fault at {iteration=} from {rank=}')225226inp=torch.rand(args.size,args.size).to(device)227model.zero_grad()228out=model_ddp(inp)229loss=out.square().mean()230loss.backward()231opt.step()232loss.item()233234ifrank==0anditeration%log_interval==log_interval-1:235logging.info(f'{rank=}{iteration=}{loss.item()=}')236237238if__name__=='__main__':239# ``inprocess.Wrapper`` uses logging library to output messages. In this240# example the Wrapper is applied to ``main()``, therefore logging needs to241# be initialized and configured before the Wrapper is launched.242args=parse_args()243logging.basicConfig(244format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",245level=args.log_level,246)247main()