# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Dict, Iterator, List, Tuple, TypeVar
import torch
from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key
from megatron.energon.rng import WorkerRng
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
T_sample = TypeVar("T_sample")
[docs]
class BlendDataset(BaseWrapperDataset[T_sample, T_sample]):
"""
This dataset wrapper blends multiple iterable datasets together give a weighting.
The datasets may be infinite. This dataset is always infinite.
"""
weights: Tuple[float, ...]
exhausted: List[bool]
_worker_rng: WorkerRng
_savable_fields = ("exhausted", "_worker_rng")
def __init__(
self,
*dataset_weights: Tuple[SavableDataset[T_sample], float],
worker_config: WorkerConfig,
):
"""Construct a BlendDataset.
Args:
dataset_weights: Each argument should be a tuple of (dataset, weight) with a weight
between 0 and 1. The output samples are sampled from the input datasets with the
given probabilities.
worker_config: Configuration for the workers.
"""
# datasets = [dataset for dataset, _weight in dataset_weights]
self.datasets, self.weights = zip(*dataset_weights)
super().__init__(self.datasets, worker_config=worker_config)
self.dataset_weights = dataset_weights
self.reset_state_own()
[docs]
def reset_state_own(self) -> None:
self._worker_rng = WorkerRng(self.worker_config)
self.exhausted = [False] * len(self.weights)
def __len__(self) -> int:
# Give the number of samples in inner datasets, disregarding the weight
return sum(len(dataset) for dataset, weight in self.dataset_weights)
def __iter__(self) -> Iterator[T_sample]:
assert self.worker_has_samples(), "Cannot blend all empty datasets"
# Create a list of datasets and their weights, but
# set the weight to 0 if the dataset has no samples on this worker.
dataset_iters = []
weights = []
for idx, (dataset, weight) in enumerate(self.dataset_weights):
assert weight > 0, "All blending weights must be > 0"
if dataset.worker_has_samples():
dataset_iters.append(iter(dataset))
weights.append(weight)
else:
dataset_iters.append(None)
weights.append(0)
weights = torch.tensor(weights, dtype=torch.float32)
if weights.sum() == 0:
raise RuntimeError(
"There is a worker with no samples in any of the blended datasets. "
"This can happen if you have a lot of workers and your dataset is too small. "
"Currently this case is not supported."
)
# Some may already be exhausted on this worker when restoring a state.
for idx, exhausted in enumerate(self.exhausted):
if exhausted:
weights[idx] = 0
dataset_iters[idx] = None
while True:
ds_idx = self._worker_rng.choice_idx(probs=weights)
if dataset_iters[ds_idx] is None:
if all(dataset_iter is None for dataset_iter in dataset_iters):
break
continue
try:
sample = next(dataset_iters[ds_idx])
except StopIteration:
dataset_iters[ds_idx] = None
weights[ds_idx] = 0
self.exhausted[ds_idx] = True
if all(dataset_iter is None for dataset_iter in dataset_iters):
break
else:
yield add_sample_restore_key(sample, ds_idx, src=self)
self.exhausted = [False] * len(self.dataset_weights)
[docs]
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset_weights": [
(dataset.config(), weight) for dataset, weight in self.dataset_weights
],
"worker_config": self.worker_config.config(),
}
def __str__(self):
return f"BlendDataset(dataset_weights={self.dataset_weights})"