Source code for nvidia_resiliency_ext.fault_tolerance.config

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import contextlib
import dataclasses
import logging
import signal
from dataclasses import dataclass, fields
from typing import Optional

import yaml


[docs] @dataclass class FaultToleranceConfig: """ Configuration of fault tolerance - `workload_check_interval` [float] periodic rank check interval (in seconds) in rank monitors. - `initial_rank_heartbeat_timeout` [float] timeout (in seconds) for the first heartbeat from a rank. Usually, it takes a bit longer for the first heartbeat to be sent, as the rank needs to initialize. If rank does not send the first heartbeat within `initial_rank_heartbeat_timeout`, failure is detected. If None this timeout needs to be deduced and set during runtime, based on the observed heartbeat intervals. - `rank_heartbeat_timeout` [float] timeout (in seconds) for subsequent heartbeats from a rank. If no rank heartbeat is received within `rank_heartbeat_timeout`, failure is detected. If None this timeout needs to be deduced and set during runtime, based on the observed heartbeat intervals. - `safety_factor` [float] when deducing the timeouts, observed heartbeat intervals are multiplied by this factor to obtain the timeouts. - `rank_termination_signal` signal used to terminate the rank when failure is detected. - `log_level` log level of fault tolerance components """ workload_check_interval: float = 5.0 initial_rank_heartbeat_timeout: Optional[float] = 60.0 * 60.0 rank_heartbeat_timeout: Optional[float] = 45.0 * 60.0 safety_factor: float = 5.0 rank_termination_signal: signal.Signals = signal.SIGKILL log_level: int = logging.INFO
[docs] @staticmethod def from_kwargs(ignore_not_recognized: bool = True, **kwargs) -> 'FaultToleranceConfig': """ Create a FaultToleranceConfig object from keyword arguments. Args: ignore_not_recognized (bool, optional): Whether to ignore unrecognized arguments. Defaults to True. **kwargs: Keyword arguments representing the fields of the FaultToleranceConfig object. Returns: FaultToleranceConfig: The created FaultToleranceConfig object. Raises: ValueError: If there are unrecognized arguments and ignore_not_recognized is False. """ fields_set = {f.name for f in fields(FaultToleranceConfig) if f.init} matching_args = {k: v for k, v in kwargs.items() if k in fields_set} extra_args = {k: v for k, v in kwargs.items() if k not in fields_set} if extra_args and not ignore_not_recognized: raise ValueError(f"Not recognized args: {extra_args}") return FaultToleranceConfig(**matching_args)
[docs] @staticmethod def from_yaml_file(cfg_path: str, ignore_not_recognized: bool = True) -> 'FaultToleranceConfig': """ Load the fault tolerance configuration from a YAML file. YAML file should contain `fault_tolerance` section. `fault_tolerance` section can be at the top level or nested in any other section. Args: cfg_path (str): The path to the YAML configuration file. ignore_not_recognized (bool, optional): Whether to ignore unrecognized configuration options. Defaults to True. Returns: FaultToleranceConfig: The fault tolerance configuration object. Raises: ValueError: If the 'fault_tolerance' section is not found in the config file. """ with open(cfg_path, 'r') as file: yaml_data = yaml.safe_load(file) ft_cfg = FaultToleranceConfig._find_fault_tol_section(yaml_data) if ft_cfg: return FaultToleranceConfig.from_kwargs( **ft_cfg, ignore_not_recognized=ignore_not_recognized ) else: raise ValueError(f"'fault_tolerance' section not found in config file {cfg_path}")
[docs] @staticmethod def from_args( args: argparse.Namespace, cfg_file_arg: str = None, ft_args_prefix: str = '', ): """ Init FT config object from parsed CLI args. Implements the following logic: - Use default FT config as a base. - If there is a config file argument defined, first try to read the FT config from the file. - Update the FT config with FT args provided via CLI. - If can't read from file and there are no related args in CLI, raise an exception. Args: args (argparse.Namespace): Parsed arguments cfg_file_arg (str, optional): Name of the argument that contains the FT config YAML file. Defaults to None - do not try to read from file. ft_args_prefix (str, optional): Prefix of the FT related args. Defaults to empty str - assume no prefix. """ ft_cfg = FaultToleranceConfig() is_read_from_file = False if cfg_file_arg: cfg_path = getattr(args, cfg_file_arg) if cfg_path is not None: with contextlib.suppress(ValueError): ft_cfg = FaultToleranceConfig.from_yaml_file(cfg_path) is_read_from_file = True # extract FT args specified via CLI, remove the common FT args prefix # so we should get FaultToleranceConfig field name -> value mapping provided_ft_args = { k.removeprefix(ft_args_prefix): v for k, v in vars(args).items() if k.startswith(ft_args_prefix) and v is not None } for arg_name, arg_val in provided_ft_args.items(): assert hasattr( ft_cfg, arg_name ), f"Invalid FT parameter specified via CLI: {ft_args_prefix}{arg_name}." setattr(ft_cfg, arg_name, arg_val) ft_cfg._fix_log_level_type() ft_cfg._fix_rank_termination_signal_type() if not (is_read_from_file or provided_ft_args): raise ValueError("No fault tolerance configuration provided.") return ft_cfg
[docs] def to_yaml_file(self, cfg_path: str) -> None: """ Convert the configuration object to a YAML file and save it to the specified path. Args: cfg_path (str): The path to save the YAML file. Returns: None """ # first, ensure that `rank_termination_signal` and `log_level` have their native types # this might not be the case, if the object was modified after creation self._fix_rank_termination_signal_type() self._fix_log_level_type() with open(cfg_path, 'w') as file: ft_cfg_dict = dataclasses.asdict(self) ft_cfg_dict['rank_termination_signal'] = self.rank_termination_signal.name ft_cfg_dict['log_level'] = self.log_level ft_cfg_dict = {'fault_tolerance': ft_cfg_dict} yaml.dump(ft_cfg_dict, file)
@staticmethod def _find_fault_tol_section(yaml_data): if isinstance(yaml_data, dict): if "fault_tolerance" in yaml_data: return yaml_data["fault_tolerance"] else: for key, value in yaml_data.items(): sub_config = FaultToleranceConfig._find_fault_tol_section(value) if sub_config: return sub_config elif isinstance(yaml_data, list): for item in yaml_data: sub_config = FaultToleranceConfig._find_fault_tol_section(item) if sub_config: return sub_config return None def _fix_rank_termination_signal_type(self): if isinstance(self.rank_termination_signal, int): self.rank_termination_signal = signal.Signals(self.rank_termination_signal) elif isinstance(self.rank_termination_signal, str): sig_str = self.rank_termination_signal.upper() if getattr(signal, sig_str, None) is None: raise ValueError( f"Invalid rank_termination_signal string: {self.rank_termination_signal}" ) self.rank_termination_signal = signal.Signals[sig_str] elif isinstance(self.rank_termination_signal, signal.Signals): self.rank_termination_signal = self.rank_termination_signal else: raise ValueError( f"Invalid value for rank_termination_signal: {self.rank_termination_signal}" ) def _fix_log_level_type(self): if isinstance(self.log_level, int): if not (logging.DEBUG <= self.log_level <= logging.CRITICAL): raise ValueError( f"Invalid log level value ({self.log_level}). Should be in [{logging.DEBUG} (DEBUG), {logging.FATAL} (CRITICAL)]" ) elif isinstance(self.log_level, str): log_level_str = self.log_level.upper() if log_level_str in ['DEBUG', 'DBG']: self.log_level = logging.DEBUG elif log_level_str == 'INFO': self.log_level = logging.INFO elif log_level_str in ['WARNING', 'WARN']: self.log_level = logging.WARNING elif log_level_str == 'ERROR': self.log_level = logging.ERROR elif log_level_str == 'CRITICAL': self.log_level = logging.CRITICAL else: raise ValueError(f"Invalid log level string: {self.log_level}") else: raise ValueError(f"Invalid value for rank_termination_signal: {self.log_level}") def __post_init__(self): self._fix_rank_termination_signal_type() self._fix_log_level_type()