Source code for megatron.energon.metadataset.loader_interface

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union

from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import BaseCoreDatasetFactory
from megatron.energon.worker import WorkerConfig


class DatasetBlendMode(Enum):
    """Determines how the the datasets are to be blended. Either by using the associated number as
    the weight for sampling from that dataset, or alternatively by using the number as the number
    of repetitions for samples in that dataset in one epoch (effectively, that corresponds to the
    weight for samples)."""

    NONE = "none"
    DATASET_WEIGHT = "dataset_weight"
    SAMPLE_REPETITIONS = "sample_repetitions"


[docs] class DatasetLoaderInterface(ABC): """General interface for a dataset loader."""
[docs] @abstractmethod def post_initialize(self, mds_path: Optional[EPath] = None): """Called to finally initialize the dataset.""" ...
[docs] @abstractmethod def get_datasets( self, *, training: bool, split_part: Union[Literal["train", "val", "test"], str], worker_config: WorkerConfig, subflavor: Optional[str] = None, subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, **kwargs, ) -> Tuple[DatasetBlendMode, List[Tuple[BaseCoreDatasetFactory, Union[float, int, None]]]]: """ Calls :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset) for all innermost datasets and resolves their relative weights to absolute weights. Args: training: If true, apply training randomization. split_part: Default split part to use. worker_config: Worker configuration to use. subflavor: Set the default subflavor for all datasets. subflavors: Set the default subflavors for all datasets. shuffle_over_epochs_multiplier: Multiply the inner datasets `shuffle_over_epochs(_multiplier)` by this factor. E.g. if the inner dataset has `shuffle_over_epochs_multiplier=2` and this function has `shuffle_over_epochs_multiplier=3`, the inner dataset will be shuffled over 6 epochs. Shuffling over `n` epochs guarantees that each sample is seen exactly `n` times in `n` epochs of the inner dataset. Use -1 for shuffling over an infinite number of epochs (effectively, this will draw shard slices with replacement). **kwargs: Additional arguments to the dataset constructor. Returns: The dataset blending mode and the instantiated core datasets with their weights/repetitions. """ ...
[docs] def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]: """ Prepares the loader by creating caches and other necessary structures on disk. Args: split_part: Name of the split to load. Returns: List of paths to the cache paths. This is used for cleanup. """ return ()