Source code for nvidia_resiliency_ext.fault_tolerance.rank_monitor_server

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import contextlib
import functools
import logging
import multiprocessing.resource_sharer as mp_resource_sharer
import os
import signal
import socket
import sys
import tempfile
import time
import traceback
from typing import Mapping, Optional

import torch
import torch.multiprocessing as mp

from ..shared_utils.health_check import GPUHealthCheck, NicHealthCheck
from .config import FaultToleranceConfig
from .data import (
    AuthkeyMsg,
    ErrorMsg,
    HeartbeatMsg,
    InitMsg,
    OkMsg,
    SectionAction,
    SectionMsg,
    UpdateConfigMsg,
)
from .rank_monitor_state_machine import RankMonitorStateMachine
from .utils import is_process_alive, read_obj_from_ipc_stream, write_obj_to_ipc_stream


[docs] class RankMonitorLogger(logging.Logger): """Logger used in a rank monitor process""" def __init__( self, name="RankMonServer", level=logging.INFO, connected_rank=None, is_restarter_logger=False, ): super().__init__(name, level) self.is_restarter_logger = is_restarter_logger self.hostname = socket.gethostname() self.pid = os.getpid() self._setup_logger() self._create_restarter_sublogger() self.set_connected_rank(connected_rank) def set_connected_rank(self, rank: Optional[int]): self.rank = rank rank_str = f"{self.rank}" if self.rank is not None else "<none>" name_fmt = f"{self.name}{self.pid}@{self.hostname}:{rank_str}" for handler in self.handlers: formatter = logging.Formatter(f"[%(asctime)s] [%(levelname)s] [{name_fmt}] %(message)s") handler.setFormatter(formatter) def log_for_restarter(self, message, *args, **kwargs): self.restarter_logger.log(logging.DEBUG, message, *args, **kwargs) def _setup_logger(self): self.setLevel(self.level) ch = logging.StreamHandler(sys.stderr) ch.setLevel(self.level) self.addHandler(ch) self.propagate = False def _create_restarter_sublogger(self): self.restarter_logger = logging.getLogger(f"{self.name}.Restarter") self.restarter_logger.setLevel( logging.DEBUG if self.is_restarter_logger else logging.CRITICAL ) stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setLevel(logging.DEBUG) formatter = logging.Formatter("[%(asctime)s] [ALWAYS] [%(name)s] %(message)s") stdout_handler.setFormatter(formatter) self.restarter_logger.addHandler(stdout_handler) self.restarter_logger.propagate = False return self.restarter_logger
[docs] class RankMonitorServer: """ RankMonitorServer, running in a separate process, is responsible for monitoring the ranks. RankMonitorClient is intialized in each rank and is used to communicate with the RankMonitorServer. """ """ How much time to wait for the rank monitor process to initialize. Default 3 minutes should be enough, even on slowest machines. """ RANK_MONITOR_INIT_TIMEOUT = 180.0 def __init__( self, cfg: FaultToleranceConfig, ipc_socket_path: str, rank_monitor_ready_event, is_restarter_logger: bool, logger: RankMonitorLogger, ): """ Initializes the RankMonitorServer object. Args: cfg (FaultToleranceConfig): The configuration object for fault tolerance. ipc_socket_path (str): Path of the IPC socket connecting this monitor with its rank rank_monitor_ready_event (mp.Event): The event indicating that the rank monitor is ready. is_restarter_logger (bool): True if this monitor writes state transition logs logger (Logger.Logger): The logger object for logging. """ self.cfg = cfg self.ipc_socket_path = ipc_socket_path self.rank_info = None self.start_time = None self.last_hb_time = None self.out_of_section_time = None self.open_sections: Mapping[str, float] = dict() # section name -> start time self.is_closing = None self.server = None self.connection_lock = asyncio.Lock() self.rank_monitor_ready_event = rank_monitor_ready_event self.logger = logger self.state_machine = RankMonitorStateMachine(logger) self._periodic_restart_task = None self.health_checker = GPUHealthCheck( interval=self.cfg.node_health_check_interval, on_failure=self._handle_unhealthy_node ) self.current_writer = None self.launcher_ipc_socket_path = ( f"{tempfile.gettempdir()}/_ft_launcher{os.getpid()}_to_rmon.socket" ) self.launcher_server = None if self.cfg.enable_nic_monitor: self.logger.info("Enable NIC health monitoring.") self.nic_health_checker = NicHealthCheck( interval=self.cfg.node_health_check_interval, pci_topo_file=self.cfg.pci_topo_file, link_down_path_template=self.cfg.link_down_path_template, on_failure=self._handle_unhealthy_nic, ) else: self.nic_health_checker = None def start_periodic_restart_check(self): if self._periodic_restart_task and not self._periodic_restart_task.done(): self.logger.warning("Periodic restart check is already running.") return self._periodic_restart_task = asyncio.get_running_loop().create_task( self._periodic_restart_check() ) self.logger.info("Started periodic restart check.") async def stop_periodic_restart_check(self): if self._periodic_restart_task: self._periodic_restart_task.cancel() try: await self._periodic_restart_task except asyncio.CancelledError: self.logger.info("Periodic restart check task cancelled.") self._periodic_restart_task = None def _shutdown_rank(self): # First sends SIGCONT to wake up the process, then "rank_termination_signal" to terminate it try: sig = self.cfg.rank_termination_signal rank_pid = self.rank_info.pid os.kill(rank_pid, signal.SIGCONT) os.kill(rank_pid, sig) self.logger.debug( f"Rank monitor sent SIGCONT then {sig.name} to the rank (PID={rank_pid})" ) except Exception as e: self.logger.error( f"Rank monitor could not send {sig.name} to the rank (PID={rank_pid}). Exception={e}" ) # do nothing, most probably rank process is terminated anyway def _shutdown_rank_if_alive(self): if self.rank_info is None: return try: if is_process_alive(self.rank_info.pid): self._shutdown_rank() else: self.logger.debug( f"_shutdown_rank_if_alive: rank process {self.rank_info.pid} is dead." ) except Exception: self.logger.error("_shutdown_rank_if_alive exception! ", exc_info=True) async def _handle_unhealthy_node(self): self.logger.debug("In RankMonitorServer._handle_unhealthy_node.") self._shutdown_rank_if_alive() async def _handle_unhealthy_nic(self): self.logger.debug("In RankMonitorServer._handle_unhealthy_nic.") self._shutdown_rank_if_alive() async def _handle_timeout(self): self.logger.debug("In RankMonitorServer._handle_timeout.") self._shutdown_rank_if_alive() async def _handle_authkey_msg(self, msg, writer): # resource sharer needs to be restarted to use updated authkey # after we stop it, it will be started automatically. # TODO: are there any side effects? mp_resource_sharer.stop() # this process authkey need to be set to the same authkey as in client # otherwise we wont be able to receive pickled Manager, Tensors etc. mp.current_process().authkey = msg.authkey await write_obj_to_ipc_stream(OkMsg(), writer) async def _handle_update_config_msg(self, msg, writer): # Reseived when new timeouts are computed by the client if msg.hb_timeouts is not None: self.cfg.initial_rank_heartbeat_timeout = msg.hb_timeouts.initial self.cfg.rank_heartbeat_timeout = msg.hb_timeouts.subsequent self.logger.debug(f"Updated heartbeat timeouts: {msg.hb_timeouts}") if msg.section_timeouts is not None: self.cfg.rank_section_timeouts = msg.section_timeouts.section self.cfg.rank_out_of_section_timeout = msg.section_timeouts.out_of_section self.logger.debug(f"Updated section timeouts: {msg.section_timeouts}") await write_obj_to_ipc_stream(OkMsg(), writer) async def _handle_init_msg(self, msg, writer): self.rank_info = msg.rank_info self.start_time = time.monotonic() self.out_of_section_time = time.monotonic() self.open_sections.clear() self.last_hb_time = None # Update NIC health checker on the rank to monitor. if self.nic_health_checker is not None: self.nic_health_checker.set_nic_device(local_rank=self.rank_info.local_rank) self.logger.set_connected_rank(msg.rank_info.global_rank) await write_obj_to_ipc_stream(OkMsg(cfg=self.cfg), writer) async def _handle_heartbeat_msg(self, msg, writer): if self._periodic_restart_task is not None: await self.stop_periodic_restart_check() self.state_machine.handle_heartbeat_msg() self.last_hb_time = time.monotonic() assert not msg.state, "state in heartbeat is not supported in this version" await write_obj_to_ipc_stream(OkMsg(), writer) async def _handle_section_msg(self, msg: SectionMsg, writer): if self._periodic_restart_task is not None: await self.stop_periodic_restart_check() self.state_machine.handle_section_msg() resp = ErrorMsg() current_time = time.monotonic() if msg.action is SectionAction.OPEN: if msg.section not in self.open_sections: self.open_sections[msg.section] = current_time self.out_of_section_time = None resp = OkMsg() else: resp = ErrorMsg(cause=f"Section '{msg.section}' is already open.") elif msg.action is SectionAction.CLOSE: if msg.section in self.open_sections: del self.open_sections[msg.section] if not self.open_sections: self.out_of_section_time = current_time resp = OkMsg() else: resp = ErrorMsg(cause=f"Section '{msg.section}' is not open.") elif msg.action is SectionAction.CLOSE_ALL: self.open_sections.clear() self.out_of_section_time = current_time resp = OkMsg() else: raise AssertionError(f"Unknown SectionAction: {msg.action}") await write_obj_to_ipc_stream(resp, writer) def _handle_ipc_connection_lost(self): self.state_machine.handle_ipc_connection_lost() if self.state_machine.is_restarting(): self.start_periodic_restart_check() self.rank_info = None self.start_time = None self.last_hb_time = None self.out_of_section_time = None if self.open_sections: open_section_names = list(self.open_sections.keys()) self.logger.warning( f"Section(s) {open_section_names} were still open. you can use`.end_all_sections` to avoid this warning" ) self.open_sections.clear() self.logger.set_connected_rank(None) if self.connection_lock.locked(): self.connection_lock.release() async def _handle_launcher_ipc_connection( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: """Handle IPC connection from the launcher.""" try: while True: msg = await read_obj_from_ipc_stream(reader) if msg == "close_worker_ipc_connection": self.logger.info( "Received request from launcher to close worker IPC connection" ) await self.close_current_connection() else: self.logger.warning(f"Received unknown message from launcher: {msg}") except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, EOFError): # Valid stream close exceptions. pass except Exception as e: self.logger.error(f"Error handling launcher IPC connection: {e}") finally: writer.close() await writer.wait_closed() async def _handle_ipc_connection(self, reader, writer): if not self.is_closing: try: if self.connection_lock.locked(): self.logger.warning( "Got a new connection while previous is still active. Need to wait..." ) await asyncio.wait_for( self.connection_lock.acquire(), RankMonitorServer.RANK_MONITOR_INIT_TIMEOUT, ) self.current_writer = writer while msg := await read_obj_from_ipc_stream(reader): if isinstance(msg, AuthkeyMsg): await self._handle_authkey_msg(msg, writer) elif isinstance(msg, InitMsg): await self._handle_init_msg(msg, writer) elif isinstance(msg, HeartbeatMsg): await self._handle_heartbeat_msg(msg, writer) elif isinstance(msg, UpdateConfigMsg): await self._handle_update_config_msg(msg, writer) elif isinstance(msg, SectionMsg): await self._handle_section_msg(msg, writer) else: raise AssertionError(f"Unknown msg type: {type(msg)}") except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, EOFError): # Valid stream close exceptions. pass except Exception: self.logger.error(f"Exception in _handle_ipc_connection: {traceback.format_exc()}") finally: self.current_writer = None self._handle_ipc_connection_lost() assert not self.connection_lock.locked() else: self.logger.warning("Ignored incoming connection: service is shutting down.") with contextlib.suppress(Exception): writer.close() await writer.wait_closed() self.logger.debug("Leaving _handle_ipc_connection. ") def _handle_signal(self, sig): self.state_machine.handle_signal() self.logger.debug(f"Received signal: {sig}") if not self.is_closing: self.is_closing = True self.server.close() self.launcher_server.close() def _is_hb_timeout_elapsed(self, curr_time) -> bool: is_elapsed = False if self.last_hb_time is None: # has not got any heartbeats yet time_since_start = curr_time - self.start_time timeout = self.cfg.initial_rank_heartbeat_timeout is_elapsed = timeout is not None and time_since_start > timeout if is_elapsed: self.logger.warning(f"Did not get initial heartbeat. Waited {timeout:.2f} seconds.") else: # did get some heartbeats time_since_last_hb = curr_time - self.last_hb_time timeout = self.cfg.rank_heartbeat_timeout is_elapsed = timeout is not None and time_since_last_hb > timeout if is_elapsed: self.logger.warning( f"Did not get subsequent heartbeat. " f"Waited {timeout:.2f} seconds." ) return is_elapsed def _is_section_timeout_elapsed(self, curr_time) -> bool: is_elapsed = False # If any sections are open, check their timeouts, # otherwise check the timeout for "out of section" for section, section_start_time in self.open_sections.items(): elapsed = curr_time - section_start_time timeout = self.cfg.rank_section_timeouts[section] is_elapsed = timeout is not None and elapsed > timeout if is_elapsed: self.logger.warning( f"Section '{section}' has been open for {elapsed:.2f} seconds. " f"Timeout is {timeout:.2f} seconds." ) is_elapsed = True if not self.open_sections: elapsed = curr_time - self.out_of_section_time timeout = self.cfg.rank_out_of_section_timeout is_elapsed = timeout is not None and elapsed > timeout if is_elapsed: self.logger.warning( f"Was out of section for {elapsed:.2f} seconds. " f"Timeout is {timeout:.2f} seconds." ) return is_elapsed async def _periodic_rank_check(self): while True: if self.rank_info is not None: curr_time = time.monotonic() hb_timed_out = self._is_hb_timeout_elapsed(curr_time) section_timed_out = self._is_section_timeout_elapsed(curr_time) if hb_timed_out or section_timed_out: self.logger.warning('FT timeout elapsed.') await self._handle_timeout() await asyncio.sleep(self.cfg.workload_check_interval) async def _periodic_restart_check(self): await asyncio.sleep(self.cfg.restart_check_interval) while True: self.state_machine.periodic_restart_check() await asyncio.sleep(self.cfg.restart_check_interval) async def _periodic_node_health_check(self): await self.health_checker.async_check() async def _periodic_nic_health_check(self): await self.nic_health_checker.async_check() async def _rank_monitor_loop(self): # Handle usual termination signals for sig_to_handle in [ signal.SIGTERM, signal.SIGINT, signal.SIGUSR1, signal.SIGUSR2, ]: asyncio.get_event_loop().add_signal_handler( sig_to_handle, functools.partial(self._handle_signal, sig_to_handle), ) # Run local server if os.path.exists(self.ipc_socket_path): os.unlink(self.ipc_socket_path) self.server = await asyncio.start_unix_server( self._handle_ipc_connection, self.ipc_socket_path, ) # Start launcher IPC server if os.path.exists(self.launcher_ipc_socket_path): os.unlink(self.launcher_ipc_socket_path) self.launcher_server = await asyncio.start_unix_server( self._handle_launcher_ipc_connection, self.launcher_ipc_socket_path, ) # Periodic checks asyncio.get_running_loop().create_task(self._periodic_rank_check()) # Periodic node health check asyncio.get_running_loop().create_task(self._periodic_node_health_check()) # Periodic nic health check if self.nic_health_checker is not None: asyncio.get_running_loop().create_task(self._periodic_nic_health_check()) self.rank_monitor_ready_event.set() try: async with self.server, self.launcher_server: await asyncio.gather( self.server.serve_forever(), self.launcher_server.serve_forever() ) except asyncio.exceptions.CancelledError: self.logger.debug("server.serve_forever() cancelled.") finally: with contextlib.suppress(Exception): os.unlink(self.ipc_socket_path) os.unlink(self.launcher_ipc_socket_path) @staticmethod def run( cfg: FaultToleranceConfig, ipc_socket_path: str, rank_monitor_ready_event, is_restarter_logger: bool, ) -> None: try: logger = RankMonitorLogger(level=cfg.log_level, is_restarter_logger=is_restarter_logger) logger.debug(f"Starting RankMonitorServer... PID={os.getpid()}") inst = RankMonitorServer( cfg, ipc_socket_path, rank_monitor_ready_event, is_restarter_logger, logger, ) asyncio.run(inst._rank_monitor_loop()) logger.debug("Leaving RankMonitorServer process") except asyncio.exceptions.CancelledError: logger.debug("asyncio.exceptions.CancelledError in RankMonitorServer.run") pass # Ignore the exception except Exception as e: logger.error(f"Exception caught in RankMonitorServer.run: {e}") pass # Ignore the exception @staticmethod def run_in_subprocess( cfg, ipc_socket_path: str, is_restarter_logger: bool = False, mp_ctx=torch.multiprocessing ): rank_monitor_ready_event = mp_ctx.Event() rank_monitor_process_kwargs = { "cfg": cfg, "ipc_socket_path": ipc_socket_path, "rank_monitor_ready_event": rank_monitor_ready_event, "is_restarter_logger": is_restarter_logger, } rank_monitor_process = mp_ctx.Process( target=RankMonitorServer.run, kwargs=rank_monitor_process_kwargs ) rank_monitor_process.daemon = True rank_monitor_process.start() if not rank_monitor_ready_event.wait(timeout=RankMonitorServer.RANK_MONITOR_INIT_TIMEOUT): raise RuntimeError( f"Could not start rank monitor. Waited {RankMonitorServer.RANK_MONITOR_INIT_TIMEOUT} sec." ) return rank_monitor_process async def close_current_connection(self): if self.current_writer is not None: self.current_writer.close() await self.current_writer.wait_closed() self.logger.debug("Closed current IPC connection") else: self.logger.warning("No active connection to close")