Source code for nvidia_resiliency_ext.inprocess.wrap

# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import datetime
import functools
import gc
import itertools
import logging
import pathlib
import threading
import time
from datetime import timedelta
from typing import Any, Callable, Optional

import torch

from . import param_utils
from .abort import Abort
from .abort import AbortTorchDistributed
from .attribution import Interruption
from .attribution import InterruptionRecord
from .exception import HealthCheckError
from .exception import InternalError
from .exception import RestartError
from .finalize import Finalize
from .health_check import HealthCheck
from .initialize import Initialize
from .logging import Logging
from .logging import log_exc
from .monitor_process import MonitorProcess
from .monitor_thread import MonitorThread
from .monitor_thread import RankShouldRestart
from .monitor_thread import reraise_if_unraisable
from .param_utils import enforce_subclass
from .param_utils import enforce_type
from .param_utils import enforce_value
from .progress_watchdog import ProgressWatchdog
from .rank_assignment import RankAssignment
from .rank_assignment import ShiftRanks
from .rank_filter import RankFilter
from .state import State
from .state import Status
from .store import PrefixStore
from .store import StoreMixin
from .store import TCPStore


class HealthCheckPassed(Exception):
    pass


def spare_fn(state, progress_watchdog, progress_watchdog_interval):
    log = logging.getLogger(__name__)
    rank = state.rank
    log.debug(f'{rank=} starting spare_fn')
    progress_watchdog.ping()
    while not state.store.is_any_rank_completed():
        progress_watchdog.ping()
        time.sleep(progress_watchdog_interval.total_seconds())
    log.debug(f'{rank=} finished spare_fn')


[docs] class Wrapper: r''' Python function wrapper that adds restart capabilities to an existing Python function implementing distributed PyTorch workload. Upon a fault, the wrapped function is restarted across all distributed ranks, within the same operating system process. Wrapped function restart invocation excludes distributed ranks that are terminated, missing, or deemed unhealthy. When a failure occurs on any worker, the wrapper ensures the function restarts simultaneously on all healthy ranks. This process continues until all ranks complete execution successfully or a termination condition is met. See the :doc:`Usage Guide <../usage_guide>` for detailed documentation. Args: store_factory: Factory to construct the internal distributed store for communication between ranks. store_kwargs: Dictionary of keyword arguments to construct the internal store with ``store_factory(**store_kwargs)``. initialize: Rank-local initialize. abort: Asynchronously aborts execution. finalize: Rank-local finalize. health_check: Rank-local health check. rank_assignment: Reassigns ranks and computes the new world size for the next restart iteration. rank_filter: Specifies ranks actively calling the wrapped function. monitor_thread_interval: Monitoring interval for the monitor thread. monitor_process_interval: Monitoring interval for the monitor process. progress_watchdog_interval: Interval for automatic progress watchdog timestamp updates. soft_timeout: Soft progress timeout. Timed-out rank executes asynchronous abort, and participates in the restart if healthy. hard_timeout: Hard progress timeout. Timed-out rank is terminated. heartbeat_timeout: Timeout for a missing rank detection heartbeat. barrier_timeout: Barrier timeout. completion_timeout: Completion barrier timeout. last_call_wait: Time interval for other ranks to report concurrent terminal failures. termination_grace_time: Interval between ``SIGTERM`` and ``SIGKILL`` signals issued by the hard timeout mechanism. monitor_process_logfile: Absolute filepath for the monitor process logfile. It may contain "{rank}" placeholder, to be filled with initial integer rank id. enabled: Enables the wrapper. Returns: Returns the value of the wrapped function if all ranks successfully completed execution. Inactive ranks return :py:obj:`None`. ''' def __init__( self, *, store_factory: type[StoreMixin] = TCPStore, store_kwargs: Optional[dict[str, Any]] = None, initialize: Optional[Initialize] = None, abort: Optional[Abort] = AbortTorchDistributed(), finalize: Optional[Finalize] = None, health_check: Optional[HealthCheck] = None, rank_assignment: RankAssignment = ShiftRanks(), rank_filter: Optional[RankFilter] = None, monitor_thread_interval: datetime.timedelta = timedelta(seconds=1), monitor_process_interval: datetime.timedelta = timedelta(seconds=1), progress_watchdog_interval: datetime.timedelta = timedelta(seconds=1), soft_timeout: datetime.timedelta = timedelta(seconds=60), hard_timeout: datetime.timedelta = timedelta(seconds=90), heartbeat_timeout: datetime.timedelta = timedelta(seconds=30), barrier_timeout: datetime.timedelta = timedelta(seconds=120), completion_timeout: datetime.timedelta = timedelta(seconds=120), last_call_wait: datetime.timedelta = timedelta(seconds=1), termination_grace_time: datetime.timedelta = timedelta(seconds=5), monitor_process_logfile: Optional[str] = None, enabled: bool = True, ): enforce_subclass('store_factory', StoreMixin) enforce_type('store_kwargs', (dict, type(None))) enforce_type('initialize', (Initialize, type(None))) enforce_type('abort', (Abort, type(None))) enforce_type('finalize', (Finalize, type(None))) enforce_type('health_check', (HealthCheck, type(None))) enforce_type('rank_assignment', (RankAssignment, type(None))) enforce_type('rank_filter', (RankFilter, type(None))) enforce_type('monitor_thread_interval', datetime.timedelta) enforce_type('monitor_process_interval', datetime.timedelta) enforce_type('progress_watchdog_interval', datetime.timedelta) enforce_type('soft_timeout', datetime.timedelta) enforce_type('hard_timeout', datetime.timedelta) enforce_type('heartbeat_timeout', datetime.timedelta) enforce_type('barrier_timeout', datetime.timedelta) enforce_type('completion_timeout', datetime.timedelta) enforce_type('last_call_wait', datetime.timedelta) enforce_type('termination_grace_time', datetime.timedelta) enforce_type('monitor_process_logfile', (str, type(None))) enforce_type('enabled', bool) enforce_value(soft_timeout < hard_timeout < barrier_timeout) enforce_value(monitor_process_interval < barrier_timeout) enforce_value(heartbeat_timeout < barrier_timeout) enforce_value(monitor_process_interval < heartbeat_timeout) enforce_value(monitor_thread_interval < soft_timeout) enforce_value(progress_watchdog_interval < soft_timeout) if monitor_process_logfile is not None: enforce_value(pathlib.Path(monitor_process_logfile).is_absolute()) enforce_value(torch.distributed.is_available()) if store_kwargs is None: store_kwargs = {} self.store_factory = store_factory self.store_kwargs = store_kwargs self.initialize = initialize self.abort = abort self.finalize = finalize self.health_check = health_check self.rank_assignment = rank_assignment self.rank_filter = rank_filter self.monitor_thread_interval = monitor_thread_interval self.monitor_process_interval = monitor_process_interval self.progress_watchdog_interval = progress_watchdog_interval self.soft_timeout = soft_timeout self.hard_timeout = hard_timeout self.heartbeat_timeout = heartbeat_timeout self.barrier_timeout = barrier_timeout self.completion_timeout = completion_timeout self.last_call_wait = last_call_wait self.termination_grace_time = termination_grace_time self.monitor_process_logfile = monitor_process_logfile self.enabled = enabled def __call__(self, fn): if not self.enabled: return fn @functools.wraps(fn) def wrapped(*args, **kwargs): with CallWrapper(self) as call_wrapper: return call_wrapper(fn, args, kwargs) return wrapped
[docs] class CallWrapper: r''' The :py:class:`CallWrapper` encapsulates the state and execution flow of the restart capabilities of the :py:class:`Wrapper` for a single invocation of the wrapped function. This design ensures that each call operates independently, with the restart state tied to the specific invocation rather than the function's definition. The :py:class:`CallWrapper` is automatically created by the :py:class:`Wrapper` with every invocation of the wrapped function. The active :py:class:`CallWrapper` instance is then passed as the value for any function argument annotated with ``CallWrapper`` or ``typing.Optional[CallWrapper]``. This allows the wrapped function to access and interact with the state of the :py:class:`Wrapper` during its execution. ''' def __init__(self, wrapper: Wrapper): monitor_process = None progress_watchdog = None try: Logging.initialize() log = logging.getLogger(__name__) enforce_value(not torch.distributed.is_initialized()) state = State() monitor_process = MonitorProcess( rank=state.rank, world_size=state.world_size, progress_timeout=wrapper.hard_timeout, termination_grace_time=wrapper.termination_grace_time, barrier_timeout=wrapper.barrier_timeout, interval=wrapper.monitor_process_interval, heartbeat_timeout=wrapper.heartbeat_timeout, log_filename=wrapper.monitor_process_logfile, store_factory=wrapper.store_factory, store_kwargs=wrapper.store_kwargs, ) store_kwargs = wrapper.store_kwargs base_store = wrapper.store_factory(**store_kwargs) log.debug(f'{base_store=} {store_kwargs=}') state.store = base_store base_store.initial_barrier( rank=state.rank, rendezvous_count=state.world_size, timeout=wrapper.barrier_timeout, ) base_store.set_initial_rank(state.rank, state.initial_rank) monitor_process.can_create_store() progress_watchdog = ProgressWatchdog( rank=state.rank, connection=monitor_process.parent_conn_time, interval=wrapper.progress_watchdog_interval, daemon=True, ) progress_watchdog.start() progress_watchdog.loop_started.wait() self.atomic_lock = threading.RLock() self.store = base_store self.base_store = base_store self.state = state self.monitor_process = monitor_process self.wrapper = wrapper self.progress_watchdog = progress_watchdog except Exception: if monitor_process is not None: monitor_process.shutdown() if progress_watchdog is not None: progress_watchdog.shutdown() Logging.deinitialize() raise def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.progress_watchdog.shutdown() self.monitor_process.shutdown() Logging.deinitialize() @property def iteration(self) -> int: r''' Returns integer index of the current restart iteration. ''' return self.state.iteration
[docs] def ping(self) -> None: r''' Sends a heartbeat to indicate that the workload is making meaningful forward progress. The optional manual progress timeout is initiated with the first call to :py:meth:`CallWrapper.ping` on each rank in a restart iteration. Once the timeout is activated, every distributed rank must periodically invoke :py:meth:`CallWrapper.ping` to confirm ongoing progress. If any rank fails to report progress within the specified ``soft_timeout`` or ``hard_timeout`` intervals for the :py:class:`Wrapper`, the rank will be considered unresponsive, and a restart of the wrapped function will be attempted. ''' self.progress_watchdog.ping()
[docs] @contextlib.contextmanager def atomic(self): r''' A context manager to wrap a section of the workload that must not be executed while the termination procedure is in progress. :py:meth:`atomic` is implemented with a reentrant lock, shared between the termination procedure and atomic section in the wrapped function. The termination procedure won't be launched if the main thread is executing :py:meth:`inprocess.CallWrapper.atomic` code block, and the main thread won't enter into :py:meth:`inprocess.CallWrapper.atomic` code block if termination procedure is already in progress. ''' while not self.atomic_lock.acquire(blocking=False): pass try: yield finally: self.atomic_lock.release()
@reraise_if_unraisable(RankShouldRestart) def __call__(self, fn, args, kwargs): log = logging.getLogger(__name__) store = self.base_store base_store = self.base_store state = self.state monitor_process = self.monitor_process wrapper = self.wrapper progress_watchdog = self.progress_watchdog state, terminated_ranks = wrapper.rank_assignment(state, set()) if terminated_ranks: msg = f'{terminated_ranks=} is not empty' raise RestartError(msg) state.activate_all_ranks() if wrapper.rank_filter is not None: state = wrapper.rank_filter(state) state.set_distributed_vars() monitor_process.start() monitor_thread = None ret = None try: while True: store = PrefixStore(state.iteration, base_store) store.set_initial_rank(state.rank, state.initial_rank) state.store = store store.iteration_barrier( rank=state.rank, rendezvous_count=state.world_size, timeout=wrapper.barrier_timeout, ) monitor_process.start_iteration( state.rank, state.world_size, state.iteration, ) progress_watchdog.reset() monitor_thread = MonitorThread( state=state, abort=wrapper.abort, interval=wrapper.monitor_thread_interval, abort_signal=None, progress_watchdog=progress_watchdog, soft_timeout=wrapper.soft_timeout, monitor_timeout=datetime.timedelta(seconds=5), last_call_wait=wrapper.last_call_wait, atomic_lock=self.atomic_lock, daemon=True, ) monitor_thread.start() args, kwargs = param_utils.substitute_param_value( fn, args, kwargs, { CallWrapper: self, }, ) try: try: monitor_thread.start_monitoring() try: try: if wrapper.initialize is not None: state = wrapper.initialize(state) except Exception as init_ex: log.error(log_exc(state, init_ex, 'init_ex')) raise try: if wrapper.health_check is not None: wrapper.health_check(state, None) except Exception as health_ex: log.error( log_exc(state, health_ex, 'health_ex') ) raise HealthCheckError from health_ex if state.status == Status.ACTIVE: ret = fn(*args, **kwargs) state.store.record_completed() elif state.status == Status.INACTIVE: ret = spare_fn( state, progress_watchdog, wrapper.progress_watchdog_interval, ) else: raise InternalError(f'{state}') progress_watchdog.reset() completion_timeout_chunk = ( wrapper.soft_timeout - wrapper.progress_watchdog_interval ) / 2 store.completion_barrier( rank=state.rank, rendezvous_count=state.world_size, timeout=wrapper.completion_timeout, timeout_chunk=completion_timeout_chunk, ) except Exception as train_ex: try: log.error(log_exc(state, train_ex, 'train_ex')) store.record_interrupted( InterruptionRecord( state.rank, Interruption.EXCEPTION ) ) progress_watchdog.spin_drain() monitor_thread.final_check_and_shutdown() except RankShouldRestart as async_ex: log.debug(log_exc(state, async_ex, 'async_ex')) monitor_thread.shutdown() raise async_ex from train_ex except Exception as other_ex: log.critical( log_exc(state, other_ex, 'other_ex') ) raise InternalError(f'{state}') from other_ex else: raise InternalError(f'{state}') from train_ex else: monitor_thread.shutdown() except RankShouldRestart as term_ex: log.warning(log_exc(state, term_ex, 'term_ex')) monitor_thread.shutdown() progress_watchdog.resume() try: if wrapper.finalize is not None: wrapper.finalize(state, term_ex.__cause__) except Exception as finalize_ex: log.error( log_exc(state, finalize_ex, 'finalize_ex') ) raise finalize_ex from term_ex try: if wrapper.health_check is not None: wrapper.health_check(state, term_ex.__cause__) except Exception as health_ex: log.error(log_exc(state, health_ex, 'health_ex')) try: raise health_ex from term_ex except Exception as ex: raise HealthCheckError from health_ex else: raise HealthCheckPassed from term_ex except Exception as term_ex: log.critical(log_exc(state, term_ex, 'term_ex')) raise InternalError(f'{state}') from term_ex except HealthCheckPassed as restart_ex: log.info(log_exc(state, restart_ex, 'restart_ex')) store.termination_barrier( rank=state.rank, rendezvous_count=state.world_size, timeout=wrapper.barrier_timeout, ) terminated_ranks = store.get_terminated_ranks() state, terminated_ranks = wrapper.rank_assignment( state, terminated_ranks ) if terminated_ranks: msg = f'{terminated_ranks=} is not empty' raise RestartError(msg) state.activate_all_ranks() if wrapper.rank_filter is not None: state = wrapper.rank_filter(state) state.set_distributed_vars() else: break finally: while gc.collect(): pass state.advance() except BaseException as exit_ex: log.critical(log_exc(state, exit_ex, 'exit_ex')) if monitor_thread is not None: monitor_thread.shutdown() store.record_interrupted( InterruptionRecord(state.rank, Interruption.BASE_EXCEPTION) ) store.record_terminated_rank(state.rank) store.termination_barrier( rank=state.rank, rendezvous_count=state.world_size, timeout=wrapper.barrier_timeout, ) raise exit_ex else: rank = state.rank log.debug(f'{rank=} clean exit') finally: initial_rank = state.initial_rank rank = state.rank initial_world_size = state.initial_world_size if initial_rank in base_store.critical_ranks: base_store.record_base_terminated_rank(initial_rank) log.debug(f'{rank=} {initial_rank=} waiting for other ranks') while ( base_store.get_base_terminated_count() != initial_world_size ): time.sleep(wrapper.monitor_thread_interval.total_seconds()) log.debug(f'{rank=} {initial_rank=} done waiting') progress_watchdog.shutdown() monitor_process.shutdown() else: progress_watchdog.shutdown() monitor_process.shutdown() base_store.record_base_terminated_rank(state.initial_rank) Logging.deinitialize() rank = state.rank log.debug(f'{rank=} returning') return ret