# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import inspect
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError
from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers._log_exception import log_exception
from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key
from megatron.energon.wrappers.skip import SkipSample
T_batch = TypeVar("T_batch", covariant=True)
T_batch_sample = TypeVar("T_batch_sample", covariant=True)
[docs]
class BatchDataset(BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch]):
"""This dataset wrapper transforms a dataset of samples into a dataset of batches."""
batch_size: int
batcher: Callable[[List[T_batch_sample]], T_batch]
drop_last: bool
error_handler: Callable[[Exception, List[T_batch_sample]], None]
_sample_index: SampleIndex
_generator_sample_keys: Optional[Any]
_generator_offset: Optional[int]
_savable_fields = ("_sample_index", "_generator_sample_keys", "_generator_offset")
def __init__(
self,
dataset: SavableDataset[T_batch_sample],
batch_size: int,
batcher: Callable[[List[T_batch_sample]], T_batch],
*,
batcher_stateless: bool = False,
batcher_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None,
drop_last: bool = False,
error_handler: Callable[[Exception, List[T_batch_sample]], None] = log_exception,
worker_config: WorkerConfig,
):
"""Construct a BatchDataset.
Args:
dataset: The input dataset to wrap
batch_size: The desired batch size. The last batch may be smaller.
batcher: Function which combines separate samples into a single object. May raise
:exc:`megatron.energon.SkipSample` to skip a sample.
batcher_stateless: If True, the batcher is stateless, thus samples can be stored/
restored.
batcher_config: Configuration for the batcher function. If callable, it should return the
configuration. Defaults to None.
drop_last: If True, the last batch is dropped if it is smaller than the batch size.
error_handler: Function which handles exceptions raised by the batcher. The default
implementation logs the exception.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
self.batch_size = batch_size
self.batcher = batcher
self.batcher_stateless = batcher_stateless
self.batcher_config = batcher_config
self.drop_last = drop_last
self.error_handler = error_handler
self.reset_state_own()
[docs]
def reset_state_own(self) -> None:
self._sample_index = SampleIndex(self.worker_config, src=self)
self._generator_sample_keys = None
self._generator_offset = None
def __len__(self):
n_samples = len(self.dataset)
num_workers = max(self.worker_config.num_workers, 1)
n_samples_per_worker_floor = n_samples // num_workers
remaining_n_sample_workers = n_samples % num_workers
n_batches_per_worker_floor = n_samples_per_worker_floor // self.batch_size
if n_samples_per_worker_floor % self.batch_size != 0 and not self.drop_last:
n_batches_per_worker_floor += 1
# Correct number of batches for the workers which yield 1 more sample (to balance)
n_batches_per_worker_ceil = (n_samples_per_worker_floor + 1) // self.batch_size
if n_batches_per_worker_ceil % self.batch_size != 0 and not self.drop_last:
n_batches_per_worker_ceil += 1
return (
n_batches_per_worker_floor * (num_workers - remaining_n_sample_workers)
+ n_batches_per_worker_ceil * remaining_n_sample_workers
)
def __iter__(self) -> Iterator[T_batch]:
batch: List[T_batch_sample] = []
sample_restore_keys = []
if self._generator_sample_keys is not None:
sample_restore_keys = self._generator_sample_keys
assert self._generator_offset is not None
batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys]
with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx:
batch_sample = self.batcher(batch)
assert isinstance(batch_sample, Generator)
assert inspect.isgeneratorfunction(self.batcher), (
f"Generator in {self.batcher} but not marked as such."
)
target_offset = self._generator_offset
self._generator_offset = 0
for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate(
self._sample_index.iter_ctx(batch_sample, sample_idx)
):
# Skip other samples
if batch_sub_idx >= target_offset:
self._generator_offset = batch_sub_idx + 1
yield set_sample_restore_key(
inner_batch_sample,
sample_idx,
batch_sub_idx,
*sample_restore_keys,
src=self,
)
self._generator_sample_keys = None
self._generator_offset = None
batch.clear()
sample_restore_keys = []
def flush():
try:
with self._sample_index.ctx() as sample_idx:
batch_sample = self.batcher(batch)
if isinstance(batch_sample, Generator):
assert inspect.isgeneratorfunction(self.batcher), (
f"Generator in {self.batcher} but not marked as such."
)
self._generator_sample_keys = sample_restore_keys
self._generator_offset = 0
for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate(
self._sample_index.iter_ctx(batch_sample, sample_idx)
):
self._generator_offset = batch_sub_idx + 1
yield set_sample_restore_key(
inner_batch_sample,
sample_idx,
batch_sub_idx,
*sample_restore_keys,
src=self,
)
self._generator_sample_keys = None
self._generator_offset = None
else:
set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self)
yield batch_sample
sample_restore_keys.clear()
except SkipSample:
pass
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(batch)
except Exception as e:
self.error_handler(e, batch)
for sample in self.dataset:
batch.append(sample)
sample_restore_keys.append(get_sample_restore_key(sample))
if len(batch) == self.batch_size:
yield from flush()
batch = []
if len(batch) > 0 and not self.drop_last:
yield from flush()
[docs]
def can_restore_sample(self) -> bool:
# Cannot really verify if the returned elements contain a __restore_key__.
# If the user wants to use this, well...
return super().can_restore_sample() and self.batcher_stateless
[docs]
def assert_can_restore(self) -> None:
assert self.batcher_stateless, (
f"Batcher {self.batcher} must be stateless to restore samples"
)
super().assert_can_restore()
[docs]
def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch:
# We need to store multiple indices to restore a batch.
self.assert_can_restore()
if inspect.isgeneratorfunction(self.batcher):
id, sample_idx, batch_sub_idx, *samples_restore_keys = index
assert id == type(self).__name__
else:
id, sample_idx, *samples_restore_keys = index
assert id == type(self).__name__
batch = [self.dataset.restore_sample(inner_idx) for inner_idx in samples_restore_keys]
with self._sample_index.ctx(sample_idx):
batch_sample = self.batcher(batch)
if isinstance(batch_sample, Generator):
assert inspect.isgeneratorfunction(self.batcher), (
f"Generator in {self.batcher} but not marked as such."
)
for cur_batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate(
self._sample_index.iter_ctx(batch_sample, sample_idx)
):
if cur_batch_sub_idx == batch_sub_idx:
return set_sample_restore_key(
inner_batch_sample,
sample_idx,
batch_sub_idx,
*samples_restore_keys,
src=self,
)
assert False, f"Batch sub-index {batch_sub_idx} not found in batch"
else:
return set_sample_restore_key(
batch_sample,
sample_idx,
*samples_restore_keys,
src=self,
)
[docs]
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"batch_size": self.batch_size,
"batcher": self._function_config(self.batcher),
**(
{
"batcher_config": (
self.batcher_config()
if callable(self.batcher_config)
else self.batcher_config
)
}
if self.batcher_config
else {}
),
"batcher_stateless": self.batcher_stateless,
"drop_last": self.drop_last,
"error_handler": self._function_config(self.error_handler),
"worker_config": self.worker_config.config(),
"dataset": self.dataset.config(),
}
def __str__(self):
return f"BatchDataset(batch_size={self.batch_size}, drop_last={self.drop_last}, batcher={self.batcher}, dataset={self.dataset})"