1importargparse 2importlogging 3importos 4importshutil 5fromtypingimportUnion 6 7importtorch 8importtorch.distributedasdist 9importtorch.nnasnn1011fromnvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckptimportTorchAsyncCheckpoint1213# Set up basic logging configuration14logging.basicConfig(level=logging.INFO)1516defparse_args():17parser=argparse.ArgumentParser(18description='Local Checkpointing Basic Example',19formatter_class=argparse.ArgumentDefaultsHelpFormatter,20)2122returnparser.parse_args()232425# Define a simple model26classSimpleModel(nn.Module):27def__init__(self):28super(SimpleModel,self).__init__()29self.fc1=nn.Linear(10,5)# Linear layer: input size 10, output size 530self.fc2=nn.Linear(5,2)# Linear layer: input size 5, output size 231self.activation=nn.ReLU()# Activation function: ReLU3233defforward(self,x):34x=self.activation(self.fc1(x))35x=self.fc2(x)36returnx373839definit_distributed_backend(backend="nccl"):40"""41 Initialize the distributed process group for NCCL backend.42 Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.43 """44try:45dist.init_process_group(46backend=backend,# Use NCCL backend47init_method="env://",# Use environment variables for initialization48)49logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")5051# Ensure each process uses a different GPU52torch.cuda.set_device(dist.get_rank())53exceptExceptionase:54logging.error(f"Error initializing the distributed backend: {e}")55raise565758defmain():59args=parse_args()60logging.info(f'{args}')6162# Initialize the distributed backend63init_distributed_backend(backend="nccl")6465# Instantiate the model and move to CUDA66model=SimpleModel().to("cuda")6768# Define checkpoint directory and manager69ckpt_dir="/tmp/test_local_checkpointing/ckpt.pt"7071ckpt_impl=TorchAsyncCheckpoint()7273ckpt_impl.async_save(model.state_dict(),ckpt_dir+"ckpt.pt")7475finalize_async_save(blocking=True,no_dist=True)7677torch.load(ckpt_dir,ckpt_dir)7879# Synchronize processes to ensure all have completed the loading80dist.barrier()8182# Clean up checkpoint directory only on rank 083ifdist.get_rank()==0:84logging.info(f"Cleaning up checkpoint directory: {ckpt_dir}")85shutil.rmtree(ckpt_dir)868788if__name__=="__main__":89main()