# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import inspect
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Hashable,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
from megatron.energon.dataclass_slots import dataclass_slots
from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError
from megatron.energon.flavors.base_dataset import FlexState, SavableDataset, set_sample_restore_key
from megatron.energon.savable import Savable
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers._log_exception import log_exception
from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex
from megatron.energon.wrappers.buffer import SavableSampleBuffer
from megatron.energon.wrappers.skip import SkipSample
T_batch = TypeVar("T_batch", covariant=True)
T_batch_sample = TypeVar("T_batch_sample", covariant=True)
@dataclass_slots
class Bucket(Savable, Generic[T_batch_sample]):
batch_size: int
samples: SavableSampleBuffer[T_batch_sample]
def save_state(self) -> FlexState:
return FlexState(
batch_size=self.batch_size,
samples=self.samples.save_state(),
)
def restore_state(self, state: FlexState):
self.batch_size = state["batch_size"]
self.samples.restore_state(state["samples"])
[docs]
class GroupBatchDataset(
BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch]
):
"""This dataset wrapper transforms a dataset of samples into a dataset of batches, grouped by some criterion.
The length is not correct, as this function can not predict the number of batches as there is no fixed batch size,
instead it returns the inner dataset size.
An example use case is: Image-Text samples, which are to be grouped by the image size into three
size categories (e.g. 128x128, 256x256, 512x512) for efficient augmentation and batching.
"""
dataset: SavableDataset[T_batch_sample]
sample_group_key: Callable[[T_batch_sample], Tuple[Hashable, Optional[int]]]
batcher: Callable[[List[T_batch_sample]], T_batch]
drop_last: bool
error_handler: Callable[[Exception, List[T_batch_sample]], None]
_group_key_sample_index: SampleIndex
_batch_sample_index: SampleIndex
_buckets: Dict[Hashable, Bucket[T_batch_sample]]
def __init__(
self,
dataset: SavableDataset[T_batch_sample],
fixed_batch_size: Optional[int],
sample_group_key: Callable[[T_batch_sample], Tuple[Hashable, Optional[int]]],
batcher: Callable[[List[T_batch_sample]], T_batch],
*,
batcher_stateless: bool = False,
batcher_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None,
drop_last: bool = False,
error_handler: Callable[[Exception, List[T_batch_sample]], None] = log_exception,
worker_config: WorkerConfig,
):
"""Construct a GroupBatchDataset.
Args:
dataset: The input dataset to wrap
fixed_batch_size: Fixed batch size to use for all buckets. If None, the batch size is determined by the sample_group_key function.
sample_group_key: Function which determines the bucket of a sample.
batcher: Function which combines separate samples into a single object. May raise
:exc:`megatron.energon.SkipSample` to skip a sample.
drop_last: If True, the last batch is dropped if it is smaller than the batch size.
error_handler: Handler for errors. Defaults to logging and ignoring the exception.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
self.fixed_batch_size = fixed_batch_size
self.sample_group_key = sample_group_key
self.batcher = batcher
self.batcher_stateless = batcher_stateless
self.batcher_config = batcher_config
self.drop_last = drop_last
self.error_handler = error_handler
self.reset_state_own()
assert not inspect.isgeneratorfunction(batcher), (
f"Batcher {batcher} must not be a generator function for grouped batching."
)
[docs]
def reset_state_own(self) -> None:
self._group_key_sample_index = SampleIndex(self.worker_config, src=self)
self._batch_sample_index = SampleIndex(self.worker_config, src=self)
self._buckets = {}
def __len__(self):
# Return an upper bound. This is for sure not correct.
return len(self.dataset)
def __iter__(self) -> Iterator[T_batch]:
buckets = self._buckets
if buckets is None:
buckets = self._buckets = dict()
# Load saved state if available
for bucket in buckets.values():
bucket.samples.worker_start()
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial GroupBatchDataset state:\n", end="")
# for bucket_key, bucket in buckets.items():
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{bucket_key}] (bs={bucket.batch_size}, len(samples)={len(bucket.samples)}):\n", end="")
# bucket.samples.debug_print(" ")
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial done\n", end="")
def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]:
# Debug print the state
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flush GroupBatchDataset state:\n", end="")
# for dbg_bucket_key, dbg_bucket in buckets.items():
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{dbg_bucket_key}{'*' if dbg_bucket_key == bucket_key else ''}] (bs={dbg_bucket.batch_size}, len(samples)={len(dbg_bucket.samples)}):\n", end="")
# dbg_bucket.samples.debug_print(" ")
batch_items, sample_restore_keys = bucket.samples.flush()
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flushed: len(batch)={len(batch_items)} len(samples)={len(bucket.samples)}\n", end="")
try:
with self._batch_sample_index.ctx() as sample_idx:
batch_sample = self.batcher(batch_items)
assert not isinstance(batch_sample, Generator), (
f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet."
)
set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self)
yield batch_sample
except SkipSample:
pass
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(batch_items)
except Exception as e:
self.error_handler(e, batch_items)
# Add samples to the buckets
for sample in self.dataset:
try:
with self._group_key_sample_index.ctx():
bucket_key, batch_size = self.sample_group_key(sample)
assert (batch_size is None) != (self.fixed_batch_size is None), (
f"A sample in group for key {bucket_key} returned batch size {batch_size}, but fixed "
f"batch size is set to {self.fixed_batch_size}. One of the two should be None."
)
if self.fixed_batch_size is not None:
batch_size = self.fixed_batch_size
except SkipSample:
continue
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(sample)
except Exception as e:
self.error_handler(e, [sample])
continue
bucket = buckets.get(bucket_key)
if bucket is None:
assert batch_size is not None
buckets[bucket_key] = bucket = Bucket(
batch_size=batch_size,
samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config),
)
else:
assert bucket.batch_size == batch_size, (
f"Got different batch size for group {bucket_key}: {bucket.batch_size} != {batch_size}."
)
bucket.samples.append(sample)
if len(bucket.samples) >= bucket.batch_size:
yield from flush(bucket)
# Flush out last samples
if not self.drop_last:
for bucket in buckets.values():
if len(bucket.samples) > 0:
yield from flush(bucket)
# Clear the buckets
self._buckets.clear()
[docs]
def save_state(self) -> FlexState:
return FlexState(
bucket_sample_index=self._group_key_sample_index.save_state(),
batch_sample_index=self._batch_sample_index.save_state(),
buckets={key: bucket.save_state() for key, bucket in self._buckets.items()},
**super().save_state(),
)
[docs]
def restore_state(self, state: FlexState) -> None:
super().restore_state(state)
self._group_key_sample_index.restore_state(state["bucket_sample_index"])
self._batch_sample_index.restore_state(state["batch_sample_index"])
for key, bucket_state in state["buckets"].items():
self._buckets[key] = Bucket(
batch_size=-1,
samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config),
)
self._buckets[key].restore_state(bucket_state)
[docs]
def can_restore_sample(self) -> bool:
return super().can_restore_sample() and self.batcher_stateless
[docs]
def assert_can_restore(self) -> None:
assert self.batcher_stateless, (
f"Batcher {self.batcher} must be stateless to restore samples"
)
super().assert_can_restore()
[docs]
def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch:
self.assert_can_restore()
id, sample_idx, *sample_restore_keys = index
assert id == type(self).__name__
batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys]
with self._batch_sample_index.ctx(sample_idx):
batch_sample = self.batcher(batch)
set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self)
return batch_sample
[docs]
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"bucket": self._function_config(self.sample_group_key),
"batcher": self._function_config(self.batcher),
**(
{
"batcher_config": (
self.batcher_config()
if callable(self.batcher_config)
else self.batcher_config
)
}
if self.batcher_config
else {}
),
"batcher_stateless": self.batcher_stateless,
"drop_last": self.drop_last,
"error_handler": self._function_config(self.error_handler),
"worker_config": self.worker_config.config(),
"dataset": self.dataset.config(),
}
def __str__(self):
return f"GroupBatchDataset(bucket={self.sample_group_key}, batcher={self.batcher}, drop_last={self.drop_last}, dataset={self.dataset})"