Source code for nvidia_resiliency_ext.ptl_resiliency.straggler_det_callback

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import sys
import time
from typing import Optional

import torch

from ._utils import is_module_available

if is_module_available("lightning"):
    from lightning.pytorch.callbacks import Callback
elif is_module_available("pytorch_lightning"):
    from pytorch_lightning.callbacks import Callback
else:
    raise ImportError("Could not find 'lightning' or 'pytorch_lightning' module")

import nvidia_resiliency_ext.straggler as straggler


[docs] class StragglerDetectionCallback(Callback):
[docs] def __init__( self, report_time_interval: float, calc_relative_gpu_perf: bool, calc_individual_gpu_perf: bool, num_gpu_perf_scores_to_print: int, gpu_relative_perf_threshold: float, gpu_individual_perf_threshold: float, stop_if_detected: bool, enable_ptl_logging: bool, profiling_interval: int = 1, logger_name: Optional[str] = "nemo_logger.StragglerDetectionCallback", ): """ Initialize straggler detection callback instance. Args: report_time_interval (float): Interval [seconds] of the straggler check calc_relative_gpu_perf (bool): Calculate relative GPU performance calc_individual_gpu_perf (bool): Calculate individual GPU performance num_gpu_perf_scores_to_print (int): How many best and worst perf scores to print (0 - does not print periodically, but only if stragglers are detected) gpu_relative_perf_threshold (float): Threshold for relative GPU performance scores gpu_individual_perf_threshold (float): Threshold for individual GPU performance scores stop_if_detected (bool): Set to True, to terminate the workload if stragglers are detected enable_ptl_logging (bool): Set to True, to log GPU performance scores to all PTL loggers enabled through trainer profiling_interval (int): `profiling_interval` passed to `straggler.Detector.initialize`. Defaults to 1. logger_name (Optional[str], optional): Defaults to "nemo_logger.StragglerDetectionCallback". Raises: ValueError: If invalid config was provided. """ self.initialized: bool = False self.logger = logging.getLogger(logger_name) self.report_time_interval: float = report_time_interval self.calc_relative_gpu_perf: bool = calc_relative_gpu_perf self.calc_individual_gpu_perf: bool = calc_individual_gpu_perf self.num_gpu_perf_scores_to_print: int = num_gpu_perf_scores_to_print self.gpu_relative_perf_threshold: float = gpu_relative_perf_threshold self.gpu_individual_perf_threshold: float = gpu_individual_perf_threshold self.stop_if_detected: bool = stop_if_detected self.enable_ptl_logging: bool = enable_ptl_logging self.profiling_interval: int = profiling_interval self.scores_to_compute = [] if self.calc_relative_gpu_perf: self.scores_to_compute += ['relative_perf_scores'] if self.calc_individual_gpu_perf: self.scores_to_compute += ['individual_perf_scores'] if not self.scores_to_compute: raise ValueError( "No straggler performance scores specified. Check if calc_relative_gpu_perf=True or calc_individual_gpu_perf=True" ) self.interval_est_was_reset = False
def _wrap_ptl_callables(self, trainer): assert getattr( trainer.strategy, 'training_step', None ), f"{type(trainer.strategy)} does not have 'training_step' method." straggler.Detector.wrap_callables( callable_ids=[straggler.CallableId(trainer.strategy, "training_step")] )
[docs] def setup(self, trainer, pl_module, stage): if not self.initialized: straggler.Detector.initialize( scores_to_compute=self.scores_to_compute, gather_on_rank0=True, profiling_interval=self.profiling_interval, report_time_interval=self.report_time_interval, ) self._wrap_ptl_callables(trainer) self.initialized = True
[docs] def teardown(self, trainer, pl_module, stage): if self.initialized: straggler.Detector.shutdown() self.initialized = False
def _print_stragglers(self, stragglers): if rel_stragglers := stragglers['straggler_gpus_relative']: self.logger.warning( f"STRAGGLER DETECTION WARNING: Some GPUs have worse relative performance. Affected ranks: {rel_stragglers}" ) if indiv_stragglers := stragglers['straggler_gpus_individual']: self.logger.warning( f"STRAGGLER DETECTION WARNING: Some GPUs performance dropped. Affected ranks: {indiv_stragglers}" ) @staticmethod def _format_gpu_scores(rank_to_score, rank_to_node, num_best=3, num_worst=3) -> str: num_ranks = len(rank_to_score) scores_and_ranks = [(s, r) for r, s in rank_to_score.items()] scores_and_ranks.sort(reverse=True) res = "" if num_ranks > (num_best + num_worst): res += f" Worst performing {num_worst}/{num_ranks} ranks:\n" for s, r in reversed(scores_and_ranks[-num_worst:]): res += f" Rank={r} Node={rank_to_node[r]} Score={s:.2f}\n" res += f" Best performing {num_best}/{num_ranks} ranks:\n" for s, r in scores_and_ranks[:num_best]: res += f" Rank={r} Node={rank_to_node[r]} Score={s:.2f}\n" else: # if the number of ranks is small enough, print them all for s, r in reversed(scores_and_ranks): res += f" Rank={r} Node={rank_to_node[r]} Score={s:.2f}\n" return res def _print_gpu_scores(self, report): assert self.num_gpu_perf_scores_to_print > 0 if self.calc_relative_gpu_perf: rel_perf_str = self._format_gpu_scores( report.gpu_relative_perf_scores, report.rank_to_node, num_best=self.num_gpu_perf_scores_to_print, num_worst=self.num_gpu_perf_scores_to_print, ) self.logger.info(f"\nGPU relative performance:\n{rel_perf_str}") if self.calc_individual_gpu_perf: indiv_perf_str = self._format_gpu_scores( report.gpu_individual_perf_scores, report.rank_to_node, num_best=self.num_gpu_perf_scores_to_print, num_worst=self.num_gpu_perf_scores_to_print, ) self.logger.info(f"\nGPU individual performance:\n{indiv_perf_str}") def _log_gpu_perf_scores(self, pl_module, rank_to_score, rank_to_node, score_prefix): """ Logs GPU performance scores with rank and node information to all PTL loggers enabled through trainer. """ scores_log = {} min_val = float('nan') med_val = float('nan') max_val = float('nan') scores = list(rank_to_score.values()) if scores: scores = torch.tensor(scores, dtype=torch.float32) min_val = torch.min(scores).item() med_val = torch.median(scores).item() max_val = torch.max(scores).item() scores_log[f"{score_prefix}/min"] = min_val scores_log[f"{score_prefix}/median"] = med_val scores_log[f"{score_prefix}/max"] = max_val try: pl_module.log_dict(scores_log, logger=True, batch_size=1, rank_zero_only=True) except Exception as e: self.logger.error(f"Failed to log GPU performance scores: {e}") def _log_gpu_scores(self, pl_module, report): assert self.enable_ptl_logging is True if self.calc_relative_gpu_perf: self._log_gpu_perf_scores( pl_module, rank_to_score=report.gpu_relative_perf_scores, rank_to_node=report.rank_to_node, score_prefix="gpu_relative_perf", ) if self.calc_individual_gpu_perf: self._log_gpu_perf_scores( pl_module, rank_to_score=report.gpu_individual_perf_scores, rank_to_node=report.rank_to_node, score_prefix="gpu_individual_perf", ) def _handle_straggler_report(self, pl_module, report) -> bool: stragglers = report.identify_stragglers( gpu_rel_threshold=self.gpu_relative_perf_threshold, gpu_indiv_threshold=self.gpu_individual_perf_threshold, ) stragglers_found = ( stragglers['straggler_gpus_relative'] or stragglers['straggler_gpus_individual'] ) if stragglers_found: self._print_stragglers(stragglers) if self.num_gpu_perf_scores_to_print > 0: self._print_gpu_scores(report) if self.enable_ptl_logging: self._log_gpu_scores(pl_module, report) return stragglers_found def _gather_flag_from_rank0(self, flag): flag = torch.tensor( [1.0 if flag else 0], device=torch.cuda.current_device(), dtype=torch.float32 ) torch.distributed.broadcast(flag, 0) flag = bool(flag.item() > 0) return flag
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): time_started = time.monotonic() rank = trainer.global_rank report = straggler.Detector.generate_report_if_interval_elapsed() stragglers_found = False if rank == 0 and report: # gather_on_rank0 is True, so only rank 0 has the report stragglers_found = self._handle_straggler_report(pl_module, report) # check if the report was generated if straggler.Detector.is_interval_elapsed(): # report was generated on the rank0 if self.stop_if_detected and self._gather_flag_from_rank0(stragglers_found): self._stop_training(trainer) # log reporting time elapsed = time.monotonic() - time_started self.logger.info(f"Straggler report processing time: {elapsed:.3f} sec.")
def _stop_training(self, trainer) -> None: self.logger.error("Detected stragglers. Terminating training...") trainer.should_stop = True if trainer.checkpoint_callback: monitor_candidates = trainer.checkpoint_callback._monitor_candidates(trainer) trainer.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) if hasattr(trainer.strategy.checkpoint_io, 'maybe_finalize_save_checkpoint'): self.logger.info("Async checkpointing detected, waiting for it to complete...") trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True) sys.exit(1)