# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Type, Union
from megatron.energon.cache import FileStore
from megatron.energon.cache.file_store import (
SystemFileStore,
WebdatasetFileStore,
)
from megatron.energon.dataset_config import load_config
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.flavors import Sample
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type
from megatron.energon.flavors.webdataset.structs import DatasetSubset
from megatron.energon.metadataset.dataset_loader import DatasetLoader
from megatron.energon.metadataset.join_dataset_loader import JoinDatasetLoader, JoinedDatasetInfo
from megatron.energon.metadataset.loader_interface import (
DatasetBlendMode,
DatasetLoaderInterface,
LoadedDatasetList,
)
from megatron.energon.metadataset.metadataset import Metadataset
from megatron.energon.worker import WorkerConfig
# Regex for any URL-like string (any protocol)
url_regex = re.compile(r"^(?P<protocol>[a-z][a-z0-9+.-]*)://(?P<path>.*)", re.IGNORECASE)
@edataclass
class AuxDatasetReference:
path: Union[str, EPath]
def post_initialize(self, mds_path: Optional[EPath] = None) -> None:
assert mds_path is not None
if not isinstance(self.path, EPath):
self.path = mds_path.parent / self.path
assert not self.path.is_file(), (
"Auxiliary datasets must not be metadataset, but direct dataset references"
)
assert (self.path / MAIN_FOLDER_NAME / "index.sqlite").is_file(), (
"Auxiliary datasets must be prepared Energon dataset"
)
def get_file_store(self) -> FileStore:
assert isinstance(self.path, EPath), "Missing call to post_initialize"
return WebdatasetFileStore(self.path)
@edataclass
class AuxFilesystemReference:
fs_path: Union[str, EPath]
def post_initialize(self, mds_path: Optional[EPath] = None) -> None:
assert mds_path is not None
if not isinstance(self.fs_path, EPath):
self.fs_path = mds_path.parent / self.fs_path
def get_file_store(self) -> FileStore:
assert isinstance(self.fs_path, EPath), "Missing call to post_initialize"
return SystemFileStore(self.fs_path)
@edataclass
class Subset:
"""
A subset range to be applied to a dataset. The range is always consecutive.
The range is a tuple of two values, where the first value is the start of the subset and the second value is the end of the subset (end not included).
The range can either be an absolute range with sample indices, or a ratio of the dataset size.
Relative range example: [25%, 75%]. This would limit the subset to the middle 50% of the dataset.
Absolute range example: [100, 200]. This would limit the subset to the 100 samples with indices 100-199.
For absolute ranges, the end can be set to "end" to indicate the end of the dataset, for example [100, end].
Since subsets can be specified at multiple levels of a hierarchy, for example in a blend,
their effects can be merged to a single subset.
Note however, that absolute ranges are only allowed for leaf datasets, while relative ranges
can be applied at any level.
"""
range: tuple[str | int, str | int]
def as_dataset_subset(self) -> DatasetSubset:
"""Convert the subset with string values to a DatasetSubset object with `range` and `absolute_range`."""
start, end = self.range
def _conv(value: str | int) -> float | int | None:
if isinstance(value, int):
return value
else:
assert isinstance(value, str), "Range must be a string if it's not an integer"
if value.strip() == "end":
return None
assert value.endswith("%"), "Range must be a percentage"
percentage = float(value.removesuffix("%"))
assert 0 <= percentage <= 100, "Percentage must be between 0 and 100"
return percentage / 100.0
start = _conv(start)
end = _conv(end)
if isinstance(start, int):
assert isinstance(end, int) or end is None, (
"End must be an integer if start is an integer"
)
return DatasetSubset(absolute_range=(start, end), range=(0, 1))
else:
assert isinstance(start, float), "Range start must be a float if it's not an integer"
assert isinstance(end, float) or end is None, "End must be a float if start is a float"
assert 0 <= start <= 1, "Start must be between 0 and 1"
assert 0 <= end <= 1, "End must be between 0 and 1"
assert start <= end, "Start must be less than end"
return DatasetSubset(range=(start, end), absolute_range=None)
def merge(self, parent_subset: DatasetSubset | None) -> DatasetSubset:
"""Merge this subset with a parent subset.
If the parent subset is None, return the subset.
If the parent subset is an absolute range, fail, because that's not allowed.
If the parent subset is a ratio, merge it with the subset.
Merging a child absolute range with a parent relative range:
In this case, both are kept in the DatasetSubset object and applies in "absolute first" order later.
Merging a child relative range with a parent relative range:
In this case, the relative parent range is applied to the child's relative range.
The absolute range is not affected.
For details on how this is applied, see `DatasetSubset.compute_subset`.
"""
assert parent_subset is None or parent_subset.absolute_range is None, (
f"Cannot merge absolute subset ranges. Absolute ranges are only allowed for a leaf dataset. {self.absolute_range=} {self.range=}"
)
my_subset = self.as_dataset_subset()
if parent_subset is None or parent_subset.range is None:
return my_subset
# Assuming inner ratio: [0.25, 0.75] and outer ratio: [0, 0.5]
# Then the total ratio is supposed to be: [0.25 + 0*0.5, 0.25 + 0.5 * 0.5] = [0.25, 0.5]
total = my_subset.range[1] - my_subset.range[0]
return DatasetSubset(
range=(
my_subset.range[0] + parent_subset.range[0] * total,
my_subset.range[0] + parent_subset.range[1] * total,
),
absolute_range=my_subset.absolute_range,
)
@edataclass
class SubsetRatioMixin:
subset: Optional[Subset] = None
def _get_subset(self, parent_subset: Optional[DatasetSubset]) -> Optional[DatasetSubset]:
if parent_subset is not None:
assert parent_subset.absolute_range is None, (
f"Can only use absolute subset ranges for a leaf dataset (Range {parent_subset.absolute_range=})"
)
if self.subset is not None:
return self.subset.merge(parent_subset)
else:
return parent_subset
elif self.subset is not None:
return self.subset.merge(None)
return None
@edataclass
class DatasetReference(SubsetRatioMixin, DatasetLoaderInterface):
path: Union[str, EPath]
split_part: Optional[str] = None
subflavors: Optional[Dict[str, Any]] = None
shuffle_over_epochs_multiplier: Optional[int] = 1
dataset_config: Optional[str] = None
split_config: Optional[str] = None
#: Auxiliary datasets. May only be specified for crude datasets for cooking. Cooking will get
# these references to load data from. If specified as string, it will be interpreted as a
# dataset path.
aux: Optional[Dict[str, str]] = None
_dataset: Optional[DatasetLoaderInterface] = None
def post_initialize(self, mds_path: Optional[EPath] = None) -> None:
assert mds_path is not None
if not isinstance(self.path, EPath):
self.path = mds_path.parent / self.path
ds_type = get_dataset_type(self.path)
if ds_type == EnergonDatasetType.METADATASET:
assert self.aux is None, "Cannot specify auxiliary datasets for crude datasets"
assert self.dataset_config is None, "Must not set dataset_config"
assert self.split_config is None, "Must not set split_config"
# Note: For backwards compatibility, the type must be Metadataset (V1).
self._dataset = load_config(
self.path,
default_type=Metadataset,
default_kwargs=dict(path=self.path),
)
self._dataset.post_initialize()
elif ds_type in (EnergonDatasetType.WEBDATASET, EnergonDatasetType.JSONL):
self._dataset = DatasetLoader(
path=self.path,
split_config=self.split_config,
dataset_config=self.dataset_config,
)
self._dataset.post_initialize()
if self.aux is not None:
new_aux = {}
for k, v in self.aux.items():
if m := url_regex.match(v):
if m.group("protocol") == "filesystem":
new_aux[k] = AuxFilesystemReference(fs_path=m.group("path"))
else:
raise ValueError(f"Unsupported protocol: {m.group('protocol')}")
else:
new_aux[k] = AuxDatasetReference(path=v)
new_aux[k].post_initialize(mds_path)
self.aux = new_aux
else:
raise FileNotFoundError(self.path)
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
assert self._dataset is not None
return self._dataset.prepare(split_part=split_part)
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
if self.subflavors is not None:
subflavors = {**self.subflavors, **(subflavors or {})}
assert self._dataset is not None
if shuffle_over_epochs_multiplier is None or self.shuffle_over_epochs_multiplier is None:
# If no shuffling is requested, this has override priority.
new_shuffle_over_epochs_multiplier = None
elif shuffle_over_epochs_multiplier == -1 or self.shuffle_over_epochs_multiplier == -1:
# Next priority is sampling without replacement.
new_shuffle_over_epochs_multiplier = -1
else:
# Otherwise, multiply the shuffle over epochs multiplier.
new_shuffle_over_epochs_multiplier = (
shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier
)
subset = self._get_subset(subset)
result = self._dataset.get_datasets(
training=training,
split_part=self.split_part or split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=new_shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
if self.aux is not None:
aux = {k: v.get_file_store() for k, v in self.aux.items()}
for loaded_dataset in result.datasets:
if loaded_dataset.aux is None:
loaded_dataset.aux = aux
else:
loaded_dataset.aux.update(aux)
return result
@edataclass
class JoinDatasetReference(DatasetReference):
nonmatch: Literal["skip", "none", "error"] = "error"
def post_initialize(self, mds_path: Optional[EPath] = None) -> DatasetLoader:
assert mds_path is not None
# Override and disable another metadataset reference, only allow direct dataset references.
# Do not store the loader, the parent MetadatasetJoin will do that.
if not isinstance(self.path, EPath):
self.path = mds_path.parent / self.path
ds_type = get_dataset_type(self.path)
if ds_type == EnergonDatasetType.WEBDATASET:
return DatasetLoader(
path=self.path,
split_part=self.split_part,
subflavors=self.subflavors,
shuffle_over_epochs_multiplier=self.shuffle_over_epochs_multiplier,
dataset_config=self.dataset_config,
split_config=self.split_config,
)
else:
raise ValueError(f"Not a joinabledataset at {self.path}")
def prepare(self, split_part: Optional[str] = None):
assert False, (
"JoinDatasetReference should not be used directly, but only by MetadatasetJoin"
)
def get_datasets(
self,
**kwargs,
) -> LoadedDatasetList:
assert False, (
"JoinDatasetReference should not be used directly, but only by MetadatasetJoin"
)
@edataclass
class MetadatasetJoin(SubsetRatioMixin, DatasetLoaderInterface):
join: Union[List[JoinDatasetReference], Dict[str, JoinDatasetReference]]
joiner: Union[Type[Sample], Callable[..., Sample]]
split_part: Optional[str] = None
subflavors: Optional[Dict[str, Any]] = None
shuffle_over_epochs_multiplier: Optional[int] = 1
dataset_config: Optional[str] = None
split_config: Optional[str] = None
_dataset: Optional[JoinDatasetLoader] = None
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is not None
assert self.join is not None
assert self.joiner is not None, "Must set joiner for joining datasets"
assert self.dataset_config is None, "Cannot set dataset_config for joining datasets"
assert self.split_config is None, "Cannot set split_config for joining datasets"
if isinstance(self.join, list):
inner_loaders = [
JoinedDatasetInfo(
dataset=join.post_initialize(mds_path),
nonmatch=join.nonmatch,
)
for join in self.join
]
elif isinstance(self.join, dict):
inner_loaders = {
key: JoinedDatasetInfo(
dataset=join.post_initialize(mds_path),
nonmatch=join.nonmatch,
)
for key, join in self.join.items()
}
else:
raise ValueError("Invalid join type")
self._dataset = JoinDatasetLoader(
datasets=inner_loaders,
joiner=self.joiner,
split_part=self.split_part,
subflavors=self.subflavors,
shuffle_over_epochs_multiplier=self.shuffle_over_epochs_multiplier,
split_config=self.split_config,
)
self._dataset.post_initialize(mds_path)
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
assert self._dataset is not None, "Missing post_initialize call."
return self._dataset.prepare(split_part=split_part)
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
assert self._dataset is not None, "Missing post_initialize call."
subset = self._get_subset(subset)
return self._dataset.get_datasets(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
@dataclass
class BlendWeightMixin:
weight: float = 1.0
@edataclass
class BlendDatasetReference(BlendWeightMixin, DatasetReference):
pass
@edataclass
class BlendJoinDatasetReference(BlendWeightMixin, MetadatasetJoin):
pass
@edataclass
class MetadatasetBlend(DatasetLoaderInterface, SubsetRatioMixin):
"""Blending of datasets by specifying the sampling weight for the inner datasets."""
blend: List[Union[BlendDatasetReference, BlendJoinDatasetReference]]
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is not None
for dataset in self.blend:
dataset.post_initialize(mds_path)
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
files = []
for dataset in self.blend:
files.extend(dataset.prepare(split_part=split_part))
return files
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
subset = self._get_subset(subset)
sum_weight = sum(dataset.weight for dataset in self.blend)
datasets = []
for dataset in self.blend:
inner_result = dataset.get_datasets(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
if inner_result.blend_mode not in (
DatasetBlendMode.NONE,
DatasetBlendMode.DATASET_WEIGHT,
):
raise ValueError(
"Can only blend datasets which are of the same blend mode. Cannot mix blend with blend_epochized."
)
for loaded_dataset in inner_result.datasets:
if inner_result.blend_mode == DatasetBlendMode.DATASET_WEIGHT:
assert isinstance(loaded_dataset.weight, float)
else:
assert inner_result.blend_mode == DatasetBlendMode.NONE
assert loaded_dataset.weight is None
assert loaded_dataset.repetitions is None
loaded_dataset.weight = 1.0
loaded_dataset.weight = loaded_dataset.weight * dataset.weight / sum_weight
datasets.append(loaded_dataset)
return LoadedDatasetList(
blend_mode=DatasetBlendMode.DATASET_WEIGHT,
datasets=datasets,
)
@dataclass
class BlendRepetitionsMixin:
repetitions: Union[int, float] = 1
@edataclass
class BlendEpochizedDatasetReference(BlendRepetitionsMixin, DatasetReference):
pass
@edataclass
class BlendEpochizedJoinDatasetReference(BlendRepetitionsMixin, MetadatasetJoin):
pass
@edataclass
class MetadatasetBlendEpochized(SubsetRatioMixin, DatasetLoaderInterface):
"""Blending of datasets, by specifying the number of repetitions for samples from the inner
datasets. Ensures that the constraint, that samples are seen exactly this many times before
repeating the "epoch" (i.e. one epoch contains the total number of repetitions for each inner
dataset)."""
blend_epochized: List[Union[BlendEpochizedDatasetReference, BlendEpochizedJoinDatasetReference]]
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is not None
for dataset in self.blend_epochized:
dataset.post_initialize(mds_path)
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
files = []
for dataset in self.blend_epochized:
files.extend(dataset.prepare(split_part=split_part))
return files
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
subset = self._get_subset(subset)
datasets = []
for dataset in self.blend_epochized:
inner_result = dataset.get_datasets(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
if inner_result.blend_mode not in (
DatasetBlendMode.NONE,
DatasetBlendMode.SAMPLE_REPETITIONS,
):
raise ValueError(
"Can only blend datasets which are of the same blend mode. Cannot mix blend with blend_epochized."
)
for loaded_dataset in inner_result.datasets:
if inner_result.blend_mode == DatasetBlendMode.SAMPLE_REPETITIONS:
assert isinstance(loaded_dataset.repetitions, (int, float))
else:
assert loaded_dataset.weight is None
assert loaded_dataset.repetitions is None
loaded_dataset.repetitions = 1
loaded_dataset.repetitions = dataset.repetitions * loaded_dataset.repetitions
datasets.append(loaded_dataset)
return LoadedDatasetList(
blend_mode=DatasetBlendMode.SAMPLE_REPETITIONS,
datasets=datasets,
)