Source code for megatron.energon.task_encoder.loader

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path
from typing import Any, List, Literal, Optional, Tuple, TypeVar, Union

from megatron.energon.epathlib import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors import BaseCoreDatasetFactory, SavableDataset
from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder
from megatron.energon.metadataset import load_dataset
from megatron.energon.task_encoder.base import DefaultTaskEncoder, TaskEncoder, WorkerConfig

T = TypeVar("T", covariant=True)


def _split_kwargs(kwargs: dict) -> dict:
    loader_kwargs = {}
    try:
        loader_kwargs["split_part"] = kwargs.pop("split_part")
    except KeyError:
        pass
    try:
        loader_kwargs["dataset_config"] = kwargs.pop("dataset_config")
    except KeyError:
        pass
    try:
        loader_kwargs["split_config"] = kwargs.pop("split_config")
    except KeyError:
        pass
    return loader_kwargs


def _split_deprecated_decoder_kwargs(kwargs: dict, task_encoder: TaskEncoder) -> None:
    """
    auto_decode: bool = True,
    image_decode: ImageDecoder = "torchrgb",
    ignore_decoder_errors: bool = False,
    av_decode: AVDecoder = "AVDecoder",
    video_decode_audio: bool = False,
    """
    auto_decode = True

    decoder_kwargs = {}
    if "auto_decode" in kwargs:
        auto_decode = kwargs.pop("auto_decode")
    if "image_decode" in kwargs:
        decoder_kwargs["image_decode"] = kwargs.pop("image_decode")
    if "av_decode" in kwargs:
        decoder_kwargs["av_decode"] = kwargs.pop("av_decode")
    if "video_decode_audio" in kwargs:
        decoder_kwargs["video_decode_audio"] = kwargs.pop("video_decode_audio")

    if not auto_decode:
        task_encoder.decoder = None
    elif len(decoder_kwargs) > 0:
        warn_deprecated(
            "The following decoder kwargs are deprecated and will be removed in a future version: "
            + ", ".join(decoder_kwargs.keys())
            + ". Instead, set the decoder directly in your task encoder."
        )

        assert (
            not hasattr(task_encoder, "decoder")
            or task_encoder.decoder is DefaultTaskEncoder.decoder
        ), "Task encoder already has a decoder, and setting using deprecated kwargs is not allowed."

        task_encoder.decoder = SampleDecoder(**decoder_kwargs)


[docs] def get_train_dataset( path: Union[str, EPath, Path], *, split_part: Union[Literal["train"], str] = "train", worker_config: WorkerConfig, batch_size: Optional[int], batch_drop_last: bool = False, packing_buffer_size: Optional[int] = None, shuffle_buffer_size: Optional[int], max_samples_per_sequence: Optional[int], virtual_epoch_length: int = 0, shuffle_over_epochs_multiplier: Optional[int] = 1, task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(), repeat: bool = True, **kwargs, ) -> SavableDataset[T]: """ Get training data loader with sensible defaults. See `get_dataset` for more details. The following recipe will be used: - :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset) - `task_encoder.encode_sample` - (:class:`megatron.energon.MixDataset` if mixing) - :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation - `task_encoder.encode_batch` - :class:`megatron.energon.EpochizeDataset` (if `virtual_epoch_length` is set) Args: path: Path to the dataset. split_part: Default split part to use. worker_config: Worker configuration to use. batch_size: Size of a batch. If None, do not batch batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`. shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding). max_samples_per_sequence: If set, limit the number of samples per sample-sequence to this. virtual_epoch_length: If set, the dataset will be epochized to this length (=iterating will be suspended and the for-loop returns, next for-loop continues iterating). Otherwise, the dataset will loop indefinitely. shuffle_over_epochs_multiplier: Shuffle the shards over this many epochs. task_encoder: Task encoder to use. repeat: By default, the inner datasets will loop. If set to False, stop iteration after one epoch. Must only be set to False in conjunction with blend_epochized in the metadataset if one is used. cache_pool: If set, the cache pool to use for the dataset. **kwargs: Additional arguments to the dataset constructor. Returns: The dataloader. """ loader = load_dataset(path, **_split_kwargs(kwargs)) _split_deprecated_decoder_kwargs(kwargs, task_encoder) datasets = loader.get_datasets( training=True, split_part=split_part, worker_config=worker_config, max_samples_per_sequence=max_samples_per_sequence, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, decoder=task_encoder.decoder, **kwargs, ) return task_encoder.build_train_datasets( datasets=datasets.datasets, worker_config=worker_config, batch_size=batch_size, batch_drop_last=batch_drop_last, packing_buffer_size=packing_buffer_size, virtual_epoch_length=virtual_epoch_length, shuffle_buffer_size=shuffle_buffer_size, blend_mode=datasets.blend_mode, repeat=repeat, )
[docs] def get_val_dataset( path: Union[str, EPath, Path], *, split_part: Union[Literal["val", "test"], str] = "val", worker_config: WorkerConfig, batch_size: int, batch_drop_last: bool = False, packing_buffer_size: Optional[int] = None, limit: Optional[int] = None, task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(), **kwargs, ) -> SavableDataset[T]: """ Get the validation/test dataset with sensible defaults. See `get_dataset` for more details. The following recipe will be used: - :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset) - `task_encoder.encode_sample` - (:class:`megatron.energon.MixDataset` if mixing) - :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation - :class:`megatron.energon.LimitDataset` (if `limit` is set) - `task_encoder.encode_batch` Args: path: Path to the dataset. split_part: Default split part to use. worker_config: Worker configuration to use. batch_size: Size of a batch batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`. limit: If set, limit the number of batches loaded from the dataset to this. task_encoder: Task encoder to use. **kwargs: Additional arguments to the dataset constructor. Returns: The loaded dataset. """ _split_deprecated_decoder_kwargs(kwargs, task_encoder) loader = load_dataset(path, **_split_kwargs(kwargs)) datasets = loader.get_datasets( training=False, split_part=split_part, worker_config=worker_config, decoder=task_encoder.decoder, **kwargs, ) return task_encoder.build_val_datasets( datasets=datasets.datasets, worker_config=worker_config, batch_size=batch_size, batch_drop_last=batch_drop_last, packing_buffer_size=packing_buffer_size, limit=limit, )
[docs] def get_val_datasets( path: Union[str, EPath, Path], *, split_part: Union[Literal["val", "test"], str] = "val", worker_config: WorkerConfig, batch_size: int, batch_drop_last: bool = False, packing_buffer_size: Optional[int] = None, limit: Optional[int] = None, task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(), **kwargs, ) -> List[Tuple[SavableDataset[T], BaseCoreDatasetFactory]]: """ Get the validation/test dataset with sensible defaults. See `get_dataset` for more details. The following recipe will be used: - :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset) - `task_encoder.encode_sample` - (:class:`megatron.energon.MixDataset` if mixing) - :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation - :class:`megatron.energon.LimitDataset` (if `limit` is set) - `task_encoder.encode_batch` Args: path: Path to the dataset. split_part: Default split part to use. worker_config: Worker configuration to use. batch_size: Size of a batch batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`. limit: If set, limit the number of batches loaded from the dataset to this. task_encoder: Task encoder to use. **kwargs: Additional arguments to the dataset constructor. Returns: The loaded val datasets, with the source datasets. """ _split_deprecated_decoder_kwargs(kwargs, task_encoder) loader = load_dataset(path, **_split_kwargs(kwargs)) datasets = loader.get_datasets( training=False, split_part=split_part, worker_config=worker_config, decoder=task_encoder.decoder, **kwargs, ) return [ ( task_encoder.build_val_datasets( datasets=[dataset], worker_config=worker_config, batch_size=batch_size, batch_drop_last=batch_drop_last, packing_buffer_size=packing_buffer_size, limit=limit, ), dataset.dataset, ) for dataset in datasets.datasets ]