Sections API Integration
1. Prerequisites
Run ranks using
ft_launcher
. The command line is mostly compatible withtorchrun
.Pass the FT config to the
ft_launcher
.
Note
Some clusters (e.g., SLURM) use SIGTERM as a default method of requesting a graceful workload shutdown. It is recommended to implement appropriate signal handling in a fault-tolerant workload. To avoid deadlocks and other unintended side effects, signal handling should be synchronized across all ranks.
2. FT configuration
With the section-based API, timeouts must be set for all defined sections, which wrap operations like training/eval steps, checkpoint saving, and initialization. Additionally, an out-of-section timeout applies when no section is active.
Note
Ensure out-of-section timeout is long enough to accommodate restart overhead, as excessively small values can cause imbalance. If needed, consider merging sections (e.g., moving ‘init’ into ‘out-of-section’) to provide more buffer time.
- Relevant FT configuration items are:
rank_section_timeouts
is a map of a section name to its timeout.rank_out_of_section_timeout
is the out-of-section timeout.
Fixed timeout values can be used throughout the training runs, or timeouts can be calculated based on observed intervals. null timeout values are interpreted as infinite timeouts. In such cases, values need to be calculated to make the FT usable.
Config file example:
1fault_tolerance:
2 initial_rank_heartbeat_timeout: null
3 rank_heartbeat_timeout: null
4 rank_section_timeouts:
5 init: 20
6 step: 10
7 checkpoint: 30
8 rank_out_of_section_timeout: 30
9 log_level: "DEBUG"
A summary of all FT configuration items can be found in nvidia_resiliency_ext.fault_tolerance.config.FaultToleranceConfig
3. Integration with PyTorch workload code
Initialize a
RankMonitorClient
instance on each rank withRankMonitorClient.init_workload_monitoring()
.(Optional) Restore the state of
RankMonitorClient
instances usingRankMonitorClient.load_state_dict()
.Mark some sections of the code with
RankMonitorClient.start_section('<section name>')
andRankMonitorClient.end_section('<section name>')
.(Optional) After a sufficient range of section intervals has been observed, call
RankMonitorClient.calculate_and_set_section_timeouts()
to estimate timeouts.(Optional) Save the
RankMonitorClient
instance’sstate_dict()
to a file so that computed timeouts can be reused in the next run.Shut down
RankMonitorClient
instances usingRankMonitorClient.shutdown_workload_monitoring()
.
Please refer to the Section API usage example with DDP for an implementation example.