# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Callable, Dict, Generic, Iterator, Optional, Tuple, TypeVar, Union
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex
T_sample = TypeVar("T_sample")
[docs]
class FilterDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""This dataset wrapper applies a custom filter function to each sample and does not yield
filtered samples."""
filter_fn: Callable[[T_sample], bool]
filter_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]]
_sample_index: SampleIndex
_savable_fields = ("_sample_index",)
def __init__(
self,
dataset: SavableDataset[T_sample],
*,
filter_fn: Callable[[T_sample], bool],
filter_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None,
worker_config: WorkerConfig,
):
"""Construct a MapDataset.
Args:
dataset: The input dataset to wrap
filter_fn: The function to apply to each sample. If it returns `True`, the sample is
accepted.
filter_fn_config: Configuration for the filter function. If callable, it should return the
configuration. Defaults to None.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
self.filter_fn = filter_fn
self.filter_fn_config = filter_fn_config
self.reset_state_own()
[docs]
def reset_state_own(self) -> None:
self._sample_index = SampleIndex(self.worker_config, src=self)
def __len__(self):
return len(self.dataset)
def __iter__(self) -> Iterator[T_sample]:
for sample in self.dataset:
with self._sample_index.ctx():
filter_res = self.filter_fn(sample)
if filter_res:
yield sample
[docs]
def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample:
return self.dataset.restore_sample(index)
[docs]
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset": self.dataset.config(),
"filter_fn": self._function_config(self.filter_fn),
**(
{
"filter_fn_config": (
self.filter_fn_config()
if callable(self.filter_fn_config)
else self.filter_fn_config
)
}
if self.filter_fn_config
else {}
),
}
def __str__(self):
return f"FilterDataset(filter_fn={self.filter_fn}, dataset={self.dataset})"