Usage guide
The nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue
provides application users with an interface to schedule nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncRequest
,
which defines checkpoint routine, its args/kwargs and finalization steps when the checkpoint routine is finished.
nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt.TorchAsyncCheckpoint
is an instatiation of the core utilities to make torch.save run asynchronously.
The implementation assumes all training ranks creates core.AsyncCallsQueue
and synchronize with core.AsyncCallsQueue.maybe_finalize_async_calls
by default.
Requirements
nvidia_resiliency_ext.checkpointing.utils
includes a couple of routines used for nvidia_resiliency_ext.checkpointing.async_ckpt.core
nvidia_resiliency_ext.checkpointing.utils.wrap_for_async
disables garbage collection in a forked process to run user’s checkpoint routine
to prevent failures incurred by GC, which tries to deallocate CUDA tensors in a forked process.
This routine requires the first argument of the passed user fn should be state dictionary containing tensors or objects for checkpoint
The current implementation uses a forked process to run pre-staged tensors in host memory by pinned memcpy.
So, the routine should include nvidia_resiliency_ext.checkpointing.utils.preload_tensors
to stage GPU tensors in a state dictionary to host memory before it’s passed to AsyncCallsQueue
Synchronization of Asynchronous Checkpoint Requests
The class nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue
provides a method to verify whether asynchronous checkpointing has completed in the background.
Each trainer can check the status of its forked checkpoint process by calling
nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue.maybe_finalize_async_calls()
with blocking=False.
When a trainer needs to finalize all active checkpoint requests in a blocking manner, it can call the same method with blocking=True.
Additionally,
AsyncCallsQueue.maybe_finalize_async_calls()
includes another parameter that must be set to no_dist=False when global synchronization across all ranks is required.
For example, if a checkpointing routine needs to write metadata (e.g., iteration or sharding information) after completing a set of checkpoints,
global synchronization ensures that all ranks finish their asynchronous checkpointing before proceeding.
This global synchronization is implemented using a single integer collective operation, ensuring that all ranks have completed their asynchronous checkpoint writes.
The synchronization logic is handled within
nvidia_resiliency_ext.checkpointing.async_ckpt.core.DistributedAsyncCaller.is_current_async_call_done()
, which is invoked by AsyncCallsQueue.maybe_finalize_async_calls()
.
The following snippet demonstrates how global synchronization is performed when no_dist is set to False (indicating that synchronization is required):
is_alive = int(self.process.is_alive()) if self.process is not None else 0
is_done = is_alive
if not no_dist:
ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device())
TorchAsyncCheckpoint wraps around these synchronization routines in nvidia_resiliency_ext.checkpointing.async_ckpt.TorchAsyncCheckpoint.finalize_async_save. The following example shows how the routine can be used to synchronize the asynchronous checkpoint in a non-blocking / blocking manner
from nvidia_resiliency_ext.checkpointing.async_ckpt import TorchAsyncCheckpoint
...
async_impl = TorchAsyncCheckpoint
# Training loop
while True:
async_impl.finalize_async_save(blocking=False)
# Perform a training iteration
...
# Save checkpoint if conditions are met
if save_condition():
async_impl.async_save(model.state_dict(), ckpt_dir)
async_impl.finalize_async_save(blocking=True)