Source code for megatron.energon.wrappers.mix_batch_dataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
from typing import Any, Callable, Dict, Generator, Generic, Iterator, List, Tuple, TypeVar, Union

import torch

from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
from megatron.energon.wrappers.batch_dataset import BatchDataset
from megatron.energon.wrappers.blend_dataset import BlendDataset

T_batch_in = TypeVar("T_batch_in")
T_batch = TypeVar("T_batch")


[docs] def generic_concat(batch: List[Any]) -> Any: """Based on the types/shapes of the batch: Will either pad and stack, or return as list. Recurses structures (dict, dataclass, namedtuple) and applies the same logic to each field.""" if isinstance(batch[0], torch.Tensor): return concat_pad(batch) elif isinstance(batch[0], dict): return {key: generic_concat([sample[key] for sample in batch]) for key in batch[0].keys()} elif dataclasses.is_dataclass(batch[0]): return type(batch[0])( **{ field.name: generic_concat([getattr(sample, field.name) for sample in batch]) for field in dataclasses.fields(batch[0]) } ) elif isinstance(batch[0], tuple) and hasattr(batch[0], "_fields"): # NamedTuple return type(batch[0])( **{ field: generic_concat([getattr(sample, field) for sample in batch]) for field in batch[0]._fields } ) else: return batch
[docs] def concat_pad(batch: List[Any]) -> Any: """Concat a batch of arbitrary-sized tensors padded with 0s.""" total_bs = sum(b.shape[0] for b in batch) max_size = [max(b.shape[dim] for b in batch) for dim in range(1, batch[0].ndim)] concat_tensor = batch[0].new_zeros((total_bs, *max_size)) b_idx = 0 for b in batch: concat_tensor[(slice(b_idx, b_idx + b.shape[0]), *(slice(0, s) for s in b.shape[1:]))] = b b_idx += b.shape[0] # Pad all tensors to max_size return concat_tensor
[docs] def homogeneous_concat_mix(samples: List[T_batch_in]) -> T_batch: """ Mixes a list of batches into a single batch. The default implementation is to concat the batches if they are all of the same type, otherwise return a list of batches. Args: samples: THe samples to mix. Returns: The mixed batch. """ first_type = type(samples[0]) assert all(first_type is type(sample) for sample in samples) # All the same type -> concat batches return generic_concat(samples)
[docs] class MixBatchDataset(BaseWrapperDataset[T_batch_in, T_batch], Generic[T_batch_in, T_batch]): """ This dataset wrapper blends multiple iterable datasets together give a weight. The datasets may be infinite. This dataset is always infinite. Effectively combines :class:`megatron.energon.BlendDataset` and :class:`megatron.energon.BatchDataset`. """ def __init__( self, *dataset_weights: Tuple[SavableDataset[T_batch_in], float], batch_size: int, batch_mix_fn: Callable[ [List[T_batch_in]], Union[T_batch, Generator[T_batch, None, None]] ] = lambda x: x, worker_config: WorkerConfig, ): """Construct a BlendDataset. Args: dataset_weights: Each argument should be a tuple of (dataset, weight) with a weight between 0 and 1. The output samples are sampled from the input datasets with the given probabilities. The datasets should have a batch size of 1, otherwise the whole batches will be sampled. batch_size: The batch size to output. batch_mix_fn: A function that takes a list of samples from the input datasets and returns a batch sample. The default implementation returns a list of batches. For homogeneous datasets, it is recommended to use the :func:`megatron.energon.homogeneous_concat_mix` which concatenates the batches. May raise :exc:`megatron.energon.SkipSample` to skip a sample. May also return a generator, which will be iterated over to produce batches. worker_config: Configuration for the workers. """ super().__init__( BatchDataset( BlendDataset(*dataset_weights, worker_config=worker_config), batch_size=batch_size, batcher=batch_mix_fn, worker_config=worker_config, ), worker_config=worker_config, )
[docs] def reset_state_own(self) -> None: return
def __len__(self) -> int: return len(self.dataset) def __iter__(self) -> Iterator[T_batch]: yield from self.dataset
[docs] def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, "dataset": self.dataset.config(), }
def __str__(self): return f"MixBatchDataset(dataset={self.dataset})"