Asynchronous PyTorch torch.save with the Core utility

TorchAsyncCheckpoint defines a wrapper for the async version of torch.save with an additional method to synchronize async saving requests

class nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt.TorchAsyncCheckpoint(persistent_queue=False)[source]

Bases: object

async_fn = None
async_save(state_dict, *args, **kwargs)[source]

Keeps the original interface of torch.save Schedules a AsyncReuqest with preloading tensors to CPU with pinned memcpy

finalize_async_save(blocking=False, no_dist=True, terminate=False)[source]

Finalizes active async save calls.

Parameters:
  • blocking (bool, optional) – if True, will wait until all active requests are done. Otherwise, finalizes only the async request that already finished. Defaults to False.

  • no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.

  • terminate (bool, optional) – if True, the asynchronous queue will be closed as the last action of this function.