Callback

class nvidia_resiliency_ext.ptl_resiliency.fault_tolerance_callback.FaultToleranceCallback(autoresume, calculate_timeouts, simulated_fault_params=None, exp_dir=None, logger_name='nemo_logger.FaultToleranceCallback')[source]

Bases: Callback

FaultToleranceCallback is a Torch Lightning callback for integration with the Fault Tolerance package.

FT is only active during a ‘fit’ stage. Training should be run with ‘ft_launcher’ for the callback to work.

Initialize callback instance.

This is a lightweight initialization. Most of the initialization is conducted in the ‘setup’ hook.

Parameters:
  • autoresume (bool) – Set to True if the FT auto-resume feature is used (e.g., there are multiple training jobs to be run).

  • calculate_timeouts (bool) – Set to True if FT timeouts should be calculated based on observed heartbeat intervals. Calculated timeouts overwrite the timeouts from the FT config. Timeouts are computed at the end of a training job, if there was checkpoint loading and saving. For example, for training started from scratch, the timeouts are computed at the end of the second job.

  • simulated_fault_params (SimulatedFaultParams, dict, DictConfig, None) – Simulated fault spec. It’s for debugging only. Defaults to None. Should be a SimulatedFaultParams instance or any object that can be used for SimulatedFaultParams initialization with SimulatedFaultParams(**obj).

  • exp_dir (Union[str, pathlib.Path, None], optional) – Directory where the FT state should be saved. Must be available for all training jobs. NOTE: Beware that PTL can move files written to its trainer.log_dir. Defaults to None, in which case it defaults to trainer.log_dir/ft_state.

  • logger_name (Optional[str], optional) – Logger name to be used. Defaults to “nemo_logger.FaultToleranceCallback”.

on_exception(trainer, pl_module, exception)[source]

Called when any trainer execution is interrupted by an exception.

on_load_checkpoint(*args, **kwargs)[source]

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the full checkpoint dictionary that got loaded by the Trainer.

on_save_checkpoint(*args, **kwargs)[source]

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the checkpoint dictionary that will be saved.

on_train_batch_end(*args, **kwargs)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_end(*args, **kwargs)[source]

Called when the train ends.

on_train_start(*args, **kwargs)[source]

Called when the train begins.

on_validation_batch_end(*args, **kwargs)[source]

Called when the validation batch ends.

on_validation_end(*args, **kwargs)[source]

Called when the validation loop ends.

on_validation_start(*args, **kwargs)[source]

Called when the validation loop begins.

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

teardown(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune ends.

class nvidia_resiliency_ext.ptl_resiliency.fault_tolerance_callback.SimulatedFaultParams(fault_type, base_delay, rand_delay=0.0, rank_to_fail=None)[source]

Bases: object

Description of a simulated rank fault, used for FT testing and debugging.

Simulated fault types are: - ‘rank_killed’ a rank is killed with SIGKILL - ‘rank_hung’ a rank is stopped with SIGSTOP - ‘random’ randomly selects one of the above faults.

Fault delay is computed as: - base_delay + RAND_FLOAT_FROM_0.0_to_1.0 * rand_delay

Parameters:
  • fault_type (str)

  • base_delay (float)

  • rand_delay (float)

  • rank_to_fail (int | None)

fault_type

The type of fault, one of: [‘random’, ‘rank_killed’, ‘rank_hung’].

Type:

str

base_delay

The base (minimum) delay [seconds] for the fault.

Type:

float

rand_delay

The max additional random delay for the fault. Defaults to 0.0.

Type:

float, optional

rank_to_fail

The rank to fail. Defaults to None - random rank will be picked.

Type:

int, optional