# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from abc import abstractmethod
from datetime import timedelta
from functools import partial
from typing import Any, Callable, Dict, NewType, Optional
from ._utils import is_module_available
if is_module_available("lightning"):
import lightning.pytorch as pl
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
elif is_module_available("pytorch_lightning"):
import pytorch_lightning as pl
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
else:
raise ImportError("Could not find 'lightning' or 'pytorch_lightning' module")
from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning_fabric.utilities.types import _PATH
from torch import Tensor
from nvidia_resiliency_ext.checkpointing.local.base_state_dict import TensorAwareStateDict
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.base_manager import (
BaseCheckpointManager,
)
logger = logging.getLogger(__name__)
StateDict = NewType('StateDict', Any)
LOCAL_CKPT_OPTS_KEY = 'local_checkpoint_options'
[docs]
class LocalCheckpointCallback(pl.callbacks.ModelCheckpoint):
"""ModelCheckpoint with basic functionality. Only train_batch_end simple save.
Simple callback for initiating local checkpoint save in `on_train_batch_end` method.
Since local checkpoints are ephemeral, they shouldn't be used for "major" checkpoint
types like `on_train_epoch_end`.
This callback must be used in conjunction with the HierarchicalCheckpointIO,
since the only thing this callback really does is passing some options
to `trainer.save_checkpoint` which can be captured with HierarchicalCheckpointIO.
Args:
every_n_train_steps (int, optional): controls local checkpointing interval in terms
of train iterations. Same semantic as in PTL ModelCheckpoint.
train_time_interval (int, optional): controls local checkpointing interval in terms
of wall time. Same semantics as in PTL ModelCheckpoint.
"""
def __init__(
self,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
async_save: bool = False,
):
super().__init__(
every_n_train_steps=every_n_train_steps,
train_time_interval=train_time_interval,
)
self.async_save = async_save
[docs]
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Skips super functionality"""
logger.info('Skipping on_train_epoch_end local ckpt save')
[docs]
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Skips super functionality"""
logger.info('Skipping on_validation_end local ckpt save')
def _save_topk_checkpoint(
self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
) -> None:
"""Skips super functionality"""
logger.info('Skipping _save_topk_checkpoint local ckpt save')
def _save_last_checkpoint(
self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
) -> None:
"""Simply saves a local checkpoint with appropriate storage_options."""
local_ckpt_opts = dict(
ckpt_type='local', iteration=trainer.global_step, is_async=self.async_save
)
trainer.save_checkpoint(None, storage_options={LOCAL_CKPT_OPTS_KEY: local_ckpt_opts})
[docs]
class HierarchicalCheckpointIO(_WrappingCheckpointIO):
"""Wrapper for a global CheckpointIO enabling local checkpointing.
Based on the presence of local checkpointing options in saving `storage_options`,
routes the save to the original global CheckpointIO or the local checkpoint manager.
Must be used in conjunction with LocalCheckpointCallback which *initiates*
local checkpoint saving during training.
Args:
wrapped_checkpoint_io (CheckpointIO): global CheckpointIO to wrap
local_ckpt_manager (BaseCheckpointManager): local manager to use for local checkpoints
get_global_ckpt_iteration_fn (Callable[[_PATH], int]): a function that
given a path to a global checkpoint, extracts the global step iteration from it
(either from the path itself or by loading metadata from the checkpoint).
"""
def __init__(
self,
wrapped_checkpoint_io: CheckpointIO,
local_ckpt_manager: BaseCheckpointManager,
get_global_ckpt_iteration_fn: Callable[[_PATH], int],
):
super().__init__(wrapped_checkpoint_io)
self.local_ckpt_manager = local_ckpt_manager
self.get_global_ckpt_iteration_fn = get_global_ckpt_iteration_fn
[docs]
def save_checkpoint(
self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None
) -> None:
"""Save local or global checkpoint, depending on the presence of options."""
if storage_options is None or LOCAL_CKPT_OPTS_KEY not in storage_options:
return self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options)
if path is not None:
raise ValueError(f'Path shouldn\'t be set for a local checkpoint, got: {path}.')
return self._save_local_checkpoint(checkpoint, storage_options.get(LOCAL_CKPT_OPTS_KEY))
def _save_local_checkpoint(self, checkpoint: Dict[str, Any], local_ckpt_options: dict) -> None:
"""Save local checkpoint."""
return self.local_ckpt_manager.save(
self.to_tensor_aware_state_dict(checkpoint),
local_ckpt_options['iteration'],
is_async=local_ckpt_options['is_async'],
)
[docs]
def load_checkpoint(
self, path: _PATH, map_location: Optional[Any] = None, **kwargs
) -> Dict[str, Any]:
"""Load the newer of local (if available) and global checkpoint."""
latest_local_iteration = self.local_ckpt_manager.find_latest()
if latest_local_iteration < 0:
logger.debug('No local checkpoint available')
return self.checkpoint_io.load_checkpoint(path, map_location=map_location, **kwargs)
# There is a local ckpt available, but we don't know if it's newer than the global ckpt yet
latest_global_iteration = self.get_global_ckpt_iteration_fn(path)
if latest_local_iteration >= latest_global_iteration:
logger.info(
f'Local checkpoint interation {latest_local_iteration} greater than'
f' global {latest_global_iteration}.'
f' Resuming from a local checkpoint'
)
intermediate_state_dict, checkpoint_name = self.local_ckpt_manager.load()
logger.debug(f'Loaded local checkpoint {checkpoint_name}')
return self.from_tensor_aware_state_dict(intermediate_state_dict, **kwargs)
else:
logger.warning(
f'Found available local checkpoint from interation {latest_local_iteration},'
f' but global iteration {latest_global_iteration} is greater.'
f' Resuming from a global checkpoint.'
)
return self.checkpoint_io.load_checkpoint(path, map_location=map_location, **kwargs)
[docs]
def remove_checkpoint(self, path: _PATH) -> None:
"""Checkpoint removal is handled independently by the LocalCkptManager."""
return self.checkpoint_io.remove_checkpoint(path)
[docs]
@classmethod
def get_partial_wrapper_constructor(
cls,
local_ckpt_manager: BaseCheckpointManager,
get_global_ckpt_iteration_fn: Callable[[_PATH], int],
):
"""Allows to provide all arguments to the constructor except for the wrapped checkpoint io."""
return partial(
cls,
local_ckpt_manager=local_ckpt_manager,
get_global_ckpt_iteration_fn=get_global_ckpt_iteration_fn,
)
[docs]
@abstractmethod
def to_tensor_aware_state_dict(self, checkpoint: Dict[str, Any]) -> TensorAwareStateDict:
raise NotImplementedError
[docs]
def from_tensor_aware_state_dict(self, tensor_aware_checkpoint: TensorAwareStateDict, **kwargs):
return tensor_aware_checkpoint.to_state_dict()