Usage Guide

The nvidia_resiliency_ext.inprocess.Wrapper serves as the primary interface for accessing in-process restart functionality. It provides various configuration options through its arguments, enabling customization of the restart process and fault monitoring capabilities. To ensure efficient and effective restarts, the function being wrapped must meet specific requirements. This usage guide outlines the requirements, features, and limitations of the in-process restart functionality provided by the Wrapper.

Requirements

In-process restart functionality requires PyPI PyTorch v2.5.1 or PyTorch NGC Container versions 24.07 through 24.10. For further limitations and compatibility details, refer to the Known issues section.

Requirements for the wrapped function

  • The wrapped function should be designed to support restarts, meaning it should carefully manage any external (e.g., global) state and resources, avoid using functions that can only be called once per process, such as multiprocessing.set_start_method() or MPI_Init, to ensure that the function can be executed multiple times in the same process without issues.

    • The function will automatically retry on any failure, meaning it will be called again with the same set of input arguments; extra caution is needed if the function accepts mutable arguments that might be modified during its execution, as these changes could affect subsequent retries.

  • All operations that wait on results from NCCL kernels, or synchronize with the GPU, need to release Python Global Interpreter Lock (GIL).

    • If the Python GIL is not released when a fault occurs, the graceful restart procedure cannot proceed. This is because the procedure runs in a separate Python thread, which is blocked from execution due to the GIL being held. As a result, hung ranks must be forcibly terminated using the hard timeout mechanism (SIGKILL). These terminated ranks will not rejoin the distributed job upon restart.

  • The function does not suppress BaseException. If the wrapped function catches a BaseException, it must re-raise it to ensure it propagates to the outer scope.

  • The function is responsible for initialization of PyTorch distributed backend (torch.distributed.init_process_group()); the initialization needs to read standard PyTorch distributed variables (RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT and LOCAL_RANK) from the environment.

  • it’s heavily recommended for the wrapped function to load the state affected by distributed collectives from a checkpoint on every restart (e.g. load weights of a model); outputs of distributed collectives are likely to become corrupted or invalid if a fault happened while a collective was in-flight and distributed backend was terminated.

Requirements for the execution environment

  • The PyTorch NCCL watchdog must either be disabled or configured with a timeout longer than the hard_timeout of the nvidia_resiliency_ext.inprocess.Wrapper. If the NCCL watchdog is triggered, it forcibly terminates the process, preventing a restart. To adjust the NCCL watchdog timeout, use the timeout argument when calling torch.distributed.init_process_group() with the backend parameter set to "nccl"

  • The job scheduler must not terminate the entire job if a faulty rank exits early or if the main process is terminated; instead, it should wait until all user-launched processes have fully exited before ending the distributed job.

Restrictions

  • node failure on rank 0 causes termination of the entire job; by default, rank 0 hosts internal torch.distributed.TCPStore to allow communication between ranks, users may specify a different implementation of a distributed store by subclassing from nvidia_resiliency_ext.inprocess.store.StoreMixin and passing the subclass as store_factory argument to the nvidia_resiliency_ext.inprocess.Wrapper

  • blocking calls issued by the main process are generally not recoverable if they hang, except for NCCL collectives or functions waiting on them; NCCL collectives are asynchronously aborted by a separate monitoring thread that calls nvidia_resiliency_ext.inprocess.abort.AbortTorchDistributed; users can specify additional nvidia_resiliency_ext.inprocess.abort.Abort subclasses to asynchronously abort blocking calls from other software components.

Functionality overview

Implementation overview

Below is a simplified pseudocode snippet that illustrates the order of operations executed by nvidia_resiliency_ext.inprocess.Wrapper, providing a high-level overview of the workflow within this class. This code is for illustrative purposes only and may omit certain implementation details.

distributed_store = store_factory(**store_kwargs)
initial_barrier()
rank_assignment()
rank_filter()

while True:
    iteration_barrier()
    initialize()
    health_check()
    try:
      if rank_is_active:
          wrapped_function()
      else:
          sleep()
      completion_barrier()
    except:
        abort()
        finalize()
        health_check()
        termination_barrier()
        rank_assignment()
        rank_filter()
    else:
        break

Distributed execution behavior

Entering and exiting the Wrapper act as distributed synchronization points. Upon entry, all workers retrieve their initial rank assignments and the total number of workers by reading the standard PyTorch distributed environment variables (RANK, WORLD_SIZE). Subsequently, all workers synchronize through a initial_barrier using a user-defined barrier_timeout to ensure consistent initialization.

Upon completion of the wrapped function, all ranks that finish enter a completion_barrier governed by a user-defined completion_timeout. If any rank fails to synchronize within the completion_timeout, it is treated as a rank failure, triggering a restart of the wrapped function on all distributed ranks.

The restart Wrapper incorporates additional distributed barriers to ensure proper synchronization: iteration_barrier (executed at the start of each restart iteration), and termination_barrier (executed before rank reassignment and filtering). These barriers are designed to be transparent to the user, requiring no modifications to the wrapped function or assumptions about the execution environment. They operate seamlessly to maintain distributed consistency and coordination.

Rank assignment

The Wrapper needs to ensure that the wrapped function is restarted with a consecutive sequence of integer rank indices, from 0 to WORLD_SIZE - 1, as some of the ranks from previous iteration may have been terminated or are in an unhealthy state. Rank reassignment and new world size computation is performed by nvidia_resiliency_ext.inprocess.rank_assignment.RankAssignment instance passed as rank_assignment argument to the Wrapper.

Multiple RankAssignments could be composed with nvidia_resiliency_ext.inprocess.Compose to achieve the desired behavior.

For example:

rank_assignment = (
    inprocess.Compose(
        inprocess.rank_assignment.ShiftRanks(),
        inprocess.rank_assignment.FilterGroupedByKey(
            key_or_fn=lambda rank, _: rank // 8,
            condition=lambda count: count == 8,
        ),
    ),
),

ensures that all ranks within each non-overlapping group of 8 consecutive ranks remain healthy. If any rank within a group of 8 is unhealthy or terminated, the entire group is terminated. The remaining healthy ranks are then reassigned by shifting left to close any gaps, forming a new sequence of consecutive integers from 0 up to the updated world size.

Rank filter

By default, all active ranks are calling the wrapped function. This behavior can be customized by providing a nvidia_resiliency_ext.inprocess.rank_filter.RankFilter instance as rank_filter argument for the Wrapper. RankFilter selects which ranks are active in the current restart iteration. Active ranks call the wrapped function. Inactive ranks are waiting idle, and could serve as a static, preallocated and preinitialized pool of spare ranks. Spare ranks would be activated in a subsequent restart iteration if previously active ranks were terminated or became unhealthy.

Multiple rank filters could be composed with nvidia_resiliency_ext.inprocess.Compose to achieve the desired behavior. For example:

rank_filter=inprocess.Compose(
    inprocess.rank_filter.WorldSizeDivisibleBy(M),
    inprocess.rank_filter.MaxActiveWorldSize(N),
),

ensures that the active world size visible to the wrapped function is the largest multiple of M that is not greater than N. The remaining ranks would be inactive and serve as spares.

Initialize

The Wrapper accepts an optional, user-provided nvidia_resiliency_ext.inprocess.initialize.Initialize class, which is executed at the start of every restart iteration, including the first one. Initialize can raise exceptions (e.g., if specific preconditions are not met). Raising a standard Python Exception triggers another restart of the wrapped function, while raising a BaseException terminates the Wrapper. The included nvidia_resiliency_ext.inprocess.initialize.RetryController can be used to limit the number of restart attempts or to halt execution if the number of healthy workers drops below a specified threshold.

Multiple initializers could be composed with nvidia_resiliency_ext.inprocess.Compose.

Wrapped function termination mechanism

When a fault or timeout occurs on any rank participating in the distributed job, the Wrapper waits for the last_call_wait interval to allow all concurrent faults from other distributed ranks to be recorded. After this waiting period, the Wrapper initiates a termination and restart procedure across all ranks to ensure a consistent recovery process:

  • the Wrapper calls an instance of nvidia_resiliency_ext.inprocess.abort.Abort from a separate Python thread; by default, this operation is equivalent to calling torch.distributed.destroy_process_group(),

  • next the Wrapper raises asynchronous Python exception within the wrapped function; this exception interrupts the execution of the wrapped function, allowing control to return to the Wrapper which then handles the restart process

The termination mechanism respects regular Python exception propagation logic, and gives the wrapped function an opportunity to properly clean up resources by calling all encountered exception handlers, context managers’ __exit__ methods etc. The restart exception raised by the Wrapper is a direct subclass of Python BaseException and it is required that the wrapped function propagates this exception to the outer function scope.

The termination procedure runs in a separate Python thread. In some cases, the main thread - unblocked by the destruction of the distributed process group - might execute a few additional Python bytecode instructions before the asynchronous exception is received. In most cases, it should be harmless as the wrapped function is about to be interrupted and restarted, but the wrapped function must not execute any code that may corrupt persistent storage and prevent correct execution after a restart (e.g. the function cannot write checkpoint to persistent storage). To protect against this possible data corruption, the Wrapper offers inprocess.CallWrapper.atomic() context manager, which implements a lock shared by the main thread and the thread performing the termination procedure. The termination procedure won’t be launched if the main thread is in inprocess.CallWrapper.atomic() code block, and the main thread won’t enter into inprocess.CallWrapper.atomic() code block if termination procedure is already in progress. The use of the inprocess.CallWrapper.atomic() context manager is optional, and may be omitted if the workload already includes mechanisms to guarantee that the restarted wrapped function does not resume execution from a corrupted or incomplete persistent state (e.g., a compromised checkpoint).

Progress timeout

The Wrapper implements two types of timeout events:

Soft timeout

Soft timeout is equivalent to a Python exception raised by one of the ranks, and triggers an attempt to restart the wrapped function on all healthy ranks.

Hard timeout

The hard timeout mechanism forcefully terminates the main Python interpreter process by sending a sequence of signals to ensure proper shutdown.

Initially, the Wrapper sends the signals (SIGCONT, SIGTERM) to allow for a graceful shutdown. If the process remains active after this step, a second sequence of signals (SIGCONT, SIGTERM, SIGKILL) is sent after a delay specified by the termination_grace_time parameter. This guarantees termination of the process if it fails to respond to the initial signals.

The termination_grace_time parameter, configurable via Wrapper, defines the time interval between the two signal sequences. If the workload implements SIGTERM cleanup handlers and their execution is critical for successfully restarting the wrapped function, termination_grace_time should be adjusted to allow sufficient time for these handlers to complete.

For workloads that do not implement SIGTERM handlers, it is safe to set termination_grace_time to 0 seconds to enable faster termination in cases where the process hangs. This minimizes restart latency while ensuring the process is terminated promptly.

Reporting progress

Timeout events are triggered when the wrapped function didn’t report progress in the specified timeout interval.

There are two methods to record progress:

  • Automatic heartbeat: the Wrapper periodically checks if the main thread of the Python interpreter keeps executing new bytecode instructions;

    • this method is always active and protects against hangs in calls that block Python interpreter, even in case when a blocking call released GIL,

    • it doesn’t protect against while-true-like livelocks, where the interpreter keeps executing new bytecode instructions but doesn’t make meaningful forward progress

  • Manual heartbeat (optional): the wrapped function can optionally report progress by periodically calling the inprocess.CallWrapper.ping() method:

Timeout event is triggered if either of the active progress monitoring methods didn’t record a heartbeat in the specified time interval.

Finalize

The Wrapper accepts optional, user-provided nvidia_resiliency_ext.inprocess.finalize.Finalize class. Finalize class is executed after a fault was detected, distributed group was destroyed, but before the HealthCheck is performed. Finalize should bring the process into a state where a restart of the wrapped function may be attempted, e.g.: deinitialize any global variables or synchronize with any async work issued by the wrapped function that was not already performed by exception handlers in the wrapped function. Any failure during the execution of Finalize should raise an exception, in this case the health check is skipped, exception is reraised by the Wrapper, and the exception should cause termination of the main Python interpreter process.

Multiple finalizers could be composed with nvidia_resiliency_ext.inprocess.Compose.

Health check

The Wrapper calls optional, user-provided nvidia_resiliency_ext.inprocess.health_check.HealthCheck class before the restart to ensure that the worker is in a healthy state. HealthCheck is executed after the wrapped function failure was discovered (on local or remote distributed rank), local distributed group was destroyed, and the optional Finalize finished execution. The execution of the health check is local to each rank that could potentially participate in a job after restart, and it is meant to filter out unhealthy ranks that cannot continue executing the workload (e.g. corrupted CUDA context). The execution should be local to the calling rank, other ranks may have already been terminated, lost or still executing the wrapped function. An unhealthy state is reported to nvidia_resiliency_ext.inprocess.Wrapper by raising an exception from inprocess.health_check.HealthCheck.__call__() method. The exception is then reraised by the Wrapper, and should cause termination of the main Python interpreter process on the local rank.

Multiple health checks could be composed with nvidia_resiliency_ext.inprocess.Compose.

Monitoring capabilities

The Wrapper provides several monitoring mechanisms to track the workload’s progress and enable rapid restart capabilities in the event of a fault.

Monitor Thread

The Monitor Thread runs as a separate threading.Thread and is tasked with periodically checking the distributed store for any faults reported by other distributed ranks. It also ensures that the local rank is reporting progress. If a fault or a lack of progress is detected, it triggers nvidia_resiliency_ext.inprocess.abort.Abort and raises asynchronous Python exception within the wrapped function.

The execution interval of the monitoring loop is governed by the monitor_thread_interval parameter of the Wrapper. During each loop iteration, the thread queries the distributed store by invoking torch.distributed.Store.get(). For workloads with a large number of distributed workers, it may be necessary to increase the monitor_thread_interval to avoid creating a communication bottleneck in the distributed store caused by concurrent queries from multiple workers.

Monitor Process

The Monitor Process operates as a separate daemon process created by the Wrapper. Its responsibilities include ensuring the main workload process remains active, submitting heartbeat signals to the distributed store for the local rank, monitoring heartbeat signals from remote ranks, and terminating the main process if it becomes unresponsive and irrecoverable.

The timeout for receiving a heartbeat from other distributed ranks is configured with heartbeat_timeout parameter of the Wrapper. If any of the distributed rank doesn’t submit a heartbeat within heartbeat_timeout interval, the rank is considered unresponsive, and a restart is triggered on all distributed ranks.

The execution interval of the monitoring loop is governed by the monitor_process_interval parameter of the Wrapper. Similar to the Monitor Thread, each iteration of the loop queries the distributed store. To prevent communication bottlenecks in the distributed store, the monitoring interval should scale proportionally with the number of distributed workers to avoid creating a communication bottleneck.

Progress Watchdog

The Monitor Thread runs as a separate threading.Thread and is responsible for issuing automatic heartbeats and receiving manual heartbeats to track the workload’s progress.

The execution interval is governed by the progress_watchdog_interval parameter of the Wrapper. The execution involves only the node-local inter-process communication, and the interval does not need to be scaled with the number of distributed workers.

Logging

The Wrapper leverages the Python logging module to output messages. It does not adhere to the conventional methods of fully integrating with an application’s root logger. Instead, logging from Wrapper within the main process is managed through a logging.StreamHandler, which is defined by the first ancestor in the logger hierarchy. Notably, the logging in Wrapper is configured to not store logs in files, and to not propagate logging messages to the ancestor loggers’ handlers.

Logging with logging.DEBUG level shows the location where the wrapped function suppressed the BaseException raised asynchronously by the Wrapper. The restart logic requires that BaseExceptions are propagated from the wrapped function to the outer scope. This feature helps to find locations where this assumption is not met, and the restart flow is interrupted.

For the monitoring daemon process, logging is handled differently; logs are written only to a file. The location of this log file is configurable. Users can specify a custom path by passing a string to the monitor_process_logfile argument. This string may include the {rank} placeholder, which allows for dynamic filename generation based on the initial distributed rank of the calling process.

Restart latency

Restart latency refers to the time elapsed between a fault occurring on any distributed rank and successfully relaunching the wrapped function across all distributed ranks.

The following table summarizes the latencies of all major items contributing to the total restart latency. Rows marked with (H) increase restart latency only when the application hangs. These items are not included if the application raises a Python exception on any distributed rank.

Category

Item

Latency

NCCL/PyT

torch.distributed.destroy_process_group()

~0.5s + 0.01s * num pending NCCL kernels

CUDA/user

drain all pending CUDA kernels

~training iteration

Wrapper

query TCPStore for any faults

monitor_thread_interval

Wrapper

wait for concurrent faults on other ranks

last_call_wait

Wrapper

execute rank_assignment

~0.5s

Wrapper

3x TCPStore-based barrier

0.5s @ 16k ranks

Wrapper

(H) detect GIL-holding hang

hard_timeout + monitor_process_interval + termination_grace_time

Wrapper

(H) detect GIL-released hang

soft_timeout + monitor_thread_interval

user

execute user-provided finalize

N/A

user

execute user-provided health_check

N/A

The latency for executing torch.distributed.destroy_process_group() assumes that NCCL collective kernel termination interval was optimized. See Known issues for more details.

Known issues

  1. torch.distributed.ProcessGroupGloo doesn’t offer _shutdown() method to terminate pending Gloo collectives (pytorch/#130345); if a rank participating in a Gloo collective stops making forward progress, the remaining ranks would wait till ProcessGroupGloo timeout is exceeded; a workaround is to specify a short timeout for the gloo backend to enable faster restarts.

  2. NCCL collective kernel termination is implemented by periodically checking a flag residing in mapped memory, and exiting from the kernel if the flag is set. Interval of checking for this flag is controlled by NCCL_SPINS_BEFORE_CHECK_ABORT value specified in nccl/src/device/primitives.h:15. The current value of NCCL_SPINS_BEFORE_CHECK_ABORT=1000000 may be too high to quickly terminate NCCL if multiple collective kernels are being executed or are pending. A workaround is to decrease the interval to 10000 and rebuild NCCL. This issue will be addressed in future NCCL versions.

  3. To perform a restart, the nvidia_resiliency_ext.inprocess.Wrapper needs to wait for completion of all executing and pending CUDA kernels. This is implemented with a GPU synchronization, and is a part of nvidia_resiliency_ext.inprocess.health_check.CudaHealthCheck. Waiting for CUDA kernels to complete could increase the restart latency if many CUDA kernels are pending execution. A workaround is to periodically synchronize with the GPU from the wrapped function to reduce the depth of pending kernels queue.

  4. Support for NVLink SHARP (NVLS) in NCCL must be disabled by setting the NCCL_NVLS_ENABLE environment variable to 0.

  5. NCCL net plugins must be disabled by setting NCCL_NET_PLUGIN environment variable to "none". This issue will be addressed in future NCCL versions.

  6. nvidia_resiliency_ext.inprocess.Wrapper is not fully compatible with torch.distributed.run(). torch.distributed.run() automatically terminates all worker processes if any one of them fails, in this case nvidia_resiliency_ext.inprocess.Wrapper can only recover from transient faults that don’t cause termination of worker processes.

  7. By default, PyTorch NCCL Watchdog forcefully terminates the process if NCCL call returns an error, or if CUDA context was corrupted. Forceful termination of the worker process prevents nvidia_resiliency_ext.inprocess.Wrapper from restarting the wrapper function. A workaround is to set TORCH_NCCL_RETHROW_CUDA_ERRORS environment variable to 0, to avoid rethrowing CUDA and NCCL errors in PyTorch NCCL Watchdog.