# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from pathlib import Path
from typing import Any, List, Literal, Optional, Tuple, TypeVar, Union
from megatron.energon.epathlib import EPath
from megatron.energon.flavors import BaseCoreDatasetFactory, SavableDataset
from megatron.energon.metadataset import load_dataset
from megatron.energon.metadataset.loader_interface import DatasetBlendMode
from megatron.energon.task_encoder.base import DefaultTaskEncoder, TaskEncoder, WorkerConfig
T = TypeVar("T", covariant=True)
def _split_kwargs(kwargs: dict) -> dict:
loader_kwargs = {}
try:
loader_kwargs["split_part"] = kwargs.pop("split_part")
except KeyError:
pass
try:
loader_kwargs["subflavor"] = kwargs.pop("subflavor")
except KeyError:
pass
try:
loader_kwargs["dataset_config"] = kwargs.pop("dataset_config")
except KeyError:
pass
try:
loader_kwargs["split_config"] = kwargs.pop("split_config")
except KeyError:
pass
return loader_kwargs
[docs]
def get_train_dataset(
path: Union[str, EPath, Path],
*,
split_part: Union[Literal["train"], str] = "train",
worker_config: WorkerConfig,
batch_size: Optional[int],
batch_drop_last: bool = False,
packing_buffer_size: Optional[int] = None,
shuffle_buffer_size: Optional[int],
max_samples_per_sequence: Optional[int],
virtual_epoch_length: int = 0,
shuffle_over_epochs_multiplier: Optional[int] = 1,
task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(),
repeat: bool = True,
**kwargs,
) -> SavableDataset[T]:
"""
Get training data loader with sensible defaults. See `get_dataset` for more details.
The following recipe will be used:
- :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
- `task_encoder.encode_sample`
- (:class:`megatron.energon.MixDataset` if mixing)
- :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
- `task_encoder.encode_batch`
- :class:`megatron.energon.EpochizeDataset` (if `virtual_epoch_length` is set)
Args:
path: Path to the dataset.
split_part: Default split part to use.
worker_config: Worker configuration to use.
batch_size: Size of a batch. If None, do not batch
batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding).
max_samples_per_sequence: If set, limit the number of samples per sample-sequence to this.
virtual_epoch_length: If set, the dataset will be epochized to this length (=iterating
will be suspended and the for-loop returns, next for-loop continues iterating).
Otherwise, the dataset will loop indefinitely.
shuffle_over_epochs_multiplier: Shuffle the shards over this many epochs.
task_encoder: Task encoder to use.
repeat: By default, the inner datasets will loop. If set to False, stop iteration after
one epoch. Must only be set to False in conjunction with blend_epochized in the
metadataset if one is used.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The dataloader.
"""
loader = load_dataset(path, **_split_kwargs(kwargs))
blend_mode, datasets = loader.get_datasets(
training=True,
split_part=split_part,
worker_config=worker_config,
max_samples_per_sequence=max_samples_per_sequence,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
**kwargs,
)
assert isinstance(blend_mode, DatasetBlendMode)
assert isinstance(datasets, list)
assert all(isinstance(d, tuple) and len(d) == 2 for d in datasets)
assert all(
isinstance(dataset, BaseCoreDatasetFactory) and isinstance(value, (type(None), int, float))
for dataset, value in datasets
)
return task_encoder.build_train_datasets(
datasets=datasets,
worker_config=worker_config,
batch_size=batch_size,
batch_drop_last=batch_drop_last,
packing_buffer_size=packing_buffer_size,
virtual_epoch_length=virtual_epoch_length,
shuffle_buffer_size=shuffle_buffer_size,
blend_mode=blend_mode,
repeat=repeat,
)
[docs]
def get_val_dataset(
path: Union[str, EPath, Path],
*,
split_part: Union[Literal["val", "test"], str] = "val",
worker_config: WorkerConfig,
batch_size: int,
batch_drop_last: bool = False,
packing_buffer_size: Optional[int] = None,
limit: Optional[int] = None,
task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(),
**kwargs,
) -> SavableDataset[T]:
"""
Get the validation/test dataset with sensible defaults. See `get_dataset` for more details.
The following recipe will be used:
- :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
- `task_encoder.encode_sample`
- (:class:`megatron.energon.MixDataset` if mixing)
- :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
- :class:`megatron.energon.LimitDataset` (if `limit` is set)
- `task_encoder.encode_batch`
Args:
path: Path to the dataset.
split_part: Default split part to use.
worker_config: Worker configuration to use.
batch_size: Size of a batch
batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
limit: If set, limit the number of batches loaded from the dataset to this.
task_encoder: Task encoder to use.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The loaded dataset.
"""
loader = load_dataset(path, **_split_kwargs(kwargs))
_blend_mode, datasets = loader.get_datasets(
training=False, split_part=split_part, worker_config=worker_config, **kwargs
)
assert isinstance(_blend_mode, DatasetBlendMode)
assert isinstance(datasets, list)
assert all(isinstance(d, tuple) and len(d) == 2 for d in datasets)
assert all(
isinstance(dataset, BaseCoreDatasetFactory) and isinstance(value, (type(None), int, float))
for dataset, value in datasets
)
return task_encoder.build_val_datasets(
datasets=[dataset for dataset, _weight in datasets],
worker_config=worker_config,
batch_size=batch_size,
batch_drop_last=batch_drop_last,
packing_buffer_size=packing_buffer_size,
limit=limit,
)
[docs]
def get_val_datasets(
path: Union[str, EPath, Path],
*,
split_part: Union[Literal["val", "test"], str] = "val",
worker_config: WorkerConfig,
batch_size: int,
batch_drop_last: bool = False,
packing_buffer_size: Optional[int] = None,
limit: Optional[int] = None,
task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(),
**kwargs,
) -> List[Tuple[SavableDataset[T], BaseCoreDatasetFactory]]:
"""
Get the validation/test dataset with sensible defaults. See `get_dataset` for more details.
The following recipe will be used:
- :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
- `task_encoder.encode_sample`
- (:class:`megatron.energon.MixDataset` if mixing)
- :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
- :class:`megatron.energon.LimitDataset` (if `limit` is set)
- `task_encoder.encode_batch`
Args:
path: Path to the dataset.
split_part: Default split part to use.
worker_config: Worker configuration to use.
batch_size: Size of a batch
batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
limit: If set, limit the number of batches loaded from the dataset to this.
task_encoder: Task encoder to use.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The loaded val datasets, with the source datasets.
"""
loader = load_dataset(path, **_split_kwargs(kwargs))
_blend_mode, datasets = loader.get_datasets(
training=False, split_part=split_part, worker_config=worker_config, **kwargs
)
assert isinstance(_blend_mode, DatasetBlendMode)
assert isinstance(datasets, list)
assert all(isinstance(d, tuple) and len(d) == 2 for d in datasets)
assert all(
isinstance(dataset, BaseCoreDatasetFactory) and isinstance(value, (type(None), int, float))
for dataset, value in datasets
)
return [
(
task_encoder.build_val_datasets(
datasets=[dataset],
worker_config=worker_config,
batch_size=batch_size,
batch_drop_last=batch_drop_last,
packing_buffer_size=packing_buffer_size,
limit=limit,
),
dataset,
)
for dataset, _weight in datasets
]