Usage guide

The nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue provides application users with an interface to schedule nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncRequest, which defines checkpoint routine, its args/kwargs and finalization steps when the checkpoint routine is finished.

nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt.TorchAsyncCheckpoint

is an instatiation of the core utilities to make torch.save run asynchronously.

The implementation assumes all training ranks creates core.AsyncCallsQueue and synchronize with core.AsyncCallsQueue.maybe_finalize_async_calls by default.

Requirements

nvidia_resiliency_ext.checkpointing.utils includes a couple of routines used for nvidia_resiliency_ext.checkpointing.async_ckpt.core nvidia_resiliency_ext.checkpointing.utils.wrap_for_async disables garbage collection in a forked process to run user’s checkpoint routine to prevent failures incurred by GC, which tries to deallocate CUDA tensors in a forked process. This routine requires the first argument of the passed user fn should be state dictionary containing tensors or objects for checkpoint

The current implementation uses a forked process to run pre-staged tensors in host memory by pinned memcpy. So, the routine should include nvidia_resiliency_ext.checkpointing.utils.preload_tensors to stage GPU tensors in a state dictionary to host memory before it’s passed to AsyncCallsQueue

Synchronization of Asynchronous Checkpoint Requests

The class nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue provides a method to verify whether asynchronous checkpointing has completed in the background. Each trainer can check the status of its forked checkpoint process by calling nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue.maybe_finalize_async_calls() with blocking=False.

When a trainer needs to finalize all active checkpoint requests in a blocking manner, it can call the same method with blocking=True.

Additionally, AsyncCallsQueue.maybe_finalize_async_calls() includes another parameter that must be set to no_dist=False when global synchronization across all ranks is required. For example, if a checkpointing routine needs to write metadata (e.g., iteration or sharding information) after completing a set of checkpoints, global synchronization ensures that all ranks finish their asynchronous checkpointing before proceeding.

This global synchronization is implemented using a single integer collective operation, ensuring that all ranks have completed their asynchronous checkpoint writes. The synchronization logic is handled within nvidia_resiliency_ext.checkpointing.async_ckpt.core.DistributedAsyncCaller.is_current_async_call_done(), which is invoked by AsyncCallsQueue.maybe_finalize_async_calls().

The following snippet demonstrates how global synchronization is performed when no_dist is set to False (indicating that synchronization is required):

is_alive = int(self.process.is_alive()) if self.process is not None else 0

is_done = is_alive
if not no_dist:
    ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device())

TorchAsyncCheckpoint wraps around these synchronization routines in nvidia_resiliency_ext.checkpointing.async_ckpt.TorchAsyncCheckpoint.finalize_async_save. The following example shows how the routine can be used to synchronize the asynchronous checkpoint in a non-blocking / blocking manner

from nvidia_resiliency_ext.checkpointing.async_ckpt import TorchAsyncCheckpoint
...
async_impl = TorchAsyncCheckpoint

# Training loop
while True:
    async_impl.finalize_async_save(blocking=False)
    # Perform a training iteration
    ...
    # Save checkpoint if conditions are met
    if save_condition():
            async_impl.async_save(model.state_dict(), ckpt_dir)

async_impl.finalize_async_save(blocking=True)