Callback

class nvidia_resiliency_ext.ptl_resiliency.fault_tolerance_callback.FaultToleranceCallback(autoresume, calculate_timeouts, simulated_fault_params=None, exp_dir=None, logger_name='nemo_logger.FaultToleranceCallback')[source]

Bases: Callback

FaultToleranceCallback is a Torch Lightning callback for integration with the Fault Tolerance package.

FT is only active during a ‘fit’ stage. Training should be run with ‘ft_launcher’ for the callback to work.

Initialize callback instance.

This is a lightweight initialization. Most of the initialization is conducted in the ‘setup’ hook.

Parameters:
  • autoresume (bool) – Set to True if the FT auto-resume feature is used (e.g., there are multiple training jobs to be run).

  • calculate_timeouts (bool) – Set to True if FT timeouts should be calculated based on observed heartbeat intervals. Calculated timeouts overwrite the timeouts from the FT config. Timeouts are computed at the end of a training job, if there was checkpoint loading and saving. For example, for training started from scratch, the timeouts are computed at the end of the second job.

  • simulated_fault_params (SimulatedFaultParams, dict, DictConfig, None) – Simulated fault spec. It’s for debugging only. Defaults to None. Should be a SimulatedFaultParams instance or any object that can be used for SimulatedFaultParams initialization with SimulatedFaultParams(**obj).

  • exp_dir (Union[str, pathlib.Path, None], optional) – Directory where the FT state should be saved. Must be available for all training jobs. NOTE: Beware that PTL can move files written to its trainer.log_dir. Defaults to None, in which case it defaults to trainer.log_dir/ft_state.

  • logger_name (Optional[str], optional) – Logger name to be used. Defaults to “nemo_logger.FaultToleranceCallback”.

on_exception(trainer, pl_module, exception)[source]

Called when any trainer execution is interrupted by an exception.

on_load_checkpoint(*args, **kwargs)[source]

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the full checkpoint dictionary that got loaded by the Trainer.

on_save_checkpoint(*args, **kwargs)[source]

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the checkpoint dictionary that will be saved.

on_train_batch_end(*args, **kwargs)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_end(*args, **kwargs)[source]

Called when the train ends.

on_train_start(*args, **kwargs)[source]

Called when the train begins.

on_validation_batch_end(*args, **kwargs)[source]

Called when the validation batch ends.

on_validation_end(*args, **kwargs)[source]

Called when the validation loop ends.

on_validation_start(*args, **kwargs)[source]

Called when the validation loop begins.

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

teardown(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune ends.