Source code for megatron.energon.wrappers.epochize_dataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Dict, Generic, Iterator, Optional, TypeVar

from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset

T_sample = TypeVar("T_sample")


[docs] class EpochizeDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """ Uses the base dataset, and creates one epoch, which has length samples. Keeps the underlying dataset iterator alive over epochs (i.e. if it is an infinite dataset, it will keep the state). Repeats the underlying dataset if the iterator is exhausted. """ length: int _active_iter: Optional[Iterator[T_sample]] _offset: int _savable_fields = ("_offset",) def __init__( self, dataset: SavableDataset[T_sample], length: int, worker_config: WorkerConfig, ): """ Create the epochized dataset. Args: dataset: The source dataset (possibly infinite) length: Number of samples to iterate before iteration stops (i.e. one epoch). When iteration continues, the original dataset iterator is resumed and does only restart if exhausted. worker_config: Configuration for the workers. """ super().__init__(dataset, worker_config=worker_config) self.length = length self._active_iter = None self.reset_state_own()
[docs] def reset_state_own(self) -> None: self._offset = 0
def __iter__(self) -> Iterator[T_sample]: # Compute the local length for this worker, i.e. all worker's lengths sum up to the total if self.worker_config.num_workers <= 1: local_length = self.length else: local_length = self.length // self.worker_config.num_workers if self.worker_config.rank_worker_id() < self.length % self.worker_config.num_workers: local_length += 1 if self.worker_config.should_log(level=2): self.worker_config.worker_log( { "t": "EpochizeDataset.epoch_start", "r": self.worker_config.rank, "w": self.worker_config.rank_worker_id(), "offset": self._offset, "local_length": local_length, "length": self.length, } ) offset_range = list(range(self._offset, local_length)) # Only iterate if there are samples to iterate if len(offset_range) > 0: if self._active_iter is None: self._active_iter = iter(self.dataset) for idx in offset_range: self._offset = (idx + 1) % local_length try: sample = next(self._active_iter) except StopIteration: break yield sample if self.worker_config.should_log(level=2): self.worker_config.worker_log( { "t": "EpochizeDataset.epoch_end", "r": self.worker_config.rank, "w": self.worker_config.rank_worker_id(), "offset": self._offset, "local_length": local_length, "length": self.length, } ) def __len__(self) -> int: return self.length
[docs] def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, "dataset": self.dataset.config(), "length": self.length, "worker_config": self.worker_config.config(), }
def __str__(self): return f"EpochizeDataset(length={self.length}, dataset={self.dataset})"