1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2# SPDX-License-Identifier: Apache-2.0 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16""" 17Demo of fault tolerance with DDP training 18""" 19 20importargparse 21importlogging 22importos 23importrandom 24importsignal 25importsys 26importthreading 27importtime 28 29importdist_utils 30importlog_utils 31importnumpyasnp 32importtorch 33importtorch.nnasnn 34 35importnvidia_resiliency_ext.fault_toleranceasfault_tolerance 36 37 38# Dummy dataset. 39classDataset(torch.utils.data.Dataset): 40def__init__(self,size,hidden): 41self.size=size 42self.hidden=hidden 43 44def__len__(self): 45returnself.size 46 47def__getitem__(self,idx): 48data=torch.full( 49(self.hidden,), 50fill_value=idx, 51dtype=torch.float32, 52device='cpu', 53) 54returndata 55 56 57# Dummy model 58classModel(nn.Module): 59def__init__(self,hidden): 60super().__init__() 61self.l1=nn.Linear(hidden,hidden) 62self.l2=nn.Linear(hidden,hidden) 63 64defforward(self,x): 65x=self.l1(x) 66x=self.l2(x) 67returnx 68 69 70defparse_args(): 71deftuple_type(strings): 72strings=strings.replace("(","").replace(")","") 73mapped_int=map(int,strings.split(",")) 74returntuple(mapped_int) 75 76deffault_desc(strings): 77parts=strings.split(",") 78assertlen(parts)==2 79return{'fault':parts[0],'delay':float(parts[1])} 80 81parser=argparse.ArgumentParser( 82description='Example of PyTorch DDP training with the Fault Tolerance package', 83formatter_class=argparse.ArgumentDefaultsHelpFormatter, 84) 85 86# fmt: off 87parser.add_argument('--hidden',type=int,default=4096, 88help='Hidden size') 89parser.add_argument('--batch',type=int,default=8, 90help='Batch size') 91parser.add_argument('--epochs',type=int,default=4, 92help='Number of training epochs') 93parser.add_argument('--train_dataset_size',type=int,default=1000000, 94help='Train dataset size') 95parser.add_argument('--val_dataset_size',type=int,default=2000, 96help='Validation dataset size') 97parser.add_argument('--device',type=str,default='cuda', 98choices=['cpu','cuda'], 99help='Device')100101parser.add_argument('--interrupt_at',type=tuple_type,nargs='*',102help='Manual interruption after (epoch, iteration), '103'for testing only')104parser.add_argument('--save_interval',type=int,default=-1,105help='Interval for saving periodic checkpoints')106parser.add_argument('--logging_interval',type=int,default=1,107help='Interval for log entries')108parser.add_argument('--log_all_ranks',action='store_true',109help='Enable logging from all distributed ranks')110parser.add_argument('--output_dir',type=str,default='results/output',111help='Output dir')112parser.add_argument('--checkpoint_fname',type=str,default='checkpoint.pt',113help='Name of a checkpoint file')114115parser.add_argument('--local_rank',type=int,116default=os.getenv('LOCAL_RANK',0))117parser.add_argument('--init_distributed_method',type=str,default='tcp',118help='Init distributed group with TCP store ("tcp") or file store ("file")')119120parser.add_argument('--simulated_fault',type=fault_desc,121help='Description of a fault to be simulated')122# fmt: on123124args=parser.parse_args()125126ifargs.interrupt_at:127args.interrupt_at=set(args.interrupt_at)128else:129args.interrupt_at=set()130131returnargs132133134defload_checkpoint(path):135map_location={136'cpu':'cpu',137}138iftorch.cuda.is_available():139map_location['cuda:0']=f'cuda:{torch.cuda.current_device()}'140141logging.info(f'Loading checkpoint from {path}')142checkpoint=torch.load(path,map_location=map_location)143returncheckpoint144145146defsave_checkpoint(147progress,148model,149optimizer,150ft_client,151output_dir,152checkpoint_fname,153):154state={155'progress':progress,156'model_state':model.state_dict(),157'optimizer_state':optimizer.state_dict(),158'ft_state':ft_client.state_dict(),159}160161checkpoint_path=os.path.join(output_dir,checkpoint_fname)162163withdist_utils.sync_workers()asrank:164ifrank==0:165logging.info(f'Saving checkpoint to {checkpoint_path}')166torch.save(state,checkpoint_path)167168169deftraining_loop(170ft_client,171para_model,172model,173optimizer,174device,175dataloader,176sampler,177progress,178args,179):180epoch_idx=progress['epoch_idx']181182# NOTE: torch.utils.data.DistributedSampler must be prepared for current epoch183# need to do it before starting iteration184sampler.start_sample_idx=progress['iter_idx']*args.batch185sampler.set_epoch(epoch_idx)186187para_model.train()188189last_log_time=time.monotonic()190191foriter_idx,xinenumerate(dataloader,start=progress['iter_idx']):192ifft_client.timeouts.are_validisFalseandepoch_idx==1anditer_idx==1:193# after 0th epoch is completed and we've done 0th iteration of the 1st epoch,194# we can calculate and set timeouts. this is a good moment to do so,195# because now we've seen the possibly long interval where checkpoint was saved.196ft_client.calculate_and_set_timeouts()197198optimizer.zero_grad()199x=x.to(device)200y=para_model(x)201loss=y.mean()202train_loss=loss.item()203loss.backward()204205ifiter_idx%args.logging_interval==0:206avg_train_loss=dist_utils.all_reduce_item(train_loss,op='mean')207logging.info(208f'CHECK TRAIN epoch: {epoch_idx:4d} '209f'iter: {iter_idx:5d} '210f'loss: {avg_train_loss} '211f'input: {x[:,0]}'212)213ifiter_idx>0:214time_per_iter=(time.monotonic()-last_log_time)/args.logging_interval215last_log_time=time.monotonic()216logging.debug(f'Avg time per iter: {time_per_iter:.3f} [sec]')217218progress['iter_idx']=iter_idx+1219220ft_client.send_heartbeat()221optimizer.step()222223# Whether to do a periodic checkpointing224periodic_save=iter_idx%args.save_interval==args.save_interval-1225226ifperiodic_saveor(epoch_idx,iter_idx)inargs.interrupt_at:227save_checkpoint(228progress=progress,229model=model,230optimizer=optimizer,231ft_client=ft_client,232output_dir=args.output_dir,233checkpoint_fname=args.checkpoint_fname,234)235if(epoch_idx,iter_idx)inargs.interrupt_at:236logging.info('Manual interruption, exiting')237sys.exit(0)238239240defvalidation_loop(ft_client,model,val_dataloader,epoch_idx,device):241total_val_loss=0242model.eval()243244foriter_idx,xinenumerate(val_dataloader):245x=x.to(device)246y=model(x)247loss=y.mean().item()248total_val_loss+=loss249ft_client.send_heartbeat()250251logging.info(252f'CHECK VAL SUMMARY: epoch: {epoch_idx:4d} 'f'loss: {total_val_loss/(iter_idx+1)}'253)254255256_sim_fault_canceled=False257_sim_fault_is_set=False258259260def_cancel_simulated_fault():261global_sim_fault_canceled262_sim_fault_canceled=True263264265def_setup_simulated_fault(ft_client,fault_desc,device):266# FIXME: hanging rank with SIGTSTP results in rank monitor267# blocked when trying to receive the data in _on_ipc_data_from_rank268269global_sim_fault_is_set270_sim_fault_is_set=True# should be True on all ranks271272rng=random.Random()273274logging.info(f"Initializing simulated fault: {fault_desc}")275276fault_type=fault_desc['fault']277iffault_type=='random':278fault_type=rng.choice(['rank_killed','rank_hung'])279280rank_to_fail=rng.randint(0,dist_utils.get_world_size()-1)281rank_to_fail=torch.tensor([rank_to_fail],device=device)282dist_utils.broadcast(rank_to_fail,0)283rank_to_fail=int(rank_to_fail.item())284285rank=torch.distributed.get_rank()286ifrank!=rank_to_fail:287return288289iffault_type=='rank_killed':290target_pid=os.getpid()291target_sig=signal.SIGKILL292eliffault_type=='rank_hung':293target_pid=os.getpid()294target_sig=signal.SIGSTOP295else:296raiseException(f"Unknown fault type {fault_type}")297298delay=fault_desc['delay']+4.0*rng.random()299300logging.info(301f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}",302)303304def__fault_thread():305time.sleep(delay)306if_sim_fault_canceled:307return308print(309f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",310file=sys.stderr,311)312os.kill(target_pid,target_sig)313314fault_sim_thread=threading.Thread(target=__fault_thread)315fault_sim_thread.daemon=True316fault_sim_thread.start()317318319_signal_received=False320321322def_sig_handler(*args,**kwargs):323print("Signal received!",file=sys.stderr)324global_signal_received325_signal_received=True326327328defmain():329signal.signal(signal.SIGTERM,_sig_handler)330331args=parse_args()332333torch.manual_seed(123)334np.random.seed(123)335random.seed(123)336337ifargs.device=='cuda':338iftorch.cuda.is_available():339device=torch.device('cuda')340torch.cuda.set_device(args.local_rank)341else:342raiseRuntimeError("Selected 'cuda' device but torch.cuda is not available.")343elifargs.device=='cpu':344device=torch.device('cpu')345else:346raiseRuntimeError('Unknown device')347348ifint(os.getenv('WORLD_SIZE','1'))==1:349raiseRuntimeError('This example supports only multi-gpu training')350351os.makedirs(args.output_dir,exist_ok=True)352353ifargs.init_distributed_method=='tcp':354# NOTE: when runing tests with tcp init method we noticed355# occasional "address already in use" errors, after workload356# is restarted357dist_utils.init_distributed_with_tcp_store(device)358elifargs.init_distributed_method=='file':359dist_utils.init_distributed_with_file_store(device,store_file_dir=args.output_dir)360else:361raiseRuntimeError(362f"--init_distributed_method should be ['tcp','file'] it is {args.init_distributed_method}"363)364365ifargs.log_all_ranks:366log_file_name=f'train_log_rank_{dist_utils.get_rank()}.log'367else:368log_file_name='train_log.log'369log_file_path=os.path.join(args.output_dir,log_file_name)370371# NOTE: logging appends outputs to an existing log file if it already372# exists. Results from a single training run (potentially with many373# restarts from a checkpoint) are stored in a single log file.374log_utils.setup_logging(args.log_all_ranks,filename=log_file_path,filemode='a')375logging.info(args)376377rank=dist_utils.get_rank()378379logging.info(f"SLURM_JOB_ID={os.getenv('SLURM_JOB_ID','<none>')} RANK={rank} PID={os.getpid()}")380381# Dummy datasets382train_dataset=Dataset(args.train_dataset_size,args.hidden)383val_dataset=Dataset(args.val_dataset_size,args.hidden)384385# ResumableDistributedSampler is needed to skip consumed samples386train_sampler=dist_utils.ResumableDistributedSampler(387train_dataset,388drop_last=True,389)390391val_sampler=torch.utils.data.DistributedSampler(392val_dataset,393)394395# A dummy model and an optimizer396model=Model(args.hidden).to(device)397optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)398399# Initial value for start epoch - will be overwritten if training is resumed from a checkpoint400progress={401'epoch_idx':0,402'iter_idx':0,403}404405checkpoint_path=os.path.join(args.output_dir,args.checkpoint_fname)406407# Initialize fault tolerance.408ft_client=fault_tolerance.RankMonitorClient()409ft_client.init_workload_monitoring()410411checkpoint=None412413# try to load checkpoint from disk414ifos.path.exists(checkpoint_path):415checkpoint=load_checkpoint(checkpoint_path)416ifcheckpoint:417logging.info(f'Checkpoint was loaded from file: {checkpoint_path}')418419ifcheckpoint:420model.load_state_dict(checkpoint['model_state'])421optimizer.load_state_dict(checkpoint['optimizer_state'])422ft_client.load_state_dict(checkpoint['ft_state'])423progress.update(checkpoint['progress'])424# Return with zero exit code if model is already fully trained.425ifprogress['epoch_idx']==args.epochs:426logging.info('Training finished.')427sys.exit(0)428429train_dataloader=torch.utils.data.DataLoader(430dataset=train_dataset,431batch_size=args.batch,432sampler=train_sampler,433num_workers=4,434persistent_workers=True,435pin_memory=False,436)437438val_dataloader=torch.utils.data.DataLoader(439dataset=val_dataset,440batch_size=args.batch,441sampler=val_sampler,442num_workers=4,443)444445# Regular DDP init446# NOTE: for convenience code is keeping both wrapped and unwrapped model and447# uses wrapped model for training and unwrapped model for saving the448# checkpoint and validation. It doesn't increase memory consumption449# since both models are holding references to the same parameters.450# Additionally saved checkpoint is ready for inference and doesn't have to451# be manually unwrapped by accessing the (undocumented) "module" attribute452# of DDP-wrapped model.453ifdevice.type=='cuda':454device_ids=[args.local_rank]455output_device=args.local_rank456elifdevice.type=='cpu':457device_ids=None458output_device=None459else:460raiseRuntimeError('Unsupported device type')461para_model=torch.nn.parallel.DistributedDataParallel(462model,device_ids=device_ids,output_device=output_device463)464465# Iteration over epochs, notice that it starts from 'epoch_idx'466# which was previously loaded from the checkpoint467forepoch_idxinrange(progress['epoch_idx'],args.epochs):468training_loop(469ft_client,470para_model,471model,472optimizer,473device,474train_dataloader,475train_sampler,476progress,477args,478)479480# epoch_idx is incremented because the current epoch is finished481# and potential resume from this checkpoint should start a new training epoch.482progress['epoch_idx']+=1483progress['iter_idx']=0484485validation_loop(ft_client,model,val_dataloader,epoch_idx,device)486487# Checkpoint contains everything needed for deterministic resume:488# state of the model, optimizer and other components,489save_checkpoint(490progress=progress,491model=model,492optimizer=optimizer,493ft_client=ft_client,494output_dir=args.output_dir,495checkpoint_fname=args.checkpoint_fname,496)497498# NOTE: SIGTERM is used by SLURM to initiate graceful job termination499# if _any_ rank received SIGTERM, we leave the main loop500ifdist_utils.is_true_on_any_rank(_signal_received):501logging.info('Leaving the main loop, due to SIGTERM')502break503504# Setup simulated fault as soon as we have valid timeouts505ifargs.simulated_faultandnot_sim_fault_is_setandft_client.timeouts.are_valid:506_setup_simulated_fault(ft_client,args.simulated_fault,device)507508_cancel_simulated_fault()509ft_client.shutdown_workload_monitoring()510logging.info('Leaving main, ret_code=0')511sys.exit(0)512513514if__name__=="__main__":515main()