PTL Callback support
- class nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback.HierarchicalCheckpointIO(wrapped_checkpoint_io, local_ckpt_manager, get_global_ckpt_iteration_fn)[source]
Bases:
_WrappingCheckpointIO
Wrapper for a global CheckpointIO enabling local checkpointing.
Based on the presence of local checkpointing options in saving storage_options, routes the save to the original global CheckpointIO or the local checkpoint manager.
Must be used in conjunction with LocalCheckpointCallback which initiates local checkpoint saving during training.
- Parameters:
wrapped_checkpoint_io (CheckpointIO) – global CheckpointIO to wrap
local_ckpt_manager (BaseCheckpointManager) – local manager to use for local checkpoints
get_global_ckpt_iteration_fn (Callable[[_PATH], int]) – a function that given a path to a global checkpoint, extracts the global step iteration from it (either from the path itself or by loading metadata from the checkpoint).
- from_tensor_aware_state_dict(tensor_aware_checkpoint, **kwargs)[source]
- Parameters:
tensor_aware_checkpoint (TensorAwareStateDict)
- classmethod get_partial_wrapper_constructor(local_ckpt_manager, get_global_ckpt_iteration_fn)[source]
Allows to provide all arguments to the constructor except for the wrapped checkpoint io.
- Parameters:
local_ckpt_manager (BaseCheckpointManager)
- load_checkpoint(path, map_location=None, **kwargs)[source]
Load the newer of local (if available) and global checkpoint.
- remove_checkpoint(path)[source]
Checkpoint removal is handled independently by the LocalCkptManager.
- class nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback.LocalCheckpointCallback(every_n_train_steps=None, train_time_interval=None, async_save=False)[source]
Bases:
ModelCheckpoint
ModelCheckpoint with basic functionality. Only train_batch_end simple save.
Simple callback for initiating local checkpoint save in on_train_batch_end method. Since local checkpoints are ephemeral, they shouldn’t be used for “major” checkpoint types like on_train_epoch_end.
This callback must be used in conjunction with the HierarchicalCheckpointIO, since the only thing this callback really does is passing some options to trainer.save_checkpoint which can be captured with HierarchicalCheckpointIO.
- Parameters:
every_n_train_steps (int, optional) – controls local checkpointing interval in terms of train iterations. Same semantic as in PTL ModelCheckpoint.
train_time_interval (int, optional) – controls local checkpointing interval in terms of wall time. Same semantics as in PTL ModelCheckpoint.
async_save (bool)