Usage Guide
The nvidia_resiliency_ext.inprocess.Wrapper
serves as the primary interface for accessing
in-process restart functionality. It provides various configuration options
through its arguments, enabling customization of the restart process and fault
monitoring capabilities. To ensure efficient and effective restarts, the
function being wrapped must meet specific requirements. This usage guide
outlines the requirements, features, and limitations of the in-process restart
functionality provided by the Wrapper
.
Requirements
In-process restart functionality requires PyPI PyTorch v2.5.1 or PyTorch NGC Container versions 24.07 through 24.10. For further limitations and compatibility details, refer to the Known issues section.
Requirements for the wrapped function
The wrapped function should be designed to support restarts, meaning it should carefully manage any external (e.g., global) state and resources, avoid using functions that can only be called once per process, such as
multiprocessing.set_start_method()
orMPI_Init
, to ensure that the function can be executed multiple times in the same process without issues.The function will automatically retry on any failure, meaning it will be called again with the same set of input arguments; extra caution is needed if the function accepts mutable arguments that might be modified during its execution, as these changes could affect subsequent retries.
All operations that wait on results from NCCL kernels, or synchronize with the GPU, need to release Python Global Interpreter Lock (GIL).
If the Python GIL is not released when a fault occurs, the graceful restart procedure cannot proceed. This is because the procedure runs in a separate Python thread, which is blocked from execution due to the GIL being held. As a result, hung ranks must be forcibly terminated using the hard timeout mechanism (
SIGKILL
). These terminated ranks will not rejoin the distributed job upon restart.
The function does not suppress
BaseException
. If the wrapped function catches aBaseException
, it must re-raise it to ensure it propagates to the outer scope.The function is responsible for initialization of PyTorch distributed backend (
torch.distributed.init_process_group()
); the initialization needs to read standard PyTorch distributed variables (RANK
,WORLD_SIZE
,MASTER_ADDR
,MASTER_PORT
andLOCAL_RANK
) from the environment.it’s heavily recommended for the wrapped function to load the state affected by distributed collectives from a checkpoint on every restart (e.g. load weights of a model); outputs of distributed collectives are likely to become corrupted or invalid if a fault happened while a collective was in-flight and distributed backend was terminated.
Requirements for the execution environment
The PyTorch NCCL watchdog must either be disabled or configured with a timeout longer than the
hard_timeout
of thenvidia_resiliency_ext.inprocess.Wrapper
. If the NCCL watchdog is triggered, it forcibly terminates the process, preventing a restart. To adjust the NCCL watchdog timeout, use thetimeout
argument when callingtorch.distributed.init_process_group()
with thebackend
parameter set to"nccl"
The job scheduler must not terminate the entire job if a faulty rank exits early or if the main process is terminated; instead, it should wait until all user-launched processes have fully exited before ending the distributed job.
Restrictions
node failure on rank 0 causes termination of the entire job; by default, rank 0 hosts internal
torch.distributed.TCPStore
to allow communication between ranks, users may specify a different implementation of a distributed store by subclassing fromnvidia_resiliency_ext.inprocess.store.StoreMixin
and passing the subclass asstore_factory
argument to thenvidia_resiliency_ext.inprocess.Wrapper
blocking calls issued by the main process are generally not recoverable if they hang, except for NCCL collectives or functions waiting on them; NCCL collectives are asynchronously aborted by a separate monitoring thread that calls
nvidia_resiliency_ext.inprocess.abort.AbortTorchDistributed
; users can specify additionalnvidia_resiliency_ext.inprocess.abort.Abort
subclasses to asynchronously abort blocking calls from other software components.
Functionality overview
Implementation overview
Below is a simplified pseudocode snippet that illustrates the order of
operations executed by nvidia_resiliency_ext.inprocess.Wrapper
, providing a high-level
overview of the workflow within this class. This code is for illustrative
purposes only and may omit certain implementation details.
distributed_store = store_factory(**store_kwargs)
initial_barrier()
rank_assignment()
rank_filter()
while True:
iteration_barrier()
initialize()
health_check()
try:
if rank_is_active:
wrapped_function()
else:
sleep()
completion_barrier()
except:
abort()
finalize()
health_check()
termination_barrier()
rank_assignment()
rank_filter()
else:
break
Distributed execution behavior
Entering and exiting the Wrapper
act as distributed synchronization
points. Upon entry, all workers retrieve their initial rank assignments and the
total number of workers by reading the standard PyTorch distributed environment
variables (RANK
, WORLD_SIZE
). Subsequently, all workers synchronize
through a initial_barrier
using a user-defined barrier_timeout
to
ensure consistent initialization.
Upon completion of the wrapped function, all ranks that finish enter a
completion_barrier
governed by a user-defined completion_timeout
. If
any rank fails to synchronize within the completion_timeout
, it is treated
as a rank failure, triggering a restart of the wrapped function on all
distributed ranks.
The restart Wrapper
incorporates additional distributed barriers to
ensure proper synchronization: iteration_barrier
(executed at the start of
each restart iteration), and termination_barrier
(executed before rank
reassignment and filtering). These barriers are designed to be transparent to
the user, requiring no modifications to the wrapped function or assumptions
about the execution environment. They operate seamlessly to maintain
distributed consistency and coordination.
Rank assignment
The Wrapper
needs to ensure that the wrapped function is restarted
with a consecutive sequence of integer rank indices, from 0
to
WORLD_SIZE - 1
, as some of the ranks from previous iteration may have been
terminated or are in an unhealthy state. Rank reassignment and new world size
computation is performed by
nvidia_resiliency_ext.inprocess.rank_assignment.RankAssignment
instance passed as
rank_assignment
argument to the Wrapper
.
Multiple RankAssignments could be composed with nvidia_resiliency_ext.inprocess.Compose
to achieve the desired behavior.
For example:
rank_assignment = (
inprocess.Compose(
inprocess.rank_assignment.ShiftRanks(),
inprocess.rank_assignment.FilterGroupedByKey(
key_or_fn=lambda rank, _: rank // 8,
condition=lambda count: count == 8,
),
),
),
ensures that all ranks within each non-overlapping group of 8 consecutive
ranks remain healthy. If any rank within a group of 8 is unhealthy or
terminated, the entire group is terminated. The remaining healthy ranks are
then reassigned by shifting left to close any gaps, forming a new sequence
of consecutive integers from 0
up to the updated world size
.
Rank filter
By default, all active ranks are calling the wrapped function. This behavior
can be customized by providing a nvidia_resiliency_ext.inprocess.rank_filter.RankFilter
instance as rank_filter
argument for the Wrapper
.
RankFilter
selects which ranks are active in the current restart
iteration. Active ranks call the wrapped function. Inactive ranks are waiting
idle, and could serve as a static, preallocated and preinitialized pool of
spare ranks. Spare ranks would be activated in a subsequent restart iteration
if previously active ranks were terminated or became unhealthy.
Multiple rank filters could be composed with nvidia_resiliency_ext.inprocess.Compose
to
achieve the desired behavior. For example:
rank_filter=inprocess.Compose(
inprocess.rank_filter.WorldSizeDivisibleBy(M),
inprocess.rank_filter.MaxActiveWorldSize(N),
),
ensures that the active world size visible to the wrapped function is the
largest multiple of M
that is not greater than N
. The remaining ranks
would be inactive and serve as spares.
Initialize
The Wrapper
accepts an optional, user-provided
nvidia_resiliency_ext.inprocess.initialize.Initialize
class, which is executed at the
start of every restart iteration, including the first one.
Initialize
can raise exceptions (e.g., if specific preconditions
are not met). Raising a standard Python Exception
triggers another
restart of the wrapped function, while raising a BaseException
terminates the Wrapper
. The included
nvidia_resiliency_ext.inprocess.initialize.RetryController
can be used to limit the
number of restart attempts or to halt execution if the number of healthy
workers drops below a specified threshold.
Multiple initializers could be composed with nvidia_resiliency_ext.inprocess.Compose
.
Wrapped function termination mechanism
When a fault or timeout occurs on any rank participating in the distributed
job, the Wrapper
waits for the last_call_wait
interval to allow
all concurrent faults from other distributed ranks to be recorded. After this
waiting period, the Wrapper
initiates a termination and restart
procedure across all ranks to ensure a consistent recovery process:
the
Wrapper
calls an instance ofnvidia_resiliency_ext.inprocess.abort.Abort
from a separate Python thread; by default, this operation is equivalent to callingtorch.distributed.destroy_process_group()
,next the
Wrapper
raises asynchronous Python exception within the wrapped function; this exception interrupts the execution of the wrapped function, allowing control to return to theWrapper
which then handles the restart process
The termination mechanism respects regular Python exception propagation logic,
and gives the wrapped function an opportunity to properly clean up resources by
calling all encountered exception handlers, context managers’ __exit__
methods etc. The restart exception raised by the Wrapper
is a
direct subclass of Python BaseException
and it is required that the
wrapped function propagates this exception to the outer function scope.
The termination procedure runs in a separate Python thread. In some cases, the
main thread - unblocked by the destruction of the distributed process group -
might execute a few additional Python bytecode instructions before the
asynchronous exception is received. In most cases, it should be harmless as the
wrapped function is about to be interrupted and restarted, but the wrapped
function must not execute any code that may corrupt persistent storage and
prevent correct execution after a restart (e.g. the function cannot write
checkpoint to persistent storage). To protect against this possible data
corruption, the Wrapper
offers
inprocess.CallWrapper.atomic()
context manager, which implements a
lock shared by the main thread and the thread performing the termination
procedure. The termination procedure won’t be launched if the main thread is in
inprocess.CallWrapper.atomic()
code block, and the main thread won’t
enter into inprocess.CallWrapper.atomic()
code block if termination
procedure is already in progress. The use of the
inprocess.CallWrapper.atomic()
context manager is optional, and may be
omitted if the workload already includes mechanisms to guarantee that the
restarted wrapped function does not resume execution from a corrupted or
incomplete persistent state (e.g., a compromised checkpoint).
Progress timeout
The Wrapper
implements two types of timeout events:
Soft timeout
Soft timeout is equivalent to a Python exception raised by one of the ranks, and triggers an attempt to restart the wrapped function on all healthy ranks.
Hard timeout
The hard timeout mechanism forcefully terminates the main Python interpreter process by sending a sequence of signals to ensure proper shutdown.
Initially, the Wrapper
sends the signals (SIGCONT
, SIGTERM
)
to allow for a graceful shutdown. If the process remains active after this
step, a second sequence of signals (SIGCONT
, SIGTERM
, SIGKILL
) is
sent after a delay specified by the termination_grace_time
parameter. This
guarantees termination of the process if it fails to respond to the initial
signals.
The termination_grace_time
parameter, configurable via Wrapper
,
defines the time interval between the two signal sequences. If the workload
implements SIGTERM
cleanup handlers and their execution is critical for
successfully restarting the wrapped function, termination_grace_time
should
be adjusted to allow sufficient time for these handlers to complete.
For workloads that do not implement SIGTERM
handlers, it is safe to set
termination_grace_time
to 0 seconds to enable faster termination in cases
where the process hangs. This minimizes restart latency while ensuring the
process is terminated promptly.
Reporting progress
Timeout events are triggered when the wrapped function didn’t report progress in the specified timeout interval.
There are two methods to record progress:
Automatic heartbeat: the
Wrapper
periodically checks if the main thread of the Python interpreter keeps executing new bytecode instructions;this method is always active and protects against hangs in calls that block Python interpreter, even in case when a blocking call released GIL,
it doesn’t protect against while-true-like livelocks, where the interpreter keeps executing new bytecode instructions but doesn’t make meaningful forward progress
Manual heartbeat (optional): the wrapped function can optionally report progress by periodically calling the
inprocess.CallWrapper.ping()
method:the
nvidia_resiliency_ext.inprocess.Wrapper
inspects the signature of the wrapped function for an argument annotated with the typenvidia_resiliency_ext.inprocess.CallWrapper
,if such an argument is present, the
Wrapper
injects an instance ofnvidia_resiliency_ext.inprocess.CallWrapper
into the function, enabling it to callinprocess.CallWrapper.ping()
within its scope,the timeout for the manual heartbeat is activated after the first call to the
inprocess.CallWrapper.ping()
method.
Timeout event is triggered if either of the active progress monitoring methods didn’t record a heartbeat in the specified time interval.
Finalize
The Wrapper
accepts optional, user-provided
nvidia_resiliency_ext.inprocess.finalize.Finalize
class. Finalize
class is
executed after a fault was detected, distributed group was destroyed, but
before the HealthCheck
is performed. Finalize
should
bring the process into a state where a restart of the wrapped function may be
attempted, e.g.: deinitialize any global variables or synchronize with any
async work issued by the wrapped function that was not already performed by
exception handlers in the wrapped function. Any failure during the execution of
Finalize
should raise an exception, in this case the health check
is skipped, exception is reraised by the Wrapper
, and the exception
should cause termination of the main Python interpreter process.
Multiple finalizers could be composed with nvidia_resiliency_ext.inprocess.Compose
.
Health check
The Wrapper
calls optional, user-provided
nvidia_resiliency_ext.inprocess.health_check.HealthCheck
class before the restart to
ensure that the worker is in a healthy state. HealthCheck
is
executed after the wrapped function failure was discovered (on local or remote
distributed rank), local distributed group was destroyed, and the optional
Finalize
finished execution. The execution of the health check is
local to each rank that could potentially participate in a job after restart,
and it is meant to filter out unhealthy ranks that cannot continue executing
the workload (e.g. corrupted CUDA context). The execution should be local to
the calling rank, other ranks may have already been terminated, lost or still
executing the wrapped function. An unhealthy state is reported to
nvidia_resiliency_ext.inprocess.Wrapper
by raising an exception from
inprocess.health_check.HealthCheck.__call__()
method. The exception is
then reraised by the Wrapper
, and should cause termination of the
main Python interpreter process on the local rank.
Multiple health checks could be composed with nvidia_resiliency_ext.inprocess.Compose
.
Monitoring capabilities
The Wrapper
provides several monitoring mechanisms to track the
workload’s progress and enable rapid restart capabilities in the event of a
fault.
Monitor Thread
The Monitor Thread runs as a separate threading.Thread
and is
tasked with periodically checking the distributed store for any faults reported
by other distributed ranks. It also ensures that the local rank is
reporting progress. If a fault or a lack of
progress is detected, it triggers nvidia_resiliency_ext.inprocess.abort.Abort
and raises
asynchronous Python exception within the wrapped function.
The execution interval of the monitoring loop is governed by the
monitor_thread_interval
parameter of the Wrapper
. During each
loop iteration, the thread queries the distributed store by invoking
torch.distributed.Store.get()
. For workloads with a large number of
distributed workers, it may be necessary to increase the
monitor_thread_interval
to avoid creating a communication bottleneck in the
distributed store caused by concurrent queries from multiple workers.
Monitor Process
The Monitor Process operates as a separate daemon process created by the
Wrapper
. Its responsibilities include ensuring the main workload
process remains active, submitting heartbeat signals to the distributed store
for the local rank, monitoring heartbeat signals from remote ranks, and
terminating the main process if it becomes unresponsive and irrecoverable.
The timeout for receiving a heartbeat from other distributed ranks is
configured with heartbeat_timeout
parameter of the Wrapper
. If
any of the distributed rank doesn’t submit a heartbeat within
heartbeat_timeout
interval, the rank is considered unresponsive, and a
restart is triggered on all distributed ranks.
The execution interval of the monitoring loop is governed by the
monitor_process_interval
parameter of the Wrapper
. Similar to
the Monitor Thread, each iteration of the loop queries
the distributed store. To prevent communication bottlenecks in the distributed
store, the monitoring interval should scale proportionally with the number of
distributed workers to avoid creating a communication bottleneck.
Progress Watchdog
The Monitor Thread runs as a separate threading.Thread
and is
responsible for issuing automatic heartbeats and receiving manual heartbeats to
track the workload’s progress.
The execution interval is governed by the progress_watchdog_interval
parameter of the Wrapper
. The execution involves only the
node-local inter-process communication, and the interval does not need to be
scaled with the number of distributed workers.
Logging
The Wrapper
leverages the Python logging module to output messages.
It does not adhere to the conventional methods of fully integrating with an
application’s root logger. Instead, logging from Wrapper
within the
main process is managed through a logging.StreamHandler
, which is
defined by the first ancestor in the logger hierarchy. Notably, the logging in
Wrapper
is configured to not store logs in files, and to not
propagate
logging messages to the ancestor loggers’ handlers.
Logging with logging.DEBUG
level shows the location where the wrapped
function suppressed the BaseException
raised asynchronously by the
Wrapper
. The restart logic requires that BaseExceptions are
propagated from the wrapped function to the outer scope. This feature helps to
find locations where this assumption is not met, and the restart flow is
interrupted.
For the monitoring daemon process, logging is handled differently; logs are
written only to a file. The location of this log file is configurable. Users
can specify a custom path by passing a string to the
monitor_process_logfile
argument. This string may include the {rank}
placeholder, which allows for dynamic filename generation based on the initial
distributed rank of the calling process.
Restart latency
Restart latency refers to the time elapsed between a fault occurring on any distributed rank and successfully relaunching the wrapped function across all distributed ranks.
The following table summarizes the latencies of all major items contributing to
the total restart latency. Rows marked with (H)
increase restart latency
only when the application hangs. These items are not included if the
application raises a Python exception on any distributed rank.
Category |
Item |
Latency |
---|---|---|
NCCL/PyT |
|
~0.5s + 0.01s * num pending NCCL kernels |
CUDA/user |
drain all pending CUDA kernels |
~training iteration |
Wrapper |
query TCPStore for any faults |
|
Wrapper |
wait for concurrent faults on other ranks |
|
Wrapper |
execute |
~0.5s |
Wrapper |
3x TCPStore-based barrier |
0.5s @ 16k ranks |
Wrapper |
|
|
Wrapper |
|
|
user |
execute user-provided |
N/A |
user |
execute user-provided |
N/A |
The latency for executing torch.distributed.destroy_process_group()
assumes that NCCL collective kernel termination interval was optimized. See
Known issues for more details.
Known issues
torch.distributed.ProcessGroupGloo
doesn’t offer_shutdown()
method to terminate pending Gloo collectives (pytorch/#130345); if a rank participating in a Gloo collective stops making forward progress, the remaining ranks would wait tillProcessGroupGloo
timeout is exceeded; a workaround is to specify a short timeout for thegloo
backend to enable faster restarts.NCCL collective kernel termination is implemented by periodically checking a flag residing in mapped memory, and exiting from the kernel if the flag is set. Interval of checking for this flag is controlled by
NCCL_SPINS_BEFORE_CHECK_ABORT
value specified in nccl/src/device/primitives.h:15. The current value ofNCCL_SPINS_BEFORE_CHECK_ABORT=1000000
may be too high to quickly terminate NCCL if multiple collective kernels are being executed or are pending. A workaround is to decrease the interval to10000
and rebuild NCCL. This issue will be addressed in future NCCL versions.To perform a restart, the
nvidia_resiliency_ext.inprocess.Wrapper
needs to wait for completion of all executing and pending CUDA kernels. This is implemented with a GPU synchronization, and is a part ofnvidia_resiliency_ext.inprocess.health_check.CudaHealthCheck
. Waiting for CUDA kernels to complete could increase the restart latency if many CUDA kernels are pending execution. A workaround is to periodically synchronize with the GPU from the wrapped function to reduce the depth of pending kernels queue.Support for NVLink SHARP (NVLS) in NCCL must be disabled by setting the
NCCL_NVLS_ENABLE
environment variable to0
.NCCL net plugins must be disabled by setting
NCCL_NET_PLUGIN
environment variable to"none"
. This issue will be addressed in future NCCL versions.nvidia_resiliency_ext.inprocess.Wrapper
is not fully compatible withtorch.distributed.run()
.torch.distributed.run()
automatically terminates all worker processes if any one of them fails, in this casenvidia_resiliency_ext.inprocess.Wrapper
can only recover from transient faults that don’t cause termination of worker processes.By default, PyTorch NCCL Watchdog forcefully terminates the process if NCCL call returns an error, or if CUDA context was corrupted. Forceful termination of the worker process prevents
nvidia_resiliency_ext.inprocess.Wrapper
from restarting the wrapper function. A workaround is to setTORCH_NCCL_RETHROW_CUDA_ERRORS
environment variable to0
, to avoid rethrowing CUDA and NCCL errors in PyTorch NCCL Watchdog.