Source code for megatron.energon.wrappers.concat_dataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Dict, Generic, Iterator, TypeVar

from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset

T_sample = TypeVar("T_sample")


[docs] class ConcatDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """ This dataset wrapper concatenates multiple iterable datasets together. The datasets must be finite, otherwise not all datasets can be sampled. This is only useful for validation / test datasets. """ def __init__( self, *datasets: SavableDataset[T_sample], worker_config: WorkerConfig, ): """Construct a concatenated dataset.""" super().__init__(datasets, worker_config=worker_config) assert len(self) >= 0, "Datasets must be finite."
[docs] def reset_state_own(self) -> None: return
def __len__(self): return sum(len(dataset) for dataset in self.datasets) def __iter__(self) -> Iterator[T_sample]: for ds_idx, dataset in enumerate(self.datasets): for sample in dataset: yield add_sample_restore_key( sample, ds_idx, src=self, )
[docs] def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, "datasets": [dataset.config() for dataset in self.datasets], }
def __str__(self): return f"ConcatDataset(datasets={self.datasets})"