Source code for nvidia_resiliency_ext.checkpointing.local.replication.strategies

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
from abc import ABC, abstractmethod
from typing import List, Mapping, Sequence, Tuple, TypeVar, Generic, Optional

import torch

from ..base_state_dict import TensorAwareStateDict
from .group_utils import ExchangePlan, GroupWrapper, ProcessGroupLike, parse_group_sequence
from .utils import debug_time, debug_msg, zip_strict

logger = logging.getLogger(__name__)


[docs] class NoReplicasAvailableError(Exception): """Exception raised when no replicas are available for a requested ID.""" pass
[docs] class ReplicationStrategy(ABC): """Abstract base class defining the interface for replication strategies."""
[docs] @abstractmethod def replicate( self, local_ckpt: TensorAwareStateDict, id_: str ) -> Tuple[List[TensorAwareStateDict], List[str]]: """Replicates the local checkpoint. Args: local_ckpt (TensorAwareStateDict): The local checkpoint to be replicated. id_ (str): Identifier for the checkpoint. Returns: A list of replicated checkpoints together with correspinding IDs """ pass
[docs] @abstractmethod def retrieve_plan( self, globally_available_ids: Mapping[int, List[str]], wanted: Sequence[str] ) -> ExchangePlan: """Generates a retrieval plan based on globally available IDs. Args: globally_available_ids (Mapping[int, List[str]]): Mapping of ranks to available IDs. wanted (Sequence[str]): List of IDs to retrieve. Returns: ExchangePlan: A plan detailing how to retrieve the requested IDs. """ pass
[docs] @abstractmethod def retrieve_execute(self, *args, **kwargs): """Executes the retrieval plan.""" pass
[docs] class CliqueReplicationStrategy(ReplicationStrategy): """Implements a replication strategy where all participants are in a single group. This strategy replicates local checkpoints among all ranks in the local process group, enabling efficient retrieval and communication of tensor data. """ def __init__(self, local_group: ProcessGroupLike, target_device="cpu"): self.local_group: GroupWrapper = GroupWrapper.wrap(local_group) self.target_device = target_device
[docs] @debug_time('CliqueReplicationStrategy.replicate', logger) def replicate( self, local_ckpt: TensorAwareStateDict, id_: str ) -> Tuple[List[TensorAwareStateDict], List[str]]: """Replicates the local checkpoint and returns the replicated checkpoints with IDs. This method splits the local checkpoint into a hollow state dictionary and its tensor data, gathers replicated copies from other ranks, and reconstructs the state dictionaries. Args: local_ckpt (TensorAwareStateDict): The local checkpoint to replicate. id_ (str): Identifier for the state dict. Returns: Tuple[List[TensorAwareStateDict], List[str]]: - List[TensorAwareStateDict]: A list of replicated checkpoints. - List[str]: A list of identifiers for the replicated checkpoints. """ sent_bytes = 0 recv_bytes = 0 # Note: it makes the original local_ckpt hollow # Split local_ckpt into a list of tensors and a picklable hollow state dict my_tensor_data = local_ckpt.pop_tensors() # Send hollow state dicts and tensors separately with debug_time("all_gather_hollow_ckpt"): others_local_ckpts = self.local_group.all_gather_object(local_ckpt) assert all(lch.is_hollow for lch in others_local_ckpts) my_tensor_data_nbytes = sum(ten.nbytes for ten in my_tensor_data) with debug_time("all_gather_others_tensor_data"): others_tensor_data = self.local_group.all_gather_batch( my_tensor_data, target_device=self.target_device ) others_tensor_data_nbytes = sum( [sum(ten.nbytes for ten in tensor_list) for tensor_list in others_tensor_data] ) sent_bytes += my_tensor_data_nbytes recv_bytes += others_tensor_data_nbytes - my_tensor_data_nbytes # Assemble hollow state dicts and tensors back into whole state dicts for lch, td in zip_strict(others_local_ckpts, others_tensor_data): lch.insert_tensors(td) assert all(not lch.is_hollow for lch in others_local_ckpts) # Label obtained state dicts with ids with debug_time("all_gather_other_ids"): other_ids = self.local_group.all_gather_object(id_) debug_msg(f"{sent_bytes=}") debug_msg(f"{recv_bytes=}") assert local_ckpt.is_hollow return others_local_ckpts, other_ids
[docs] @debug_time('CliqueReplicationStrategy.retrieve_plan', logger) def retrieve_plan( self, globally_available_ids: Mapping[int, List[str]], wanted: Sequence[str] ) -> ExchangePlan: """Creates a plan for retrieving the specified IDs from globally available replicas. Args: globally_available_ids (Mapping[int, List[str]]): Mapping of ranks to available IDs. wanted (Sequence[str]): List of IDs to retrieve. Returns: ExchangePlan: A plan detailing how to retrieve the requested IDs. Raises: NoReplicasAvailableError: If no replicas are found for a requested ID. """ # TODO: expand the function to multiple wanted IDs, and with smarter "routing" rng = random.Random(0) with debug_time("all_gather_wanted_ids"): globally_wanted = self.local_group.all_gather_object(wanted) result = ExchangePlan(group=self.local_group) for receiver, currently_wanted in zip(self.local_group.ranks, globally_wanted): for wanted_id in currently_wanted: available = set( rank for rank in self.local_group.ranks if wanted_id in globally_available_ids[rank] ) if not available: raise NoReplicasAvailableError( f"No replicated copies for id={wanted_id} found!" ) if receiver in available: sender = receiver else: sender = rng.choice(sorted(list(available))) result.plan(sender=sender, receiver=receiver, id_=wanted_id) return result
[docs] @debug_time('CliqueReplicationStrategy.retrieve_execute', logger) def retrieve_execute(self, *args, **kwargs): """Executes the retrieval plan using the local group. Returns: The result of executing the retrieval plan. """ return self.local_group.execute_plan(*args, **kwargs)
[docs] @classmethod @debug_time('CliqueReplicationStrategy.from_replication_params', logger) def from_replication_params( cls, replication_jump: int = torch.cuda.device_count(), replication_factor: int = 2 ) -> 'CliqueReplicationStrategy': """Instantiates process groups necessary for checkpoint replication. Training ranks are divided into `W // F` distinct groups of size `F`, where `W` is the world size and `F` is the `replication_factor`. Each group consists of ranks: `n`, `n + J`, `n + 2J`, ..., `n + (F - 1)J`, where `J` is the `replication_jump` and `n = aJF + b`, with: - `a = 0, 1, ..., (W / (JF)) - 1` - `b = 0, 1, ..., J - 1`. Checkpoint shards are exchanged and fully replicated within each group. **Important:** The world size (`W`) must be divisible by `J * F`. This grouping enables replication across different failure domains by specifying `J` equal to the failure blast radius. **Example:** For a world size of 32, `replication_jump = 8`, and `replication_factor = 2`, the replication groups (cliques) are: 0-8, 1-9, 2-10, 3-11, 4-12, 5-13, 6-14, 7-15, 16-24, 17-25, 18-26, 19-27, 20-28, 21-29, 22-30, 23-31 Args: replication_jump (int, optional): `J` in the formula above. Represents the gap between successive ranks storing replicas of a given rank's data. replication_factor (int, optional): `F` in the formula above. Denotes the number of ranks storing replicas of a given rank's data. """ logger.debug(f'Initializing {cls.__name__}') repl_process_groups_ranks: List[List[int]] = parse_group_sequence( replication_jump=replication_jump, replication_factor=replication_factor, world_size=torch.distributed.get_world_size(), ) repl_process_groups: List[torch.distributed.ProcessGroup] = [ torch.distributed.new_group(g) for g in repl_process_groups_ranks ] my_process_group = GroupWrapper.from_list_of_groups(repl_process_groups) return cls(my_process_group, target_device="cpu")
EagerT = TypeVar('EagerT')
[docs] class LazyReplicationStrategyBuilder(ReplicationStrategy, ABC, Generic[EagerT]): """Represents an uninitialized replication strategy. Replication strategy needs process groups which can be impossible to initialize and the time of instantiation of the ReplicationStrategy class. This class allows for a lazy initialization of an instance of `EagerT` type: >>> lazy_repl_strategy = LazyReplicationStrategyBuilder() >>> ... >>> lazy_repl_strategy.replicate(...) # performs lazy init transparently >>> lazy_repl_strategy.retrieve_execute(...) # reuses previously initialized instance transparently """ def __init__(self): self._replication_strategy: Optional[EagerT] = None @property def replication_strategy(self) -> EagerT: """Lazy build on demand.""" if self._replication_strategy is None: self._replication_strategy = self._eager_build() return self._replication_strategy
[docs] def replicate( self, local_ckpt: TensorAwareStateDict, id_: str ) -> Tuple[List[TensorAwareStateDict], List[str]]: """Delegate to the underlying replication strategy.""" return self.replication_strategy.replicate(local_ckpt, id_)
[docs] def retrieve_plan( self, globally_available_ids: Mapping[int, List[str]], wanted: Sequence[str] ) -> ExchangePlan: """Delegate to the underlying replication strategy.""" return self.replication_strategy.retrieve_plan(globally_available_ids, wanted)
[docs] def retrieve_execute(self, *args, **kwargs): """Delegate to the underlying replication strategy.""" return self.replication_strategy.retrieve_execute(*args, **kwargs)
@abstractmethod def _eager_build(self) -> EagerT: """Instantiates the eager class."""
[docs] class LazyCliqueReplicationStrategy(LazyReplicationStrategyBuilder[CliqueReplicationStrategy]): """Lazy version of CliqueReplicationStrategy allowing to delay process group formation. Training ranks are divided into `W // F` distinct groups of size `F`, where `W` is the world size and `F` is the `replication_factor`. Each group consists of ranks: `n`, `n + J`, `n + 2J`, ..., `n + (F - 1)J`, where `J` is the `replication_jump` and `n = aJF + b`, with: - `a = 0, 1, ..., (W / (JF)) - 1` - `b = 0, 1, ..., J - 1`. Checkpoint shards are exchanged and fully replicated within each group. **Important:** The world size (`W`) must be divisible by `J * F`. This grouping enables replication across different failure domains by specifying `J` equal to the failure blast radius. **Example:** For a world size of 32, `replication_jump = 8`, and `replication_factor = 2`, the replication groups (cliques) are: 0-8, 1-9, 2-10, 3-11, 4-12, 5-13, 6-14, 7-15, 16-24, 17-25, 18-26, 19-27, 20-28, 21-29, 22-30, 23-31 Args: replication_jump (int, optional): `J` in the formula above. Represents the gap between successive ranks storing replicas of a given rank's data. replication_factor (int, optional): `F` in the formula above. Denotes the number of ranks storing replicas of a given rank's data. """ def __init__( self, replication_jump: int = torch.cuda.device_count(), replication_factor: int = 2 ): super().__init__() self.replication_jump = replication_jump self.replication_factor = replication_factor def _eager_build(self): return CliqueReplicationStrategy.from_replication_params( self.replication_jump, self.replication_factor )