Source code for megatron.energon.wrappers.shuffle_buffer_dataset

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Dict, Generic, Iterator, Tuple, TypeVar, Union

from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.rng import WorkerRng
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
from megatron.energon.wrappers.buffer import SavableSampleBuffer

T_sample = TypeVar("T_sample")


[docs] class ShuffleBufferDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """Shuffle buffer for the dataset.""" size: int _worker_rng: WorkerRng _active_buffer: SavableSampleBuffer[T_sample] _savable_fields = ("_active_buffer", "_worker_rng") def __init__( self, dataset: SavableDataset[T_sample], size: int, *, worker_config: WorkerConfig, ): """Create a shuffle buffer for the dataset.""" super().__init__(dataset, worker_config=worker_config) self.size = size self.reset_state_own()
[docs] def reset_state_own(self) -> None: self._worker_rng = WorkerRng(self.worker_config) self._active_buffer = SavableSampleBuffer(self.dataset, worker_config=self.worker_config)
def __len__(self) -> int: return len(self.dataset) def __iter__(self) -> Iterator[T_sample]: self._active_buffer.worker_start() it = iter(self._active_buffer.append_iter()) while True: if len(self._active_buffer) >= self.size: pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) yield self._active_buffer.pop(pop_idx) else: try: next(it) except StopIteration: break while len(self._active_buffer) > 0: pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) yield self._active_buffer.pop(pop_idx)
[docs] def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample: return self._active_buffer.restore_sample(index)
[docs] def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, "dataset": self.dataset.config(), "size": self.size, "worker_config": self.worker_config.config(), }
def __str__(self): return f"ShuffleBufferDataset(size={self.size}, dataset={self.dataset})"