Asynchronous Checkpoint Core Utilities

This module provides an async utilities which allow to start a checkpoint save process in the background.

class nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCaller[source]

Bases: ABC

Wrapper around mp.Process that ensures correct semantic of distributed finalization.

Starts process asynchronously and allows checking if all processes on all ranks are done.

abstract close()[source]

Terminate the async caller at exit of an application or some termination conditions

abstract is_current_async_call_done(blocking, no_dist)[source]

Check if async save is finished on all ranks.

For semantic correctness, requires rank synchronization in each check. This method must be called on all ranks.

Parameters:
  • blocking (bool, optional) – if True, will wait until the call is done on all ranks. Otherwise, returns immediately if at least one rank is still active. Defaults to False.

  • no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.

Returns:

True if all ranks are done (immediately of after active wait

if blocking is True), False if at least one rank is still active.

Return type:

bool

abstract schedule_async_call(async_req)[source]
Schedule async_req with some process forking or reusing

persistent worker

This method must be called on all ranks.

Parameters:

async_req (AsyncRequest) – AsyncRequest object containing to start async process

Return type:

None

sync_all_async_calls(is_alive)[source]

Check if all ranks have completed async checkpoint writing

Parameters:

is_alive (bool) – if True, the current async request is not completed

Returns:

True if all ranks are done, False if at least one rank is still active.

Return type:

bool

class nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue(persistent=True)[source]

Bases: object

Manages a queue of async calls.

Allows adding a new async call with schedule_async_request and finalizing active calls with maybe_finalize_async_calls.

Parameters:

persistent (bool)

close()[source]

Finalize all calls upon closing.

get_num_unfinalized_calls()[source]

Get the number of active async calls.

maybe_finalize_async_calls(blocking=False, no_dist=False)[source]

Finalizes all available calls.

This method must be called on all ranks.

Parameters:
  • blocking (bool, optional) – if True, will wait until all active requests are done. Otherwise, finalizes only the async request that already finished. Defaults to False.

  • no_dist (bool, optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.

Returns:

list of indices (as returned by schedule_async_request)

of async calls that have been successfully finalized.

Return type:

List[int]

schedule_async_request(async_request)[source]

Start a new async call and add it to a queue of active async calls.

This method must be called on all ranks.

Parameters:

async_request (AsyncRequest) – async request to start.

Returns:

index of the async call that was started.

This can help the user keep track of the async calls.

Return type:

int

class nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncRequest(async_fn, async_fn_args, finalize_fns, async_fn_kwargs={}, preload_fn=None, is_frozen=False, call_idx=0)[source]

Bases: NamedTuple

Represents an async request that needs to be scheduled for execution.

Parameters:
  • async_fn (Callable, optional) – async function to call. None represents noop.

  • async_fn_args (Tuple) – args to pass to async_fn.

  • finalize_fns (List[Callable]) – list of functions to call to finalize the request. These functions will be called synchronously after async_fn is done on all ranks.

  • async_fn_kwargs (Tuple) – kwargs to pass to async_fn.

  • preload_fn (Callable) – preload function to stage tensors from GPU to Host. This should be self-contained with a proper list of arguments with partial.

  • is_frozen (Bool) – a flag to indicate this async request can be modified or not.

  • call_idx (int) – index variable used to order async requests for synchronization in preloading and writing tensors on the async caller

Create new instance of AsyncRequest(async_fn, async_fn_args, finalize_fns, async_fn_kwargs, preload_fn, is_frozen, call_idx)

add_finalize_fn(fn)[source]

Adds a new finalize function to the request.

Parameters:

fn (Callable) – function to add to the async request. This function will be called after existing finalization functions.

Returns:

None

Return type:

None

async_fn: Callable | None

Alias for field number 0

async_fn_args: Tuple

Alias for field number 1

async_fn_kwargs: Dict

Alias for field number 3

call_idx: int

Alias for field number 6

execute_finalize_fns(validate_matching_call_idx=True)[source]

Execute all the finalize functions associated with this async request.

Parameters:

validate_matching_call_idx (bool, optional) – Validate that all ranks invoke CP finalize on the same call_idx. This is typically useful in async CP stages where multiple CP requests can be pending. This validation is unnecessary during synchronous CP step. When this param is True, an AllReduce Sync across all participating ranks is invoked. Default set to True for conservative validation.

Returns:

The call_idx of async request that has been finalized

Return type:

call_idx

execute_sync()[source]

Helper to synchronously execute the request.

This logic is equivalent to what should happen in case of the async call.

Return type:

None

finalize_fns: List[Callable]

Alias for field number 2

freeze()[source]

Freezes the async request, disallowing adding new finalization functions.

Returns:

new async request with all same fields except for the

is_frozen flag.

Return type:

AsyncRequest

is_frozen: bool

Alias for field number 5

preload_fn: Callable

Alias for field number 4

class nvidia_resiliency_ext.checkpointing.async_ckpt.core.PersistentAsyncCaller[source]

Bases: AsyncCaller

Wrapper around mp.Process that ensures correct semantic of distributed finalization.

Starts process asynchronously and allows checking if all processes on all ranks are done.

static async_loop(rank, queue, preload_q, comp_q, log_level=20)[source]

Main function for the persistent checkpoint worker

The persisent worker is created once and terminated at exit or when application calls close() explictily

This routine receives AsyncRequest and does preload_fn first and put the integer value in preload_q to inform the trainer to proceed. When the async_fn from the request` is completed (background saving is done), it puts a integer value to comp_q to notify the trainer the completion.

Parameters:
  • rank (int) – the rank of the trainer where the persistent worker is created.

  • queue (mp.JoinableQueue) – the main queue used to receive AsyncRequest from the training rank

  • preload_q (mp.JoinableQueue) – a queue to inform trainer that preloading of tensors from GPU to Host or dedicated location is completed

  • comp_q (mp.Queue) – a queue to inform the training rank the completion of scheduled async checkpoint request

  • log_level (int, Optional) – an integer to set log-level in this spawned process to get aligned with the training rank’s logging level

close()[source]

Wait on the left async requests and terminate the PersistentAsyncCaller

Signals the PersistentAsyncCaller by sending a ‘DONE’ message to make it terminated

is_current_async_call_done(blocking=False, no_dist=False)[source]

Check if async save is finished on all ranks.

For semantic correctness, requires rank synchronization in each check. This method must be called on all ranks.

Parameters:
  • blocking (bool, optional) – if True, will wait until the call is done on all ranks. Otherwise, returns immediately if at least one rank is still active. Defaults to False.

  • no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.

Returns:

True if all ranks are done (immediately of after active wait

if blocking is True), False if at least one rank is still active.

Return type:

bool

schedule_async_call(async_req)[source]

Put AsyncRequest to the Persistent Async Caller

This method must be called on all ranks. The async_req object is pickled and sent to the persistent async worker via a JoinableQueue. Therefore, all arguments within async_req must be picklable.

Parameters:
  • async_fn (Callable, optional) – async function to call. If None, no process will be started.

  • async_req (AsyncRequest) – AsyncRequest object containing to schedule a checkpointing request

Return type:

None

class nvidia_resiliency_ext.checkpointing.async_ckpt.core.TemporalAsyncCaller[source]

Bases: AsyncCaller

Wrapper around mp.Process that ensures correct semantic of distributed finalization.

Starts process asynchronously and allows checking if all processes on all ranks are done.

close()[source]

For TemporalAsyncCaller, this method is called explictly in is_current_async_calls_done

This method make sure the TemporalAsyncCaller terminated with all its assigned async request completed

is_current_async_call_done(blocking=False, no_dist=False)[source]

Check if async save is finished on all ranks.

For semantic correctness, requires rank synchronization in each check. This method must be called on all ranks.

Parameters:
  • blocking (bool, optional) – if True, will wait until the call is done on all ranks. Otherwise, returns immediately if at least one rank is still active. Defaults to False.

  • no_dist (bool, Optional) – if True, training ranks simply check its asynchronous checkpoint writer without synchronization.

Returns:

True if all ranks are done (immediately of after active wait

if blocking is True), False if at least one rank is still active.

Return type:

bool

schedule_async_call(async_req)[source]

Spawn a process with async_fn as the target.

This method must be called on all ranks.

Parameters:
  • async_fn (Callable, optional) – async function to call. If None, no process will be started.

  • async_req (AsyncRequest) – AsyncRequest object containing to start async process

Return type:

None