Source code for megatron.energon.errors

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from contextlib import contextmanager
from typing import Any, Callable, Generator, Type, TypeVar

from megatron.energon.sample_utils import format_sample_compact, format_sample_detailed
from megatron.energon.source_info import SourceInfo, get_source_info

T = TypeVar("T")


class SkipSample(Exception):
    """Raise this exception in any processing function to skip the current sample."""

    pass


class SampleException(ValueError):
    @classmethod
    def from_sample_key(cls: Type[T], sample_key: str) -> T:
        return cls(f"Sample {sample_key} failed")

    @classmethod
    def from_sample(cls: Type[T], sample: Any, message: str = "") -> T:
        if message:
            message = f": {message}"
        return cls(f"Sample {format_sample_compact(sample)} failed{message}")


class FatalSampleError(SampleException):
    # This will not be handled by the error handler
    pass


SYSTEM_EXCEPTIONS = (
    SystemError,
    SyntaxError,
    ImportError,
    StopIteration,
    StopAsyncIteration,
    MemoryError,
    RecursionError,
    ReferenceError,
    NameError,
    UnboundLocalError,
    FatalSampleError,
)


class ErrorContext:
    """Tracks consecutive errors and enforces error tolerance limits.

    This class helps prevent infinite error loops by tracking consecutive failures
    and raising a FatalSampleError when a tolerance threshold is exceeded.

    Example:
        error_ctx = ErrorContext(
            name="MapDataset.map_fn",
            handler=self.worker_config.global_error_handler
            tolerance=100,
        )

        with error_ctx.handle_errors(sample):
            result = process_sample(sample)
    """

    name: str
    tolerance: int
    handler: Callable[[Exception, Any, list["SourceInfo"] | None], None]

    _consecutive_failures: int = 0

    def __init__(
        self,
        name: str,
        handler: Callable[[Exception, Any, list["SourceInfo"] | None], None],
        tolerance: int = 100,
    ):
        """Initialize error context.

        Args:
            name: Name of the operation being tracked (for error messages).
            handler: Error handler function to call on exceptions. Takes (exception, sample, sources).
                If None, exceptions will be raised after incrementing the counter.
            tolerance: Maximum number of consecutive failures before raising FatalSampleError.
                Set to 0 to disable tolerance checking.
        """
        self.name = name
        self.tolerance = tolerance
        self.handler = handler

    def reset(self) -> None:
        """Reset the consecutive failures counter."""
        self._consecutive_failures = 0

    @contextmanager
    def handle_errors(
        self,
        sample: Any,
    ) -> Generator[None, None, None]:
        """Context manager for handling exceptions during sample processing.

        Automatically tracks consecutive failures and resets on success.

        Args:
            sample: The sample being processed (used in error reporting).
        """
        try:
            yield
            # Success - reset counter
            self._consecutive_failures = 0
        except GeneratorExit:
            raise
        except SkipSample:
            pass
        except SYSTEM_EXCEPTIONS as e:
            raise FatalSampleError.from_sample(
                sample, f"{self.name} failed due to system exception: {e}."
            )
        except Exception as e:
            print(f"Except {e} in {self.name}")
            # Call the error handler if provided
            if self.handler is not None:
                # Call the error handler
                self.handler(e, sample, get_source_info(sample))

            # Increment counter (may raise FatalSampleError if tolerance exceeded)
            self._consecutive_failures += 1

            if self._consecutive_failures > 1:
                print(
                    f"ErrorContext {self.name} failed {self._consecutive_failures}/{self.tolerance} times in a row."
                )
            if self.tolerance > 0 and self._consecutive_failures >= self.tolerance:
                raise FatalSampleError.from_sample(
                    sample,
                    (
                        f"{self.name} failed {self._consecutive_failures} times in a row. "
                        f"Likely your code or dataset are broken."
                    ),
                )

    def __repr__(self) -> str:
        return f"ErrorContext(name={self.name!r}, tolerance={self.tolerance}, count={self._consecutive_failures})"


@contextmanager
def handle_restore_errors(
    error_handler: Callable[[Exception, Any, list["SourceInfo"] | None], None],
    sample: Any,
) -> Generator[None, None, None]:
    """Context manager for handling exceptions during sample restoration.

    Args:
        error_handler: Function to call when an exception occurs. Takes (exception, sample, sources).
        sample: The sample being restored.
    """
    try:
        yield
    except SkipSample as e:
        # Unexpected skip sample
        try:
            raise ValueError(f"Unexpected skip sample {sample} during restoration.") from e
        except Exception as e:
            error_handler(e, sample, get_source_info(sample))
    except GeneratorExit as e:
        # Unexpected skip sample
        try:
            raise ValueError(
                f"Unexpected generator early stopping for sample {sample} during restoration."
            ) from e
        except Exception as e:
            error_handler(e, sample, get_source_info(sample))
    except SYSTEM_EXCEPTIONS as e:
        raise FatalSampleError.from_sample(sample) from e
    except Exception as e:
        error_handler(e, sample, get_source_info(sample))


[docs] def log_exception(e: Exception, sample: Any, sources: list["SourceInfo"] | None = None) -> None: """Error handler that logs exceptions with sample information. This function prints the exception traceback, source information if available, and a smart representation of the failed sample to help with debugging. Args: e: The exception that was raised. sample: The sample that caused the exception. sources: Optional list of SourceInfo objects with sample provenance. """ import traceback traceback.print_exc() print("-" * 10) if sources: print("Sources:") for source in sources: if hasattr(source, "dataset_path"): print( f" - {source.dataset_path}[{source.index}] {source.shard_name}{source.file_names!r}" ) print("-" * 10) sample_str = format_sample_detailed(sample) print(sample_str) print("-" * 10)
[docs] def reraise_exception( e: Exception, _sample: Any, _sources: list["SourceInfo"] | None = None ) -> None: """Error handler that simply reraises the exception. This is useful when you want failures to propagate immediately without any tolerance or logging. Args: e: The exception to reraise. _sample: The sample (unused). _sources: Source info (unused). Raises: The original exception. """ raise e