Asynchronous PyTorch torch.save with the Core utility
TorchAsyncCheckpoint defines a wrapper for the async version of torch.save with an additional method to synchronize async saving requests
- class nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt.TorchAsyncCheckpoint(persistent_queue=False)[source]
Bases:
object
- async_fn = None
- async_save(state_dict, *args, **kwargs)[source]
Keeps the original interface of torch.save Schedules a AsyncReuqest with preloading tensors to CPU with pinned memcpy
- finalize_async_save(blocking=False, no_dist=True, terminate=False)[source]
Finalizes active async save calls.
- Parameters:
blocking (bool, optional) – if True, will wait until all active requests are done. Otherwise, finalizes only the async request that already finished. Defaults to False.
no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.
terminate (bool, optional) – if True, the asynchronous queue will be closed as the last action of this function.