Source code for nvidia_resiliency_ext.inprocess.rank_assignment

# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import bisect
import collections
import dataclasses
import datetime
import enum
import heapq
import itertools
import warnings
from typing import Callable, Optional, Union

from . import exception
from .state import Mode, State
from .store import StoreMixin


class RankDiscarded(exception.RestartError):
    r'''
    Exception raised when a distributed rank is discarded by
    :py:class:`RankAssignment`.
    '''

    pass


@dataclasses.dataclass
class RankAssignmentCtx:
    r'''
    Represents inputs and outputs of :py:class:`RankAssignment`.

    Args:
        state: :py:class:`Wrapper` state
        store: distributed store
        terminated_ranks: a set containing indices of terminated ranks
    '''

    state: State
    store: StoreMixin
    terminated_ranks: set[int]


class RankAssignment(abc.ABC):
    r'''
    Abstract base class for ``rank_assignment`` argument for
    :py:class:`inprocess.Wrapper`.

    :py:class:`RankAssignment` is responsible for reassigning distributed
    ranks, computing the new world size and selecting which ranks are active in
    the next iteration of the wrapped function.

    Active ranks call the provided wrapped function. Inactive ranks are waiting
    idle, and could serve as a pool of static, preallocated and preinitialized
    reserve ranks. Reserve ranks would be activated in a subsequent restart
    iteration if previously active ranks were terminated or became unhealthy.

    Multiple instances of composable :py:class:`RankAssignment` could be
    composed with :py:class:`inprocess.Compose` to achieve the desired
    behavior.
    '''

    @abc.abstractmethod
    def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
        r'''
        Implementation of a :py:class:`RankAssignment`.

        Args:
            ctx: :py:class:`RankAssignmentCtx`

        Returns:
            Modified :py:class:`RankAssignmentCtx`
        '''
        raise NotImplementedError


class RankFilter(RankAssignment):
    r'''
    :py:class:`RankFilter` is a subclass of :py:class:`RankAssignment` which
    selects which ranks are active in the current restart iteration of
    :py:class:`inprocess.Wrapper`.

    Active ranks call the wrapped function. Inactive ranks are waiting idle,
    and could serve as a pool of static, preallocated and preinitialized
    reserve ranks. Reserve ranks would be activated in a subsequent restart
    iteration if one of the active ranks is terminated or becomes unhealthy.

    Multiple :py:class:`RankFilter` or :py:class:`RankAssignment` instances can
    be composed using :py:class:`inprocess.Compose` to achieve the desired
    behavior. Typically, all :py:class:`RankFilter` instances should follow any
    :py:class:`RankAssignment` steps that recalculate rank indices or adjust
    the world size.
    '''

    @abc.abstractmethod
    def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
        r'''
        Implementation of a :py:class:`RankFilter`.

        Args:
            ctx: :py:class:`RankAssignmentCtx`

        Returns:
            Modified :py:class:`RankAssignmentCtx`
        '''
        raise NotImplementedError


[docs] class ActivateAllRanks(RankFilter): r''' Activates all distributed ranks. All healthy distributed ranks will call the provided wrapped function in the next iteration of :py:class:`inprocess.Wrapper`. :py:class:`ActivateAllRanks` unconditionally activates all ranks, and cannot be combined with any other :py:class:`RankAssignment` performing rank activation. ''' def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: state = dataclasses.replace( ctx.state, mode=Mode.ACTIVE, active_rank=ctx.state.rank, active_world_size=ctx.state.world_size, ) ctx.state = state return ctx
[docs] class MaxActiveWorldSize(RankFilter): r''' :py:class:`MaxActiveWorldSize` ensures that the active world size is no greater than the specified ``max_active_world_size``. Ranks with indices less than the active world size are active and calling the wrapped function, while ranks outside this range are inactive. Args: max_active_world_size: maximum active world size, no limit if :py:obj:`None` ''' def __init__(self, max_active_world_size: Optional[int] = None): self.max_active_world_size = max_active_world_size def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: state = ctx.state if state.active_world_size is None: active_world_size = state.world_size else: active_world_size = min(state.active_world_size, state.world_size) if self.max_active_world_size is not None: active_world_size = min(active_world_size, self.max_active_world_size) if state.rank < active_world_size: mode = Mode.ACTIVE active_rank = state.rank else: mode = Mode.INACTIVE active_rank = None state = dataclasses.replace( state, mode=mode, active_rank=active_rank, active_world_size=active_world_size, ) ctx.state = state return ctx
[docs] class ActiveWorldSizeDivisibleBy(RankFilter): r''' :py:class:`ActiveWorldSizeDivisibleBy` ensures that the active world size is divisible by a given number. Ranks within the adjusted world size are marked as active and are calling the wrapped function, while ranks outside this range are marked as inactive. Args: divisor: the divisor to adjust the active world size by ''' def __init__(self, divisor: int = 1) -> None: self.divisor = divisor def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: state = ctx.state divisor = self.divisor if state.active_world_size is None: active_world_size = state.world_size else: active_world_size = min(state.active_world_size, state.world_size) active_world_size = active_world_size // divisor * divisor if state.rank < active_world_size: mode = Mode.ACTIVE active_rank = state.rank else: mode = Mode.INACTIVE active_rank = None state = dataclasses.replace( state, mode=mode, active_rank=active_rank, active_world_size=active_world_size, ) ctx.state = state return ctx
[docs] class LayerFlag(enum.Flag): r''' A flag to modify rank assignment or rank filtering policy in a given :py:class:`Layer` of a :py:class:`Tree` rank assignment. Attributes: RESERVE: indicates that branches at this layer of the topology tree may be traversed while searching for a replacement inactive rank BACKFILL: indicates that branches at this layer of the topology tree may be traversed while searching for a replacement active rank ''' RESERVE = enum.auto() BACKFILL = enum.auto()
[docs] @dataclasses.dataclass class Layer: r''' Represents a configuration for a layer of branches at a certain depth in a topology tree constructed by :py:class:`Tree`. Args: min_ranks: the minimum number of healthy ranks in a subtree max_ranks: the maximum number of ranks to activate in a subtree, no limit if :py:obj:`None` key_or_fn: a string key, or a ``Callable`` evaluated with :py:class:`inprocess.State` as input to produce a grouping string key flag: an optional flag that modifies rank assignment policy in a given branch ''' min_ranks: int = 1 max_ranks: Optional[int] = None key_or_fn: Union[str, Callable[[State], str]] = '' flag: Optional[LayerFlag] = None
class Node: def __init__(self, parent, name, layer, state): self.parent = parent self.name = name self.layer = layer self.state = state self.active_count = 0 self.children = {} self.inactive_nodes = {} self.backfill_domain = None def add_child(self, name, layer, state): child = Node(self, name, layer, state) self.children[name] = child return child def is_leaf(self): return not self.children def iter_leaves(self): if self.is_leaf(): yield self else: for child in self.children.values(): yield from child.iter_leaves() def deactivate(self): self.state.mode = Mode.INACTIVE self.state.active_rank = None node = self while (parent := node.parent) is not None: parent.inactive_nodes[self.state.initial_rank] = self node = parent def terminate(self): self.state.mode = Mode.TERMINATED node = self while (parent := node.parent) is not None: if self.state.initial_rank in parent.inactive_nodes: parent.inactive_nodes.pop(self.state.initial_rank) node = parent def activate(self, active_rank): self.state.active_rank = active_rank self.state.mode = Mode.ACTIVE node = self while (parent := node.parent) is not None: if self.state.initial_rank in parent.inactive_nodes: parent.inactive_nodes.pop(self.state.initial_rank) node = parent def assign_backfill_domain(self): assert self.is_leaf() backfill_domain = None parent = self.parent while parent and parent.layer.flag and parent.layer.flag & LayerFlag.BACKFILL: backfill_domain = parent parent = parent.parent self.backfill_domain = backfill_domain def __repr__(self): return f'{type(self).__name__}({self.name=})' def bounded_activate(node, counter, path=None): if path is None: path = [] if node.is_leaf(): if all( ( ascendant.layer.max_ranks is None or ascendant.active_count < ascendant.layer.max_ranks for ascendant in path ) ): node.activate(counter) counter += 1 for ascendant in path: ascendant.active_count += 1 else: node.deactivate() path.append(node) for child in node.children.values(): counter = bounded_activate(child, counter, path) path.pop() return counter def propagate_terminations(node, terminated_ranks): def count_not_terminated(node, terminated_ranks): return sum( 1 for leaf in node.iter_leaves() if leaf.state.mode != Mode.TERMINATED and leaf.state.rank not in terminated_ranks ) for child in node.children.values(): terminated_ranks = propagate_terminations(child, terminated_ranks) if not node.is_leaf() and count_not_terminated(node, terminated_ranks) < node.layer.min_ranks: terminated_ranks.update( set( leaf.state.rank for leaf in node.iter_leaves() if leaf.state.mode != Mode.TERMINATED ) ) return terminated_ranks
[docs] class Tree(RankAssignment): r''' Implements an integrated rank assignment and activation algorithm that builds a multi-layer topology tree for distributed ranks. Each layer in this tree specifies constraints and policies for assigning and activating ranks. Grouping keys in each layer can align with hardware properties (e.g., to confine ranks within a compute node) or application-driven requirements (e.g., ensuring a particular divisibility). :py:class:`Tree` constructs a rooted topology tree whose depth equals the number of layers. Each layer corresponds to a :py:class:`Layer`, determining the rank assignment policy within its subtree. The distributed ranks are represented as leaves. **Algorithm** **Initialization** The algorithm traverses all ranks in depth-first order. For each visited rank, if all ancestor layers permit more active ranks (i.e., if the already-active ranks do not exceed any ancestor layer’s :py:attr:`Layer.max_ranks`), that rank is activated. **Rank reassignment** When some ranks terminate or become unhealthy, the algorithm proceeds in several steps: 1. **Propagate termination** Using a reverse depth-first search (children before parents), if the number of healthy ranks in a branch falls below :py:attr:`Layer.min_ranks`, that entire branch (and its subtree) is terminated. 2. **Replace ranks from a reserve domain** The algorithm attempts to replace terminated or unhealthy active ranks with inactive ranks from the nearest ancestor subtree that has the :py:attr:`LayerFlag.RESERVE` flag. This search for an inactive rank continues recursively upward until a branch without the :py:attr:`LayerFlag.RESERVE` flag is reached. 3. **Backfill ranks** Within any ancestor subtree flagged as :py:attr:`LayerFlag.BACKFILL`, an active rank with the largest rank index swaps places with a terminated rank, effectively filling local gaps (similar to :py:class:`FillGaps`). 4. **Shift ranks** After local backfills, remaining gaps from unhandled terminations are closed by shifting healthy ranks left to fill any vacated indices. This global step reassigns all rank indices past the first unhealthy rank (similar to :py:class:`ShiftRanks`). 5. **Optional filter** If a ``world_size_filter`` callable is provided, it can reduce the active ranks to a smaller ``world_size`` necessary for the workload. ``world_size_filter`` is invoked with the current number of active ranks as an argument, and returns the adjusted number of requested active ranks, no greater than the input. Healthy ranks with indices greater than the value returned value are deactivated and become part of the reserve pool. .. note:: :py:class:`Tree` cannot be composed with any other instance of :py:class:`RankAssignment` or :py:class:`RankFilter`. **Example** .. code-block:: python inprocess.rank_assignment.Tree( [ inprocess.rank_assignment.Layer( min_ranks=128, max_ranks=256, key_or_fn='root', flag=inprocess.rank_assignment.LayerFlag.RESERVE, ), inprocess.rank_assignment.Layer( min_ranks=8, max_ranks=8, key_or_fn=lambda _: socket.gethostname(), flag=inprocess.rank_assignment.LayerFlag.RESERVE, ), ] ) In this two-level topology tree: - The first layer (hostname-based) allows up to 8 active ranks per host (:py:attr:`Layer.max_ranks=8`). If the number of healthy ranks in any host drops below 8 (:py:attr:`Layer.min_ranks=8`), that entire host’s subtree is terminated. The algorithm can look for inactive reserve ranks within the same hostname because of the :py:attr:`LayerFlag.RESERVE` flag. - All hosts are grouped under the ``'root'`` layer, which permits up to 256 active ranks (:py:attr:`Layer.max_ranks=256`). If the global healthy rank count drops below 128 (:py:attr:`Layer.min_ranks=128`), all ranks are terminated. The :py:attr:`LayerFlag.RESERVE` flag at the ``'root'`` level lets the algorithm traverse upward from one host to another host through the ``'root'`` to search for reserve ranks. Args: layers: a list of :py:class:`Layer` instances, each layer specifies properties corresponding to one grouping level in a topology tree world_size_filter: an optional ``Callable`` that takes the final application-visible world size, and returns the new world size, no greater than the input ''' def __init__( self, layers: list[Layer], world_size_filter: Optional[Callable[int, int]] = None, ): self.layers = layers self.world_size_filter = world_size_filter self.tree = None self.rank_map = {} self.init_rank_map = {} def build_tree(self, state, store): key = [ (layer.key_or_fn(state) if callable(layer.key_or_fn) else layer.key_or_fn) for layer in self.layers ] store.send_state(state, state.rank) states = store.get_states(range(state.world_size)) store.send_key(key, state.rank) keys = store.get_keys(range(state.world_size)) root_keys = set(key[0] for key in keys) if len(root_keys) != 1: msg = ( f'all distributed ranks are required to share the same ' f'grouping key at the root level of the topology tree, but ' f'got {root_keys=}' ) raise RuntimeError(msg) self.tree = Node(parent=None, name=keys[0], layer=self.layers[0], state=None) for key, state in zip(keys, states): node = self.tree for k, layer in zip(key[1:], self.layers[1:]): if k in node.children: node = node.children[k] else: node.add_child(k, layer, None) node = node.children[k] child = node.add_child(state.initial_rank, None, state) self.init_rank_map[state.initial_rank] = child self.rank_map[state.rank] = child for idx, leaf in enumerate(self.tree.iter_leaves()): if idx != leaf.state.rank: topology_rank = idx environment_rank = leaf.state.rank msg = ( f'Initial environment rank assignment does not match the ' f'specified topology: {topology_rank=} {environment_rank=}' ) warnings.warn(msg) def replace_with_inactive(self, terminated_active_ranks): replaced_terminate_active_ranks = set() for terminated_active_rank in sorted(terminated_active_ranks): terminated_active_node = self.rank_map[terminated_active_rank] node = terminated_active_node while ( (parent := node.parent) and parent.layer.flag and parent.layer.flag & LayerFlag.RESERVE ): if parent.inactive_nodes: _, inactive = parent.inactive_nodes.popitem() inactive.activate(terminated_active_node.state.active_rank) replaced_terminate_active_ranks.add(terminated_active_rank) break node = parent return replaced_terminate_active_ranks def replace_with_backfill(self, unhandled_terminations): replaced_active = set() backfill_domains = collections.defaultdict(list) backfill_domains_id_map = {} for rank in sorted(unhandled_terminations): terminated_node = self.rank_map[rank] backfill_domain = terminated_node.backfill_domain if backfill_domain is not None: backfill_domains[id(backfill_domain)].append(terminated_node) backfill_domains_id_map[id(backfill_domain)] = backfill_domain else: replaced_active.add(terminated_node.state.active_rank) for domain_id, terminated_nodes in backfill_domains.items(): backfill_domain = backfill_domains_id_map[domain_id] largest_active_nodes = heapq.nlargest( len(terminated_nodes), (leaf for leaf in backfill_domain.iter_leaves() if leaf.state.mode == Mode.ACTIVE), key=lambda node: node.state.active_rank, ) for backfill_node, terminated_node in itertools.zip_longest( reversed(largest_active_nodes), terminated_nodes, fillvalue=None, ): if backfill_node is not None: replaced_active.add(backfill_node.state.active_rank) backfill_node.state.active_rank = terminated_node.state.active_rank else: replaced_active.add(terminated_node.state.active_rank) return replaced_active def shift_ranks(self, replaced_active, unhandled_terminations): sorted_replaced_active = sorted(replaced_active) for n in self.rank_map.values(): n.state.active_world_size -= len(unhandled_terminations) if n.state.active_rank is not None: count_less = bisect.bisect_left(sorted_replaced_active, n.state.active_rank) n.state.active_rank -= count_less def filter_active_world_size(self): active_world_size = next(iter(self.rank_map.values())).state.active_world_size new_active_world_size = self.world_size_filter(active_world_size) assert new_active_world_size <= active_world_size for leaf in self.tree.iter_leaves(): leaf.state.active_world_size = new_active_world_size if leaf.state.mode == Mode.ACTIVE and leaf.state.active_rank >= new_active_world_size: leaf.deactivate() def recompute_rank(self): leaves = list(self.tree.iter_leaves()) leaves.sort(key=lambda leaf: leaf.state.mode == Mode.TERMINATED) for idx, leaf in enumerate(leaves): leaf.state.rank = idx self.rank_map[idx] = leaf def update_tree(self, state, terminated_ranks): world_size = state.world_size - len(terminated_ranks) for node in self.tree.iter_leaves(): node.state.world_size = world_size for terminated_rank in terminated_ranks: self.rank_map[terminated_rank].state.rank = None self.rank_map[terminated_rank].state.active_rank = None def get_state_from_tree(self, state, terminated_ranks): tree_state = self.init_rank_map[state.initial_rank].state if tree_state.mode == Mode.TERMINATED: raise RankDiscarded(f'{state.rank=} {terminated_ranks=}') state = State(**dataclasses.asdict(tree_state)) return state def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: state = ctx.state store = ctx.store terminated_ranks = ctx.terminated_ranks if self.tree is None: self.build_tree(state, store) active_world_size = bounded_activate(self.tree, 0) for node in self.rank_map.values(): node.state.active_world_size = active_world_size for leaf in self.tree.iter_leaves(): leaf.assign_backfill_domain() for leaf in self.tree.iter_leaves(): leaf.state.copy_from(state, fields=['fn_exception', 'iteration']) terminated_ranks = propagate_terminations(self.tree, terminated_ranks) terminated_active_ranks = set( rank for rank in terminated_ranks if self.rank_map[rank].state.mode == Mode.ACTIVE ) for terminated_rank in terminated_ranks: self.rank_map[terminated_rank].terminate() replaced_terminate_active_ranks = self.replace_with_inactive(terminated_active_ranks) unhandled_terminations = terminated_active_ranks - replaced_terminate_active_ranks if unhandled_terminations: replaced_active = self.replace_with_backfill(unhandled_terminations) self.shift_ranks(replaced_active, unhandled_terminations) if self.world_size_filter is not None: self.filter_active_world_size() self.update_tree(state, terminated_ranks) self.recompute_rank() ctx.state = self.get_state_from_tree(state, terminated_ranks) ctx.terminated_ranks = set() return ctx
[docs] class FillGaps(RankAssignment): r''' A class for reassigning distributed ranks, filling in gaps caused by terminated or unhealthy ranks. The :py:class:`FillGaps` class is a specialized rank assignment strategy that reorders ranks to fill gaps created by terminated or unhealthy ranks. It preserves the previous rank assignment for the first ``world_size - len(terminated_ranks)`` healthy ranks; the remaining healthy ranks are reassigned to fill in gaps left by unhealthy ranks. Example: .. code-block:: python |<--- preserved --->|<- moved ->| |<--new world size->| +---+---+---+---+---+---+---+---+ +---+---+---+---+---+ | 0 | X | 2 | 3 | X | X | 6 | 7 | --> | 0 | 6 | 2 | 3 | 7 | +---+---+---+---+---+---+---+---+ +---+---+---+---+---+ ^ ^ | | | | | | --------------------- | | | ------------- ''' def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: state = ctx.state rank = state.rank world_size = state.world_size terminated_ranks = ctx.terminated_ranks ordered_terminated_ranks = sorted(list(terminated_ranks)) world_size = world_size - len(terminated_ranks) if rank in terminated_ranks: raise RankDiscarded(f'{rank=} {terminated_ranks=}') elif rank >= world_size: rank = ordered_terminated_ranks[rank - world_size] state = dataclasses.replace( state, rank=rank, world_size=world_size, ) ctx.state = state ctx.terminated_ranks = set() return ctx
[docs] class ShiftRanks(RankAssignment): r''' A class for reassigning distributed ranks, filling in gaps caused by terminated or unhealthy ranks. The :py:class:`ShiftRanks` class is a specialized rank assignment strategy that shifts all healthy ranks to the left to fill gaps created by terminated or unhealthy ranks. :py:class:`ShiftRanks` preserves the relative order of all healthy ranks, but all ranks past the first unhealthy rank are reassigned (shifted). Example: .. code-block:: python <- ->|<------- moved ------->| |<--new world size->| ---- v | +---+---+---+---+---+---+---+---+ +---+---+---+---+---+ | 0 | X | 2 | 3 | X | X | 6 | 7 | --> | 0 | 2 | 3 | 6 | 7 | +---+---+---+---+---+---+---+---+ +---+---+---+---+---+ ^ | ^ ^ | | | | | | | | ---- ------------ | | | ------------ ''' def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: state = ctx.state rank = state.rank world_size = state.world_size terminated_ranks = ctx.terminated_ranks world_size = world_size - len(terminated_ranks) if rank in terminated_ranks: raise RankDiscarded(f'{rank=} {terminated_ranks=}') else: rank = rank - sum(rank > terminated_rank for terminated_rank in terminated_ranks) state = dataclasses.replace( state, rank=rank, world_size=world_size, ) ctx.state = state ctx.terminated_ranks = set() return ctx
[docs] class FilterCountGroupedByKey(RankAssignment): r''' A class for filtering distributed ranks by grouping by a key. :py:class:`FilterCountGroupedByKey` organizes ranks into groups based on a specified string key. For each group, it increments a group counter by 1 for every healthy rank. A given boolean ``condition`` is then evaluated for each rank, with the corresponding group counter passed as input. - If ``condition(group_counter)`` evaluates to ``True``, the rank is preserved. - If it evaluates to ``False``, the rank is considered unhealthy and marked for termination. :py:class:`FilterCountGroupedByKey` needs to be followed by another :py:class:`RankAssignment` that performs the actual rank termination by raising :py:exc:`RankDiscarded` exception. .. code-block:: python condition = lambda count: count == 2 +---+---+---+---+---+---+---+---+ +---+---+---+---+---+---+---+---+ | 0 | X | 2 | 3 | X | X | 6 | 7 | --> | X | X | 2 | 3 | X | X | 6 | 7 | +---+---+---+---+---+---+---+---+ +---+---+---+---+---+---+---+---+ | key=0 | key=1 | key=2 | key=3 | | key=0 | key=1 | key=2 | key=3 | | | | | | | | | | | |count=1|count=2|count=0|count=2| | False | True | False | True | Example: .. code-block:: python # hostname is the group key, and condition checks if exactly 8 ranks # corresponding to a given hostname are in a healthy state, if the # count is different than 8, all ranks from corresponding hostname are # considered unhealthy, and terminated; remaining healthy ranks are # shifted to the left to fill all gaps created by unhealthy ranks. rank_assignment = ( inprocess.Compose( inprocess.rank_assignment.ShiftRanks(), inprocess.rank_assignment.FilterCountGroupedByKey( key_or_fn=lambda _: socket.gethostname(), condition=lambda count: count == 8, ), ), ), Args: key_or_fn: a string key, or a ``Callable`` evaluated with :py:class:`inprocess.state.State` as the input to produce a string key condition: condition to be evaluated with group counter as the input, if ``False`` the rank is terminated timeout: timeout for distributed barrier ''' instance_count = 0 def __init__( self, key_or_fn: Union[str, Callable[[State], str]], condition: Callable[int, bool], timeout: datetime.timedelta = datetime.timedelta(seconds=60), ): self.key_or_fn = key_or_fn self.condition = condition self.timeout = timeout self.name = f'{type(self).__name__}_{type(self).instance_count}' type(self).instance_count += 1 def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: COUNT_ALIVE_BARRIER = f'count_alive_barrier_{self.name}' SUBMIT_MISMATCHING_BARRIER = f'submit_mismatching_barrier_{self.name}' RANKS_TO_TERMINATE = f'ranks_to_terminate_{self.name}' rank = ctx.state.rank world_size = ctx.state.world_size store = ctx.store terminated_ranks = ctx.terminated_ranks alive_world_size = world_size - len(terminated_ranks) if rank not in terminated_ranks: key = self.key_or_fn(ctx.state) if callable(self.key_or_fn) else self.key_or_fn prefixed_key = f'filter_grouped_by_key_{self.name}_{key}' store.add(prefixed_key, 1) store.barrier( ranks=[rank], group_name=COUNT_ALIVE_BARRIER, rendezvous_count=alive_world_size, timeout=self.timeout, ) if not self.condition(int(store.get(prefixed_key))): store.append(RANKS_TO_TERMINATE, f'{rank},') store.barrier( ranks=[rank], group_name=SUBMIT_MISMATCHING_BARRIER, rendezvous_count=alive_world_size, timeout=self.timeout, ) if store.check([RANKS_TO_TERMINATE]): ranks_to_terminate = set( int(r) for r in store.get(RANKS_TO_TERMINATE).decode().rstrip(',').split(',') ) else: ranks_to_terminate = set() ctx.terminated_ranks = terminated_ranks.union(ranks_to_terminate) return ctx