Source code for nvidia_resiliency_ext.checkpointing.local.ckpt_managers.base_manager

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" BaseCheckpointManager defines interface for managing local checkpoints.

Each CheckpointManager handles tasks such as:
    - cleaning up old checkpoints
    - tracking the iteration of the latest valid checkpoint
    - saving and loading checkpoints using the implemented backend.

It uses a state_dict interface, requiring users to adjust the state_dict as needed,
with MCore facilitating these modifications.
"""

import gc
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Iterable, Optional, Tuple

import torch

from ..base_state_dict import TensorAwareStateDict
from ..replication.group_utils import GroupWrapper
from ..replication.strategies import ReplicationStrategy
from ..replication.utils import debug_time
from ...async_ckpt.core import AsyncRequest
from ...utils import _disable_gc

logger = logging.getLogger(__name__)

CkptID = Tuple[int, int, Any]


[docs] class CheckpointingException(Exception): """Base checkpointing related exception""" pass
[docs] class SameMachineReplicationException(CheckpointingException): """ Exception raised when an attempt is made to override a file during replication. Inherits from `CheckpointingException`. """ def __init__(self, ckpt_id): message = f"Checkpoint '{ckpt_id}' already exists on the same machine." super().__init__(message)
[docs] class BaseCheckpointManager(ABC): """ The Base Checkpoint Manager provides an interface for integrating different checkpoint managers, abstracting replication mechanisms from the underlying implementations. """ def __init__(self, session_id, repl_strategy: ReplicationStrategy = None): self.latest_iteration = -1 self.repl_strategy = repl_strategy self.session_id = session_id self._rank = None @property def rank(self): if self._rank is None: if torch.distributed.is_initialized(): self._rank = torch.distributed.get_rank() else: logger.warning("Torch distributed backend has not been initialized.") self._rank = 0 return self._rank def _ckpt_id(self, iteration: int) -> CkptID: """ Generates a unique checkpoint ID from the iteration number. Each rank assigns its own distinct ID. Args: iteration (int): The iteration number. Returns: A unique checkpoint ID. """ if iteration < 0: raise CheckpointingException( f"Invalid iteration: expected a non-negative value, got {iteration}." ) return (iteration, self.rank, self.session_id) @abstractmethod def _my_ckpt_ids(self) -> Iterable[CkptID]: """Collect all locally available checkpoint IDs.""" pass @abstractmethod def _load(self, ckpt_id: CkptID) -> TensorAwareStateDict: """Load of the checkpoint identified by ckpt_id. Should raise a CheckpointingException if failed""" pass @abstractmethod def _save(self, state_dict: TensorAwareStateDict, ckpt_id: CkptID): """Save of the tensor_aware_state_dict identified by ckpt_id. Should raise a SameMachineReplicationException if the checkpoint already exists""" pass @abstractmethod def _cleanup(self, iteration): """Removes outdated or invalid checkpoints after successfully saving the checkpoint for the specified iteration. Args: iteration : The iteration number for which the checkpoint was successfully saved. """ pass @abstractmethod def _cleanup_failed_save(self, iteration): """Removes invalid checkpoints that could not be saved due to a failure. Args: iteration : The iteration number for which the checkpoint failed to save. """ @debug_time('BaseCheckpointManager._load_fn', logger) def _load_fn(self, ckpt_id: CkptID) -> TensorAwareStateDict: state_dict = self._load(ckpt_id) state_dict.restore_tensor_device(non_blocking=False) logger.debug(f'Finish loading {ckpt_id}') return state_dict @debug_time('BaseCheckpointManager._save_fn', logger) @_disable_gc() def _save_fn(self, id_to_state_dict): for ckpt_id, state_dict in id_to_state_dict.items(): self._save(state_dict, ckpt_id) logger.debug(f'Finish saving {ckpt_id}')
[docs] @debug_time('BaseCheckpointManager.find_latest', logger) def find_latest(self): """ Searches for the most recent complete checkpoint and returns its iteration number. If no complete checkpoints are found, the method returns -1. All training ranks have to call this method at once. Returns: int: The iteration number of the most recent complete checkpoint, or -1 if no checkpoints are available. """ if self.latest_iteration != -1: # Use cache to optimize performance in case of two-step loading. # Assumes the cache remains valid unless a new save occurs, # as no other operations should invalidate the most recent iteration. logger.debug(f'Using cached latest_iteration: {self.latest_iteration} in find_latest') return self.latest_iteration group_wrapper = GroupWrapper() self.globally_available_ids = group_wrapper.all_gather_object(self._my_ckpt_ids()) # Maps each iteration to a corresponding set of ranks checkpoint_coverage_map = defaultdict(set) for ids in self.globally_available_ids: for ckpt_id in ids: iteration, rank, session_id = ckpt_id assert type(iteration) == int assert session_id == self.session_id checkpoint_coverage_map[iteration].add(rank) self.latest_iteration = max( [ iteration for iteration, rank_set in checkpoint_coverage_map.items() if rank_set == set(group_wrapper.ranks) ], default=-1, ) return self.latest_iteration
[docs] @debug_time('BaseCheckpointManager.load', logger) def load(self) -> Tuple[TensorAwareStateDict, str]: """Loads the most recent complete checkpoint. Ensure that `find_latest()` has been called first to identify the latest checkpoint. All training ranks have to call this method at once. Returns: Tuple[TensorAwareStateDict, str] - `state_dict`: The state dictionary loaded from the most recent complete checkpoint. - `ckpt_id`: The identifier of the checkpoint that was successfully loaded. """ if self.latest_iteration == -1: raise CheckpointingException( "The 'find_latest' method must be called before invoking the 'load' function." ) ckpt_id = self._ckpt_id(self.latest_iteration) logger.debug(f'Loading checkpoint from {self.latest_iteration} iteration') if self.repl_strategy is not None: plan = self.repl_strategy.retrieve_plan(self.globally_available_ids, [ckpt_id]) my_data = {k: self._load_fn(k) for k in plan.required_ids()} execute_result = list(self.repl_strategy.retrieve_execute(plan, my_data).items()) # TODO: refactor assert len(execute_result) == 1, f"Got {len(execute_result)} IDs, but requested only 1!" assert ( execute_result[0][0] == ckpt_id ), f"Retrieved different ID ({execute_result[0][0]}) than requested ({ckpt_id})?" return execute_result[0][1], ckpt_id return self._load_fn(ckpt_id), ckpt_id
[docs] @debug_time("BaseCheckpointManager.save", logger) def save( self, state_dict: TensorAwareStateDict, iteration: int, is_async: bool = False ) -> Optional[AsyncRequest]: """ Saves the `state_dict` associated with the specified `iteration` number. If `is_async` is set to `True`, the save operation is performed asynchronously, and the function returns an `AsyncRequest` object. Otherwise, the save operation is completed synchronously. All training ranks have to call this method at once. Args: state_dict (dict): The state dictionary to be saved. iteration (int): The iteration number for identifying the checkpoint. is_async (bool): Whether to perform the save operation asynchronously. Returns: AsyncRequest or None: An `AsyncRequest` object if `is_async` is True; otherwise, None as the operation completes synchronously. """ assert ( self.latest_iteration < iteration ), f'A newer checkpoint is already available: {self.latest_iteration} (saving {iteration})' if self.repl_strategy: save_arg = { ckpt_id: s_dict for s_dict, ckpt_id in zip( *self.repl_strategy.replicate(state_dict, self._ckpt_id(iteration)) ) } # TODO consider D2H (below) during replicate, and add more stuff in async save_fn save_arg[self._ckpt_id(iteration)].copy_tensors_to_cpu(non_blocking=True) save_args = (save_arg,) else: state_dict.copy_tensors_to_cpu(non_blocking=True) save_args = ({self._ckpt_id(iteration): state_dict},) save_fn = self._save_fn self.latest_iteration = -1 # invalidate latest_iteration def finalize_fn(): validated_latest_iteration = self.find_latest() # TODO optimize self.latest_iteration = -1 # invalidate latest_iteration if validated_latest_iteration < iteration: # TODO Execute cleanup in a separate process self._cleanup_failed_save(iteration) raise CheckpointingException( f"Failure during saving local checkpoint from iteration {iteration}" f" (last valid iteration is {validated_latest_iteration})" ) else: if validated_latest_iteration == iteration: logging.info(f"Succesfully saved local checkpoint from iteration {iteration}") else: logging.info( f"WARNING: during saving iteration {iteration} " f"found valid checkpoint from iteration {validated_latest_iteration}" ) # TODO Execute cleanup in a separate process self._cleanup(iteration) if is_async: # we must wait for D2H to complete before returning control to the training with debug_time("ckpt_D2H_synchronize", logger): torch.cuda.synchronize() return AsyncRequest(save_fn, save_args, [finalize_fn]) assert not is_async save_fn(*save_args) # Wait so everyone is done (necessary) if torch.distributed.is_initialized(): torch.distributed.barrier() finalize_fn()