Source code for megatron.energon.wrappers.map_dataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import inspect
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    Generic,
    Iterator,
    Optional,
    Tuple,
    TypeVar,
    Union,
)

from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError
from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers._log_exception import log_exception
from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key
from megatron.energon.wrappers.skip import SkipSample

T_sample = TypeVar("T_sample")
T_sample_out = TypeVar("T_sample_out")


[docs] class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]): """This dataset wrapper applies a custom function to transform each sample.""" map_fn: Callable[[T_sample], Union[T_sample_out, Generator[T_sample_out, None, None]]] error_handler: Callable[[Exception, T_sample], None] stateless_map_fn: bool map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] _sample_index: SampleIndex _generator_sample_key: Optional[Any] _generator_offset: Optional[int] _savable_fields = ( "_sample_index", "_generator_sample_key", "_generator_offset", ) def __init__( self, dataset: SavableDataset[T_sample], map_fn: Callable[[T_sample], Union[T_sample_out, Generator[T_sample_out, None, None]]], *, error_handler: Callable[[Exception, T_sample], None] = log_exception, stateless_map_fn: bool = False, map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, worker_config: WorkerConfig, ): """Construct a MapDataset. If this should be savable, the map_fn must only return a sample, or a generator yielding 0 or 1 sample per input sample. Otherwise this will be broken (see `IterMapDataset`). Args: dataset: The input dataset to wrap map_fn: The function to apply to each sample. May raise :exc:`megatron.energon.SkipSample` to skip a sample. Alternatively, may return a generator to yield multiple or no samples. error_handler: Handler for errors. Defaults to logging and ignoring the exception. stateless_map_fn: If true, the map_fn is deterministic and stateless (thus key for random access can propagate to inner dataset). Defaults to False. map_fn_config: Configuration for the map_fn function. If callable, it should return the configuration. Defaults to None. worker_config: Worker configuration. """ super().__init__(dataset, worker_config=worker_config) self.map_fn = map_fn self.error_handler = error_handler self.stateless_map_fn = stateless_map_fn self.map_fn_config = map_fn_config self.reset_state_own()
[docs] def reset_state_own(self) -> None: self._sample_index = SampleIndex(self.worker_config, src=self) self._generator_sample_key = None self._generator_offset = None
def __len__(self): return len(self.dataset) def __iter__(self) -> Iterator[T_sample_out]: if self._generator_sample_key is not None: assert self._generator_offset is not None sample = self.dataset.restore_sample(self._generator_sample_key) # Do not increment the sample index, use previous index with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: mapped_sample = self.map_fn(sample) assert isinstance(mapped_sample, Generator) assert inspect.isgeneratorfunction(self.map_fn), ( f"Generator in {self.map_fn} but not marked as such." ) target_offset = self._generator_offset self._generator_offset = 0 for idx, (sample_idx, inner_sample) in enumerate( self._sample_index.iter_ctx(mapped_sample, sample_idx) ): # Skip other samples if idx >= target_offset: self._generator_offset = idx + 1 yield add_sample_restore_key( inner_sample, sample_idx, idx, src=self, ) self._generator_sample_key = None self._generator_offset = None for sample in self.dataset: try: with self._sample_index.ctx() as sample_idx: mapped_sample = self.map_fn(sample) if isinstance(mapped_sample, Generator): assert inspect.isgeneratorfunction(self.map_fn), ( f"Generator in {self.map_fn} but not marked as such." ) self._generator_sample_key = get_sample_restore_key(sample) self._generator_offset = 0 # In case of a generator, additionally store the index of the yielded samples # per input sample for idx, (sample_idx, inner_sample) in enumerate( self._sample_index.iter_ctx(mapped_sample, sample_idx) ): self._generator_offset = idx + 1 yield add_sample_restore_key( inner_sample, sample_idx, idx, src=self, ) self._generator_sample_key = None self._generator_offset = None else: yield add_sample_restore_key( mapped_sample, sample_idx, src=self, ) except SkipSample: pass except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(sample) except Exception as e: self.error_handler(e, sample)
[docs] def can_restore_sample(self) -> bool: return super().can_restore_sample() and self.stateless_map_fn
[docs] def assert_can_restore(self) -> None: assert self.stateless_map_fn, ( f"MapDataset can only restore samples if map_fn {self.map_fn} is stateless." ) super().assert_can_restore()
[docs] def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample_out: self.assert_can_restore() if inspect.isgeneratorfunction(self.map_fn): id, sample_idx, local_idx = index[:3] assert id == type(self).__name__ index = index[3:] assert isinstance(local_idx, int) else: id, sample_idx = index[:2] assert id == type(self).__name__ index = index[2:] inner_sample = self.dataset.restore_sample(index) with self._sample_index.ctx(sample_idx): mapped_sample = self.map_fn(inner_sample) if isinstance(mapped_sample, Generator): assert inspect.isgeneratorfunction(self.map_fn), ( f"Generator in {self.map_fn} but not marked as such." ) for idx, (sample_idx, res_sample) in enumerate( self._sample_index.iter_ctx(mapped_sample, sample_idx) ): if idx == local_idx: return add_sample_restore_key(res_sample, sample_idx, local_idx, src=self) assert False, ( "Generator did not yield enough samples, but is marked stateless/deterministic." ) else: return add_sample_restore_key(mapped_sample, sample_idx, src=self)
[docs] def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, "dataset": self.dataset.config(), "map_fn": self._function_config(self.map_fn), **( { "map_fn_config": ( self.map_fn_config() if callable(self.map_fn_config) else self.map_fn_config ) } if self.map_fn_config else {} ), "map_fn_stateless": self.stateless_map_fn, }
def __str__(self): return f"MapDataset(map_fn={self.map_fn}, dataset={self.dataset})"