PTL Callback support

class nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback.HierarchicalCheckpointIO(wrapped_checkpoint_io, local_ckpt_manager, get_global_ckpt_iteration_fn)[source]

Bases: _WrappingCheckpointIO

Wrapper for a global CheckpointIO enabling local checkpointing.

Based on the presence of local checkpointing options in saving storage_options, routes the save to the original global CheckpointIO or the local checkpoint manager.

Must be used in conjunction with LocalCheckpointCallback which initiates local checkpoint saving during training.

Parameters:
  • wrapped_checkpoint_io (CheckpointIO) – global CheckpointIO to wrap

  • local_ckpt_manager (BaseCheckpointManager) – local manager to use for local checkpoints

  • get_global_ckpt_iteration_fn (Callable[[_PATH], int]) – a function that given a path to a global checkpoint, extracts the global step iteration from it (either from the path itself or by loading metadata from the checkpoint).

from_tensor_aware_state_dict(tensor_aware_checkpoint, **kwargs)[source]
Parameters:

tensor_aware_checkpoint (TensorAwareStateDict)

classmethod get_partial_wrapper_constructor(local_ckpt_manager, get_global_ckpt_iteration_fn)[source]

Allows to provide all arguments to the constructor except for the wrapped checkpoint io.

Parameters:
load_checkpoint(path, map_location=None, **kwargs)[source]

Load the newer of local (if available) and global checkpoint.

Parameters:
Return type:

Dict[str, Any]

remove_checkpoint(path)[source]

Checkpoint removal is handled independently by the LocalCkptManager.

Parameters:

path (str | Path)

Return type:

None

save_checkpoint(checkpoint, path, storage_options=None)[source]

Save local or global checkpoint, depending on the presence of options.

Parameters:
Return type:

None

abstract to_tensor_aware_state_dict(checkpoint)[source]
Parameters:

checkpoint (Dict[str, Any])

Return type:

TensorAwareStateDict

class nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback.LocalCheckpointCallback(every_n_train_steps=None, train_time_interval=None, async_save=False)[source]

Bases: ModelCheckpoint

ModelCheckpoint with basic functionality. Only train_batch_end simple save.

Simple callback for initiating local checkpoint save in on_train_batch_end method. Since local checkpoints are ephemeral, they shouldn’t be used for “major” checkpoint types like on_train_epoch_end.

This callback must be used in conjunction with the HierarchicalCheckpointIO, since the only thing this callback really does is passing some options to trainer.save_checkpoint which can be captured with HierarchicalCheckpointIO.

Parameters:
  • every_n_train_steps (int, optional) – controls local checkpointing interval in terms of train iterations. Same semantic as in PTL ModelCheckpoint.

  • train_time_interval (int, optional) – controls local checkpointing interval in terms of wall time. Same semantics as in PTL ModelCheckpoint.

  • async_save (bool)

on_train_epoch_end(trainer, pl_module)[source]

Skips super functionality

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

Return type:

None

on_validation_end(trainer, pl_module)[source]

Skips super functionality

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

Return type:

None