Source code for megatron.energon.metadataset.dataset_loader

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from megatron.energon.dataclass_slots import dataclass_slots
from megatron.energon.dataset_config import get_dataset_from_config
from megatron.energon.epathlib import EPath
from megatron.energon.flavors import BaseCoreDatasetFactory
from megatron.energon.metadataset.loader_interface import DatasetBlendMode, DatasetLoaderInterface
from megatron.energon.worker import WorkerConfig


[docs] @dataclass_slots class DatasetLoader(DatasetLoaderInterface): """Loads a dataset from a path.""" path: Union[str, EPath] split_part: Optional[str] = None subflavor: Optional[str] = None subflavors: Optional[Dict[str, Any]] = None shuffle_over_epochs_multiplier: Optional[int] = 1 dataset_config: str = "dataset.yaml" split_config: str = "split.yaml"
[docs] def post_initialize(self, mds_path: Optional[EPath] = None): pass
[docs] def get_dataset( self, *, training: bool, split_part: Optional[str] = None, worker_config: WorkerConfig, subflavor: Optional[str] = None, subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs: Optional[int] = 1, split_config: Optional[str] = None, dataset_config: Optional[str] = None, **kwargs, ) -> BaseCoreDatasetFactory: """ Args: training: If true, apply training randomization. split_part: Default split part to use. worker_config: Worker configuration. shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding). subflavor: Subflavor to use, might be overridden by inner datasets. subflavors: Subflavors to use, might be overridden by inner datasets. shuffle_over_epochs: Shuffle the dataset over this many epochs. **kwargs: Additional arguments to the dataset constructor. Returns: The loaded dataset """ if self.split_part is not None: split_part = self.split_part if split_part is None: raise ValueError("Missing split part") if subflavor is None: subflavor = self.subflavor if self.subflavors is not None: subflavors = {**self.subflavors, **(subflavors or {})} if split_config is None: split_config = self.split_config if dataset_config is None: dataset_config = self.dataset_config return get_dataset_from_config( self.path, training=training, split_part=split_part, worker_config=worker_config, subflavor=subflavor, subflavors=subflavors, dataset_config=dataset_config, split_config=split_config, shuffle_over_epochs=shuffle_over_epochs, **kwargs, )
[docs] def get_datasets( self, *, training: bool, split_part: Union[Literal["train", "val", "test"], str], worker_config: WorkerConfig, subflavor: Optional[str] = None, subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, **kwargs, ) -> Tuple[DatasetBlendMode, List[Tuple[BaseCoreDatasetFactory, Union[float, int, None]]]]: return DatasetBlendMode.NONE, [ ( self.get_dataset( training=training, split_part=split_part, worker_config=worker_config, subflavor=subflavor, subflavors=subflavors, shuffle_over_epochs=shuffle_over_epochs_multiplier, **kwargs, ), None, ) ]