Source code for megatron.energon.wrappers.repeat_dataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import math
from typing import Any, Dict, Generic, Iterator, Optional, TypeVar, Union

from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset

T_sample = TypeVar("T_sample")


[docs] class RepeatDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """This dataset repeats the inner dataset indefinitely or a specific number of repeats.""" repeats: Optional[Union[int, float]] _repetition: int _index: int _savable_fields = ("_repetition", "_index") def __init__( self, dataset: SavableDataset[T_sample], *, repeats: Optional[Union[int, float]] = None, restart: bool = True, worker_config: WorkerConfig, ): """Construct a RepeatDataset. Args: dataset: The input dataset to repeat. repeats: Number of repeats, `None` for indefinitely repeating. restart: If true, restart the underlying dataset after iterating once through the repeats if repeats is set to an integer, but still stop iterating. worker_config: Configuration for the workers. """ super().__init__(dataset, worker_config=worker_config) self.repeats = repeats self.restart = restart self.reset_state_own()
[docs] def reset_state_own(self) -> None: self._repetition = 0 self._index = 0
def __len__(self): if self.repeats is None: return len(self.dataset) return int(len(self.dataset) * self.repeats) def __iter__(self) -> Iterator[T_sample]: assert self.repeats is not None or self.dataset.worker_has_samples(), ( "Cannot repeat empty dataset indefinitely" ) ds_len = len(self.dataset) while self.repeats is None or self._repetition < self.repeats: if self.repeats is not None and self._repetition == math.floor(self.repeats): # Last iteration, adjust the number of samples fraction = self.repeats - math.floor(self.repeats) stop_after = math.floor(ds_len * fraction) if self._index >= stop_after: # We restored an index and it is already past the stop_after break else: stop_after = None for sample in self.dataset: self._index += 1 yield sample if stop_after is not None and self._index >= stop_after: break if self.worker_config.should_log(level=2): self.worker_config.worker_log( { "t": "RepeatDataset.repeat", "r": self.worker_config.rank, "w": self.worker_config.rank_worker_id(), "offset": self._repetition, "repeats": self.repeats, } ) self._repetition += 1 self._index = 0 if self.restart: self._repetition = 0 else: # No more repeats self._repetition = math.ceil(self.repeats)
[docs] def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, "dataset": self.dataset.config(), "repeats": self.repeats, "worker_config": self.worker_config.config(), }
def __str__(self): return f"RepeatDataset(repeats={self.repeats}, dataset={self.dataset})"