# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import pathlib
import random
import signal
import sys
import threading
import time
from dataclasses import dataclass
from typing import Optional, Union
import torch
from ._utils import is_module_available
if is_module_available("lightning"):
from lightning.pytorch.callbacks import Callback
elif is_module_available("pytorch_lightning"):
from pytorch_lightning.callbacks import Callback
else:
raise ImportError("Could not find 'lightning' or 'pytorch_lightning' module")
import nvidia_resiliency_ext.fault_tolerance as ft
[docs]
@dataclass
class SimulatedFaultParams:
"""
Description of a simulated rank fault, used for FT testing and debugging.
Simulated fault types are:
- 'rank_killed' a rank is killed with SIGKILL
- 'rank_hung' a rank is stopped with SIGSTOP
- 'random' randomly selects one of the above faults.
Fault delay is computed as:
- `base_delay` + RAND_FLOAT_FROM_0.0_to_1.0 * `rand_delay`
Attributes:
fault_type (str): The type of fault, one of: ['random', 'rank_killed', 'rank_hung'].
base_delay (float): The base (minimum) delay [seconds] for the fault.
rand_delay (float, optional): The max additional random delay for the fault. Defaults to 0.0.
rank_to_fail (int, optional): The rank to fail. Defaults to None - random rank will be picked.
"""
fault_type: str
base_delay: float
rand_delay: float = 0.0
rank_to_fail: Optional[int] = None
class _TrainingStateMachine:
"""
This class encapsulates logic for determining when:
- training is finished successfully (`.is_training_completed` method)
- FT timeouts can be updated (`.can_update_timeouts` property)
`on_ ...` methods update the state and should be called from the corresponding PTL callback methods.
`on_ft_heartbeat_sent` should be called after each FT heartbeat.
"""
MIN_ITERS_FOR_TIMEOUT_UPDATE = 2
def __init__(self):
self.num_tr_iters_total = 0
self.num_hb_total = 0
self.num_hb_at_last_save = None
self.seen_checkpointing = False
self.loaded_checkpoint = False
self.caught_exception = False
self.is_stop_exception = False
self.training_ended = False
self.timeouts_updated = False
def on_setup(self):
assert self.num_tr_iters_total == 0
assert self.num_hb_total == 0
def on_teardown(self):
self.training_ended = True
def on_load_checkpoint(self):
self.loaded_checkpoint = True
def on_save_checkpoint(self):
self.num_hb_at_last_save = self.num_hb_total
def on_train_start(self):
pass
def on_train_batch_end(self):
self.num_tr_iters_total += 1
def on_train_end(self):
pass
def on_validation_start(self):
pass
def on_validation_batch_end(self):
pass
def on_validation_end(self):
pass
def on_exception(self, exc=None):
self.caught_exception = True
# check if `sys.exit(0)` was invoked, interpret that as a "clean exit".
# it's used i.e. by the NeMo preemption callback to stop the training.
# NOTE: _TunerExitException raised by NeMo StatelessTimer is NOT captured here,
# but `teardown` hook is called when _TunerExitException is raised.
self.is_stop_exception = isinstance(exc, SystemExit) and not exc.code
def on_ft_timeouts_updated(self):
self.timeouts_updated = True
def on_ft_heartbeat_sent(self):
self.num_hb_total += 1
if not self.seen_checkpointing and self.num_hb_at_last_save is not None:
# detect checkpointing that makes hearbeat interval longer
# NOTE: neeed at least one post-checkpointing heartbeat
num_pre_save = self.num_hb_at_last_save
num_post_save = self.num_hb_total - self.num_hb_at_last_save
self.seen_checkpointing = num_pre_save > 0 and num_post_save > 0
def is_training_completed(self, trainer=None) -> bool:
"""
Returns True if training is finished sucessfuly.
"""
# if exiting AND just 0 or 1 training iterations were made AND error is not set,
# assume training has finished successfully and there is nothing else to do.
# 1 iteration is made when we run a workload for which 'max_time' elapsed,
# so need to handle that special case.
has_failed = self.caught_exception and not self.is_stop_exception
if self.training_ended and self.num_tr_iters_total <= 1 and not has_failed:
return True
if trainer is not None:
# if iters limit is reached:
if (
isinstance(trainer.max_steps, int)
and trainer.max_steps > 0
and trainer.global_step >= trainer.max_steps
):
return True
# if epochs limit is reached
if (
isinstance(trainer.max_epochs, int)
and trainer.max_epochs > 0
and trainer.current_epoch >= trainer.max_epochs
):
return True
return False
@property
def can_update_timeouts(self) -> bool:
"""
Returns True if new timeouts can be computed.
`.on_timeouts_updated()` resets this property back to False.
"""
if self.timeouts_updated:
# timeouts are updated at most once per training run
return False
if self.num_tr_iters_total < self.MIN_ITERS_FOR_TIMEOUT_UPDATE:
# need a few training iters
return False
if self.caught_exception and not self.is_stop_exception:
# if stopping due to an exception, and it isn't "graceful stop" exception
return False
# check if there was checkpoint loading and saving
# this makes heartbeat iterval longer than usual.
return self.loaded_checkpoint and self.seen_checkpointing
[docs]
class FaultToleranceCallback(Callback):
"""
FaultToleranceCallback is a Torch Lightning callback for integration with the Fault Tolerance package.
FT is only active during a 'fit' stage.
Training should be run with 'ft_launcher' for the callback to work.
"""
TIMEOUTS_FILENAME = "_ft_state.json"
FT_DIR_NAME = "ft_state"
[docs]
def __init__(
self,
autoresume: bool,
calculate_timeouts: bool,
simulated_fault_params: Union[SimulatedFaultParams, dict, None] = None,
exp_dir: Union[str, pathlib.Path, None] = None,
logger_name: Optional[str] = "nemo_logger.FaultToleranceCallback",
):
"""
Initialize callback instance.
This is a lightweight initialization. Most of the initialization is conducted in the 'setup' hook.
Args:
autoresume (bool): Set to `True` if the FT auto-resume feature is used (e.g., there are multiple training jobs to be run).
calculate_timeouts (bool): Set to `True` if FT timeouts should be calculated based on observed heartbeat intervals.
Calculated timeouts overwrite the timeouts from the FT config.
Timeouts are computed at the end of a training job, if there was checkpoint loading and saving.
For example, for training started from scratch, the timeouts are computed at the end of the second job.
simulated_fault_params (SimulatedFaultParams, dict, DictConfig, None): Simulated fault spec. It's for debugging only. Defaults to None.
Should be a `SimulatedFaultParams` instance or any object that can be used for SimulatedFaultParams initialization with `SimulatedFaultParams(**obj)`.
exp_dir (Union[str, pathlib.Path, None], optional): Directory where the FT state should be saved.
Must be available for all training jobs. NOTE: Beware that PTL can move files written to its `trainer.log_dir`.
Defaults to None, in which case it defaults to `trainer.log_dir/ft_state`.
logger_name (Optional[str], optional): Logger name to be used.
Defaults to "nemo_logger.FaultToleranceCallback".
"""
self.logger = logging.getLogger(logger_name)
self.fault_tol_client = None
self.autoresume = autoresume
self.calculate_timeouts = calculate_timeouts
self.simulated_fault_params = self._parse_simulated_fault_params(simulated_fault_params)
self.state_machine = None
self.provided_exp_dir = exp_dir
self.timeouts_file_path = None
self._verify_env()
@property
def is_initialized(self):
return self.fault_tol_client is not None
[docs]
def setup(self, trainer, pl_module, stage):
if stage == "fit":
self.state_machine = _TrainingStateMachine()
self.state_machine.on_setup()
self._setup_fault_tolerance(trainer)
[docs]
def teardown(self, trainer, pl_module, stage):
# FT might be already deinitialized due to an exception
if stage == "fit" and self.is_initialized:
self.state_machine.on_teardown()
if self._is_rank0():
if self.autoresume and self.state_machine.is_training_completed(trainer):
self._create_finished_flag_file()
self._send_ft_heartbeat()
self._maybe_update_ft_timeouts()
self._shutdown_fault_tolerance()
[docs]
def on_train_start(self, *args, **kwargs):
self.state_machine.on_train_start()
self._send_ft_heartbeat()
[docs]
def on_train_batch_end(self, *args, **kwargs):
self.state_machine.on_train_batch_end()
self._send_ft_heartbeat()
[docs]
def on_train_end(self, *args, **kwargs):
self.state_machine.on_train_end()
self._send_ft_heartbeat()
[docs]
def on_validation_start(self, *args, **kwargs):
# this can be called outside of `fit` stage
if self.is_initialized:
self.state_machine.on_validation_start()
self._send_ft_heartbeat()
[docs]
def on_validation_batch_end(self, *args, **kwargs):
# this can be called outside of `fit` stage
if self.is_initialized:
self.state_machine.on_validation_batch_end()
self._send_ft_heartbeat()
[docs]
def on_validation_end(self, *args, **kwargs):
# this can be called outside of `fit` stage
if self.is_initialized:
self.state_machine.on_validation_end()
self._send_ft_heartbeat()
[docs]
def on_load_checkpoint(self, *args, **kwargs):
# this can be called outside of `fit` stage
if self.is_initialized:
self.state_machine.on_load_checkpoint()
[docs]
def on_save_checkpoint(self, *args, **kwargs):
# this can be called outside of `fit` stage
if self.is_initialized:
self.state_machine.on_save_checkpoint()
# in NeMo, it can happen that there are 2 checkpointing operations
# one after another, without any training/eval iteration between.
# send a heartbeat, so in such case we wont get unusually long interval.
self._send_ft_heartbeat()
[docs]
def on_exception(self, trainer, pl_module, exception):
# this can be called outside of `fit` stage
if self.is_initialized:
self.state_machine.on_exception(exception)
self._send_ft_heartbeat()
self._maybe_update_ft_timeouts()
self._shutdown_fault_tolerance()
def _is_rank0(self):
return torch.distributed.is_initialized() and torch.distributed.get_rank() == 0
def _log_info_on_rank0(self, msg):
if self._is_rank0():
self.logger.info("[FaultToleranceCallback@rank0] " + str(msg))
def _verify_env(self):
if self.autoresume and not os.environ.get('FAULT_TOL_FINISHED_FLAG_FILE', ''):
raise RuntimeError(
"'FAULT_TOL_FINISHED_FLAG_FILE' env variable is not set. Was this job launched with FT launcher?"
)
def _send_ft_heartbeat(self):
self.fault_tol_client.send_heartbeat()
self.state_machine.on_ft_heartbeat_sent()
def _maybe_update_ft_timeouts(self):
if self.calculate_timeouts and self.state_machine.can_update_timeouts:
self.fault_tol_client.calculate_and_set_timeouts()
self.state_machine.on_ft_timeouts_updated()
self._log_info_on_rank0(
f'Updated FT timeouts. New values: {self.fault_tol_client.timeouts}'
)
if self._is_rank0():
# FT state is the same on all ranks, so we can save it only on rank 0
with open(self.timeouts_file_path, mode='w') as f:
json.dump(self.fault_tol_client.state_dict(), f)
def _maybe_load_ft_timeouts(self):
if self.calculate_timeouts:
# we load the timeouts only when calculate_timeouts=True
loaded_ft_state_dict = {}
if self.timeouts_file_path.exists():
with open(self.timeouts_file_path, mode='r') as f:
loaded_ft_state_dict = json.load(f)
if loaded_ft_state_dict:
self.fault_tol_client.load_state_dict(loaded_ft_state_dict)
ft_timeouts = self.fault_tol_client.timeouts
self._log_info_on_rank0(f"Fault tolerance timeouts loaded: {ft_timeouts}")
def _setup_fault_tolerance(self, trainer):
assert not self.is_initialized, "Fault tolerance client already initialized."
self.fault_tol_client = ft.RankMonitorClient()
# Format timeouts file path
if self.provided_exp_dir:
ft_dir = pathlib.Path(self.provided_exp_dir)
else:
ft_dir = pathlib.Path(trainer.log_dir) / self.FT_DIR_NAME
if self._is_rank0():
ft_dir.mkdir(exist_ok=True)
trainer.strategy.barrier()
self._log_info_on_rank0(f"Fault tolerance dir: {ft_dir}")
if not ft_dir.exists():
raise ValueError(f"Fault tolerance save directory does not exist: {ft_dir}")
self.timeouts_file_path = ft_dir / self.TIMEOUTS_FILENAME
self.fault_tol_client.init_workload_monitoring()
self._maybe_load_ft_timeouts()
ft_timeouts = self.fault_tol_client.timeouts
if ft_timeouts.are_valid:
self._log_info_on_rank0(f"Fault tolerance client initialized. Timeouts: {ft_timeouts}")
else:
if self.calculate_timeouts:
self._log_info_on_rank0(
"Fault tolerance client initialized. Timeouts: not calculated yet."
)
else:
raise RuntimeError(
"Fault tolerance doesn't have valid timeouts set and 'calculate_timeouts' is False."
)
# Simulated fault for testing/debug purposes
if self.simulated_fault_params:
self._setup_simulated_fault()
assert self.is_initialized
def _shutdown_fault_tolerance(self):
if self.is_initialized:
self.fault_tol_client.shutdown_workload_monitoring()
self.fault_tol_client = None
assert not self.is_initialized
def _create_finished_flag_file(self):
try:
flag_file_path = pathlib.Path(os.environ["FAULT_TOL_FINISHED_FLAG_FILE"])
flag_file_path.touch()
except Exception as e:
self.logger.error(f"_create_finished_flag_file exception: {e}")
def _parse_simulated_fault_params(
self, simulated_fault_params
) -> Optional[SimulatedFaultParams]:
# TODO: this if for testing only, should be removed in release version
if simulated_fault_params is None:
return None
if isinstance(simulated_fault_params, SimulatedFaultParams):
return simulated_fault_params
try:
return SimulatedFaultParams(**simulated_fault_params)
except Exception as e:
raise ValueError(
f"Failed to parse simulated fault params, "
"it should be SimulatedFaultParams instance or "
"an object that can be unpacked with '**' and passed to the "
f"SimulatedFaultParams.__init__ Got: {simulated_fault_params}"
) from e
def _setup_simulated_fault(self):
# TODO: this if for testing only, should be removed in release version
rng = random.Random()
fault_desc = self.simulated_fault_params
self._log_info_on_rank0(f"Initializing simulated fault: {fault_desc}")
# rank that simulates a fault can be explicitly specified in the `rank_to_fail` field
# if not specified, it just picks a random rank
rank = torch.distributed.get_rank()
rand_rank = rng.randint(0, torch.distributed.get_world_size() - 1)
rank_to_fail = fault_desc.rank_to_fail if fault_desc.rank_to_fail is not None else rand_rank
rank_to_fail = torch.tensor([rank_to_fail], device=torch.cuda.current_device())
torch.distributed.broadcast(rank_to_fail, 0)
rank_to_fail = int(rank_to_fail.item())
if rank != rank_to_fail:
# this rank is not going to simulate a fault, nothing more to do
return
fault_type = fault_desc.fault_type
if fault_type == 'random':
fault_type = rng.choice(['rank_killed', 'rank_hung'])
if fault_type == 'rank_killed':
target_pid = os.getpid()
elif fault_type == 'rank_hung':
target_pid = os.getpid()
else:
raise Exception(f"Unknown fault type {fault_type}")
delay = fault_desc.base_delay + fault_desc.rand_delay * rng.random()
self._log_info_on_rank0(
f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}"
)
def __fault_thread():
time.sleep(delay)
for of in [sys.stdout, sys.stderr]:
print(
f"\n####\nSimulating fault: {fault_type}; rank to fail: {rank_to_fail}\n#####\n",
file=of,
flush=True,
)
if fault_type == 'rank_hung':
os.kill(target_pid, signal.SIGSTOP)
else:
os.kill(target_pid, signal.SIGKILL)
fault_sim_thread = threading.Thread(target=__fault_thread)
fault_sim_thread.daemon = True
fault_sim_thread.start()