Usage guide

The nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue provides application users with an interface to schedule nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncRequest, which defines checkpoint routine, its args/kwargs and finalization steps when the checkpoint routine is finished.

nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt.TorchAsyncCheckpoint

is an instatiation of the core utilities to make torch.save run asynchronously.

The implementation assumes all training ranks creates core.AsyncCallsQueue and synchronize with core.AsyncCallsQueue.maybe_finalize_async_calls by default.

Requirements

nvidia_resiliency_ext.checkpointing.utils includes a couple of routines used for nvidia_resiliency_ext.checkpointing.async_ckpt.core nvidia_resiliency_ext.checkpointing.utils.wrap_for_async disables garbage collection in a forked process to run user’s checkpoint routine to prevent failures incurred by GC, which tries to deallocate CUDA tensors in a forked process. This routine requires the first argument of the passed user fn should be state dictionary containing tensors or objects for checkpoint

The current implementation uses a forked process to run pre-staged tensors in host memory by pinned memcpy. So, the routine should include nvidia_resiliency_ext.checkpointing.utils.preload_tensors to stage GPU tensors in a state dictionary to host memory before it’s passed to AsyncCallsQueue

Synchronization of Asynchronous Checkpoint Requests

The class nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue provides a method to verify whether asynchronous checkpointing has completed in the background. Each trainer can check the status of its forked checkpoint process by calling nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue.maybe_finalize_async_calls() with blocking=False.

When a trainer needs to finalize all active checkpoint requests in a blocking manner, it can call the same method with blocking=True.

Additionally, AsyncCallsQueue.maybe_finalize_async_calls() includes another parameter that must be set to no_dist=False when global synchronization across all ranks is required. For example, if a checkpointing routine needs to write metadata (e.g., iteration or sharding information) after completing a set of checkpoints, global synchronization ensures that all ranks finish their asynchronous checkpointing before proceeding.

This global synchronization is implemented using a single integer collective operation, ensuring that all ranks have completed their asynchronous checkpoint writes. The synchronization logic is handled within nvidia_resiliency_ext.checkpointing.async_ckpt.core.DistributedAsyncCaller.is_current_async_call_done(), which is invoked by AsyncCallsQueue.maybe_finalize_async_calls().

The following snippet demonstrates how global synchronization is performed when no_dist is set to False (indicating that synchronization is required):

is_alive = int(self.process.is_alive()) if self.process is not None else 0

is_done = is_alive
if not no_dist:
    ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device())

TorchAsyncCheckpoint wraps around these synchronization routines in nvidia_resiliency_ext.checkpointing.async_ckpt.TorchAsyncCheckpoint.finalize_async_save. The following example shows how the routine can be used to synchronize the asynchronous checkpoint in a non-blocking / blocking manner

from nvidia_resiliency_ext.checkpointing.async_ckpt import TorchAsyncCheckpoint
...
async_impl = TorchAsyncCheckpoint

# Training loop
while True:
    async_impl.finalize_async_save(blocking=False)
    # Perform a training iteration
    ...
    # Save checkpoint if conditions are met
    if save_condition():
            async_impl.async_save(model.state_dict(), ckpt_dir)

async_impl.finalize_async_save(blocking=True)

Using Multi-Storage Client with Async Checkpointing

nvidia_resiliency_ext supports saving checkpoints to object stores like AWS S3, Azure Blob Storage, Google Cloud Storage, and more through the NVIDIA Multi-Storage Client (MSC) integration.

MSC (GitHub repository) provides a unified API for various storage backends, allowing you to write checkpoints to different storage services using the same code.

Installation

Before using MSC integration, you need to install the Multi-Storage Client package:

# Install the Multi-Storage Client package with boto3 support
pip install multi-storage-client[boto3]

Configuration

Create a configuration file for the Multi-Storage Client and export the environment variable MSC_PROFILE to point to it:

export MSC_CONFIG=/path/to/your/msc_config.yaml
Example configuration file used for AWS S3.
profiles:
  model-checkpoints:
    storage_provider:
      type: s3
      options:
        base_path: bucket-checkpoints # Set the bucket name as the base path
    credentials_provider:
      type: S3Credentials
      options:
        access_key: ${AWS_ACCESS_KEY} # Set the AWS access key in the environment variable
        secret_key: ${AWS_SECRET_KEY} # Set the AWS secret key in the environment variable

Basic Usage

To enable MSC integration, simply pass use_msc=True when creating the FileSystemWriterAsync instance:

The MSC URL scheme is msc://<profile-name>/<path>. The example below shows how to save a checkpoint to the model-checkpoints profile, the data will be stored in the bucket-checkpoints bucket in AWS S3.

from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync

# Create writer with MSC enabled
writer = FileSystemWriterAsync(
    "msc://model-checkpoints/iteration-0010",
    use_msc=True
)

Example: Saving and Loading Checkpoints with MSC

The following example demonstrates a complete workflow for saving and loading checkpoints using Multi-Storage Client integration:

import torch
from torch.distributed.checkpoint import (
    DefaultLoadPlanner,
    DefaultSavePlanner,
    load,
)

from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
    save_state_dict_async_finalize,
    save_state_dict_async_plan,
)

import multistorageclient as msc


def async_save_checkpoint(checkpoint_path, state_dict, thread_count=2):
    """
    Save checkpoint asynchronously to MSC storage.
    """
    # Create async queue
    async_queue = AsyncCallsQueue()

    # Create writer with MSC enabled
    writer = FileSystemWriterAsync(checkpoint_path, thread_count=thread_count, use_msc=True)
    coordinator_rank = 0
    planner = DefaultSavePlanner()

    # Plan the save operation
    save_state_dict_ret = save_state_dict_async_plan(
        state_dict, writer, None, coordinator_rank, planner=planner
    )

    # Create async request with finalization
    save_fn, preload_fn, save_args = writer.get_save_function_and_args()

    def finalize_fn():
        """Finalizes async checkpointing and synchronizes processes."""
        save_state_dict_async_finalize(*save_state_dict_ret)
        if torch.distributed.is_initialized():
            torch.distributed.barrier()

    async_request = AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)

    # Schedule the request and return the queue for later checking
    async_queue.schedule_async_request(async_request)
    return async_queue


def load_checkpoint(checkpoint_path, state_dict):
    """
    Load checkpoint from MSC storage into the state_dict.
    """
    # Create reader with MSC path
    reader = msc.torch.MultiStorageFileSystemReader(checkpoint_path, thread_count=2)

    # Load the checkpoint into the state_dict
    load(
        state_dict=state_dict,
        storage_reader=reader,
        planner=DefaultLoadPlanner(),
    )
    return state_dict


# Initialize your model and get state_dict
model = YourModel()
state_dict = model.state_dict()

# Save checkpoint asynchronously
checkpoint_path = "msc://model-checkpoints/iteration-0010"
async_queue = async_save_checkpoint(checkpoint_path, state_dict)
async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)

# Load checkpoint synchronously
loaded_state_dict = load_checkpoint(checkpoint_path, state_dict.copy())