Source code for megatron.energon.metadataset.metadataset_v2

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import re
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Sequence,
    Type,
    Union,
)

from megatron.energon.cache import FileStore
from megatron.energon.cache.file_store import (
    SystemFileStore,
    WebdatasetFileStore,
)
from megatron.energon.dataset_config import load_config
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.flavors import Sample
from megatron.energon.flavors.webdataset.config import INDEX_SQLITE_FILENAME, MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type
from megatron.energon.flavors.webdataset.structs import DatasetSubset
from megatron.energon.metadataset.dataset_loader import DatasetLoader
from megatron.energon.metadataset.join_dataset_loader import JoinDatasetLoader, JoinedDatasetInfo
from megatron.energon.metadataset.loader_interface import (
    DatasetBlendMode,
    DatasetLoaderInterface,
    LoadedDatasetList,
    TraversedDatasetReference,
)
from megatron.energon.metadataset.metadataset import Metadataset
from megatron.energon.worker import WorkerConfig

# Regex for any URL-like string (any protocol)
url_regex = re.compile(r"^(?P<protocol>[a-z][a-z0-9+.-]*)://(?P<path>.*)", re.IGNORECASE)


@edataclass
class AuxDatasetReference:
    path: Union[str, EPath]

    def _resolve_path(self, mds_path: Optional[EPath]) -> EPath:
        assert mds_path is not None
        if not isinstance(self.path, EPath):
            self.path = mds_path.parent / self.path
        return self.path

    def post_initialize(self, mds_path: Optional[EPath] = None) -> None:
        self._resolve_path(mds_path)
        assert not self.path.is_file(), (
            "Auxiliary datasets must not be metadataset, but direct dataset references"
        )
        assert (self.path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).is_file(), (
            "Auxiliary datasets must be prepared Energon datasets. This one does not exist or is not prepared: "
            + str(self.path)
        )

    def get_file_store(self) -> FileStore:
        assert isinstance(self.path, EPath), "Missing call to post_initialize"
        return WebdatasetFileStore(self.path)


@edataclass
class AuxFilesystemReference:
    fs_path: Union[str, EPath]

    def _resolve_path(self, mds_path: Optional[EPath]) -> EPath:
        assert mds_path is not None
        if not isinstance(self.fs_path, EPath):
            self.fs_path = mds_path.parent / self.fs_path
        return self.fs_path

    def post_initialize(self, mds_path: Optional[EPath] = None) -> None:
        self._resolve_path(mds_path)

    def get_file_store(self) -> FileStore:
        assert isinstance(self.fs_path, EPath), "Missing call to post_initialize"
        return SystemFileStore(self.fs_path)


@edataclass
class Subset:
    """
    A subset range to be applied to a dataset. The range is always consecutive.

    The range is a tuple of two values, where the first value is the start of the subset and the second value is the end of the subset (end not included).
    The range can either be an absolute range with sample indices, or a ratio of the dataset size.
    Relative range example: [25%, 75%]. This would limit the subset to the middle 50% of the dataset.
    Absolute range example: [100, 200]. This would limit the subset to the 100 samples with indices 100-199.
    For absolute ranges, the end can be set to "end" to indicate the end of the dataset, for example [100, end].

    Since subsets can be specified at multiple levels of a hierarchy, for example in a blend,
    their effects can be merged to a single subset.
    Note however, that absolute ranges are only allowed for leaf datasets, while relative ranges
    can be applied at any level.
    """

    range: tuple[str | int, str | int]

    def as_dataset_subset(self) -> DatasetSubset:
        """Convert the subset with string values to a DatasetSubset object with `range` and `absolute_range`."""

        start, end = self.range

        def _conv(value: str | int) -> float | int | None:
            if isinstance(value, int):
                return value
            else:
                assert isinstance(value, str), "Range must be a string if it's not an integer"
                if value.strip() == "end":
                    return None
                assert value.endswith("%"), "Range must be a percentage"
                percentage = float(value.removesuffix("%"))
                assert 0 <= percentage <= 100, "Percentage must be between 0 and 100"
                return percentage / 100.0

        start = _conv(start)
        end = _conv(end)

        if isinstance(start, int):
            assert isinstance(end, int) or end is None, (
                "End must be an integer if start is an integer"
            )
            return DatasetSubset(absolute_range=(start, end), range=(0, 1))
        else:
            assert isinstance(start, float), "Range start must be a float if it's not an integer"
            assert isinstance(end, float) or end is None, "End must be a float if start is a float"
            assert 0 <= start <= 1, "Start must be between 0 and 1"
            assert 0 <= end <= 1, "End must be between 0 and 1"
            assert start <= end, "Start must be less than end"
            return DatasetSubset(range=(start, end), absolute_range=None)

    def merge(self, parent_subset: DatasetSubset | None) -> DatasetSubset:
        """Merge this subset with a parent subset.

        If the parent subset is None, return the subset.
        If the parent subset is an absolute range, fail, because that's not allowed.
        If the parent subset is a ratio, merge it with the subset.

        Merging a child absolute range with a parent relative range:
        In this case, both are kept in the DatasetSubset object and applies in "absolute first" order later.

        Merging a child relative range with a parent relative range:
        In this case, the relative parent range is applied to the child's relative range.
        The absolute range is not affected.

        For details on how this is applied, see `DatasetSubset.compute_subset`.
        """

        assert parent_subset is None or parent_subset.absolute_range is None, (
            f"Cannot merge absolute subset ranges. Absolute ranges are only allowed for a leaf dataset. {self.absolute_range=} {self.range=}"
        )
        my_subset = self.as_dataset_subset()
        if parent_subset is None or parent_subset.range is None:
            return my_subset

        # Assuming inner ratio: [0.25, 0.75] and outer ratio: [0, 0.5]
        # Then the total ratio is supposed to be: [0.25 + 0*0.5, 0.25 + 0.5 * 0.5] = [0.25, 0.5]
        total = my_subset.range[1] - my_subset.range[0]
        return DatasetSubset(
            range=(
                my_subset.range[0] + parent_subset.range[0] * total,
                my_subset.range[0] + parent_subset.range[1] * total,
            ),
            absolute_range=my_subset.absolute_range,
        )


@edataclass
class SubsetRatioMixin:
    subset: Optional[Subset] = None

    def _get_subset(self, parent_subset: Optional[DatasetSubset]) -> Optional[DatasetSubset]:
        if parent_subset is not None:
            assert parent_subset.absolute_range is None, (
                f"Can only use absolute subset ranges for a leaf dataset (Range {parent_subset.absolute_range=})"
            )
            if self.subset is not None:
                return self.subset.merge(parent_subset)
            else:
                return parent_subset
        elif self.subset is not None:
            return self.subset.merge(None)
        return None


@edataclass
class DatasetReference(SubsetRatioMixin, DatasetLoaderInterface):
    path: Union[str, EPath]

    split_part: Optional[str] = None
    subflavors: Optional[Dict[str, Any]] = None
    shuffle_over_epochs_multiplier: Optional[int] = 1
    dataset_config: Optional[str] = None
    split_config: Optional[str] = None

    #: Auxiliary datasets. May only be specified for crude datasets for cooking. Cooking will get
    # these references to load data from. If specified as string, it will be interpreted as a
    # dataset path.
    aux: Optional[Dict[str, Union[str, AuxDatasetReference, AuxFilesystemReference]]] = None

    _dataset: Optional[DatasetLoaderInterface] = None

    def _resolve_path(self, mds_path: Optional[EPath]) -> EPath:
        assert mds_path is not None
        if not isinstance(self.path, EPath):
            self.path = mds_path.parent / self.path
        return self.path

    @staticmethod
    def _normalize_aux_reference(
        reference: Union[str, AuxDatasetReference, AuxFilesystemReference],
    ) -> Union[AuxDatasetReference, AuxFilesystemReference]:
        if isinstance(reference, (AuxDatasetReference, AuxFilesystemReference)):
            return reference
        if m := url_regex.match(reference):
            prot = m.group("protocol")
            if prot.count("+") == 1:
                # filesystem+fs_prot://
                fs_type, fs_prot = prot.split("+")
                assert fs_type == "filesystem"
                path = f"{fs_prot}://{m.group('path')}"
            elif prot == "filesystem":
                # filesystem:// (may be relative or absolute)
                fs_type = "filesystem"
                path = m.group("path")
            else:
                # msc:// or other protocol
                fs_type = None
                path = reference
            # With filesystem or without.
            if fs_type == "filesystem":
                return AuxFilesystemReference(fs_path=path)
            assert fs_type is None, f"Invalid filesystem type: {fs_type} in path {reference}"
            return AuxDatasetReference(path=path)
        return AuxDatasetReference(path=reference)

    def _normalize_aux_references(self, mds_path: Optional[EPath], *, validate: bool) -> None:
        if self.aux is None:
            return
        new_aux: Dict[str, Union[AuxDatasetReference, AuxFilesystemReference]] = {}
        for key, value in self.aux.items():
            normalized = self._normalize_aux_reference(value)
            if validate:
                normalized.post_initialize(mds_path)
            else:
                normalized._resolve_path(mds_path)
            new_aux[key] = normalized
        self.aux = new_aux

    def _get_traversed_aux_references(self) -> dict[str, EPath]:
        if self.aux is None:
            return {}
        traversed_aux: dict[str, EPath] = {}
        for key, value in self.aux.items():
            if isinstance(value, AuxDatasetReference):
                assert isinstance(value.path, EPath)
                traversed_aux[key] = value.path
            else:
                assert isinstance(value, AuxFilesystemReference)
                assert isinstance(value.fs_path, EPath)
                traversed_aux[key] = value.fs_path
        return traversed_aux

    def _merge_traversed_subflavors(
        self, inherited_subflavors: Optional[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """Merge this reference's subflavors with the inherited traversal subflavors.

        The merge order mirrors `get_datasets(...)`: this reference contributes the base mapping,
        and inherited outer-hierarchy subflavors override on key conflicts.

        Args:
            inherited_subflavors: Effective subflavors accumulated from outer metadataset
                references during traversal.

        Returns:
            The effective subflavor mapping for this reference, after applying outer-overrides-inner
            merge semantics.
        """
        if self.subflavors is not None:
            return {**self.subflavors, **(inherited_subflavors or {})}
        return dict(inherited_subflavors or {})

    def _load_nested_metadataset(self) -> DatasetLoaderInterface:
        assert isinstance(self.path, EPath)
        assert self.aux is None, "Cannot specify auxiliary datasets for crude datasets"
        assert self.dataset_config is None, "Must not set dataset_config"
        assert self.split_config is None, "Must not set split_config"
        # Note: For backwards compatibility, the type must be Metadataset (V1).
        return load_config(
            self.path,
            default_type=Metadataset,
            default_kwargs=dict(path=self.path),
        )

    def post_initialize(self, mds_path: Optional[EPath] = None) -> None:
        self._resolve_path(mds_path)
        ds_type = get_dataset_type(self.path)
        if ds_type == EnergonDatasetType.METADATASET:
            self._dataset = self._load_nested_metadataset()
            self._dataset.post_initialize()
        elif ds_type in (EnergonDatasetType.WEBDATASET, EnergonDatasetType.JSONL):
            self._dataset = DatasetLoader(
                path=self.path,
                split_config=self.split_config,
                dataset_config=self.dataset_config,
            )
            self._dataset.post_initialize()
            self._normalize_aux_references(mds_path, validate=True)
        elif ds_type == EnergonDatasetType.FILESYSTEM:
            raise ValueError(
                "Filesystem datasets are not supported within metadatasets except as auxiliary datasets."
            )
        else:
            raise FileNotFoundError(self.path)

    def traverse(
        self,
        mds_path: Optional[EPath] = None,
        *,
        split_part: Union[Literal["train", "val", "test"], str],
        _subflavors: Optional[Dict[str, Any]] = None,
    ) -> List[TraversedDatasetReference]:
        """Traverse this V2 dataset reference into flattened leaf references.

        For direct leaf datasets, traversal resolves the dataset path and any auxiliary references
        into plain `EPath` values. For nested metadatasets, traversal recurses immediately into the
        referenced split instead of building an intermediate object graph.

        Args:
            mds_path: Parent metadataset path used internally to resolve relative dataset and
                auxiliary paths. Must be set for nested references and inner traversal nodes;
                use None only for top-level metadatasets.
            split_part: Split inherited from the parent traversal. If this reference defines its
                own split override, that split takes precedence for nested traversal and the
                returned leaf reference.

        Returns:
            A single leaf `TraversedDatasetReference` for direct dataset references, or the
            flattened traversal result of the nested metadataset when this reference points to one.
        """
        self._resolve_path(mds_path)
        effective_subflavors = self._merge_traversed_subflavors(_subflavors)
        ds_type = get_dataset_type(self.path)
        if ds_type == EnergonDatasetType.METADATASET:
            return self._load_nested_metadataset().traverse(
                split_part=self.split_part or split_part,
                _subflavors=effective_subflavors,
            )
        self._normalize_aux_references(mds_path, validate=False)
        return [
            TraversedDatasetReference(
                path=self.path,
                split_part=self.split_part or split_part,
                aux=self._get_traversed_aux_references(),
                subflavors=effective_subflavors,
            )
        ]

    def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
        assert self._dataset is not None
        return self._dataset.prepare(split_part=split_part)

    def get_datasets(
        self,
        *,
        training: bool,
        split_part: Union[Literal["train", "val", "test"], str],
        worker_config: WorkerConfig,
        subflavors: Optional[Dict[str, Any]] = None,
        shuffle_over_epochs_multiplier: Optional[int] = 1,
        subset: Optional[DatasetSubset] = None,
        **kwargs,
    ) -> LoadedDatasetList:
        if self.subflavors is not None:
            subflavors = {**self.subflavors, **(subflavors or {})}
        assert self._dataset is not None

        if shuffle_over_epochs_multiplier is None or self.shuffle_over_epochs_multiplier is None:
            # If no shuffling is requested, this has override priority.
            new_shuffle_over_epochs_multiplier = None
        elif shuffle_over_epochs_multiplier == -1 or self.shuffle_over_epochs_multiplier == -1:
            # Next priority is sampling without replacement.
            new_shuffle_over_epochs_multiplier = -1
        else:
            # Otherwise, multiply the shuffle over epochs multiplier.
            new_shuffle_over_epochs_multiplier = (
                shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier
            )
        subset = self._get_subset(subset)

        result = self._dataset.get_datasets(
            training=training,
            split_part=self.split_part or split_part,
            worker_config=worker_config,
            subflavors=subflavors,
            shuffle_over_epochs_multiplier=new_shuffle_over_epochs_multiplier,
            subset=subset,
            **kwargs,
        )
        if self.aux is not None:
            aux = {k: v.get_file_store() for k, v in self.aux.items()}
            for loaded_dataset in result.datasets:
                if loaded_dataset.aux is None:
                    loaded_dataset.aux = aux
                else:
                    loaded_dataset.aux.update(aux)
        return result


@edataclass
class JoinDatasetReference(DatasetReference):
    nonmatch: Literal["skip", "none", "error"] = "error"

    def post_initialize(self, mds_path: Optional[EPath] = None) -> DatasetLoader:
        assert mds_path is not None
        # Override and disable another metadataset reference, only allow direct dataset references.
        # Do not store the loader, the parent MetadatasetJoin will do that.
        self._resolve_path(mds_path)
        ds_type = get_dataset_type(self.path)
        if ds_type == EnergonDatasetType.WEBDATASET:
            return DatasetLoader(
                path=self.path,
                split_part=self.split_part,
                subflavors=self.subflavors,
                shuffle_over_epochs_multiplier=self.shuffle_over_epochs_multiplier,
                dataset_config=self.dataset_config,
                split_config=self.split_config,
            )
        else:
            raise ValueError(f"Not a joinabledataset at {self.path}")

    def traverse(
        self,
        mds_path: Optional[EPath] = None,
        *,
        split_part: Union[Literal["train", "val", "test"], str],
        _subflavors: Optional[Dict[str, Any]] = None,
    ) -> List[TraversedDatasetReference]:
        raise NotImplementedError("traverse_metadataset() does not support joined datasets.")

    def prepare(self, split_part: Optional[str] = None):
        assert False, (
            "JoinDatasetReference should not be used directly, but only by MetadatasetJoin"
        )

    def get_datasets(
        self,
        **kwargs,
    ) -> LoadedDatasetList:
        assert False, (
            "JoinDatasetReference should not be used directly, but only by MetadatasetJoin"
        )


@edataclass
class MetadatasetJoin(SubsetRatioMixin, DatasetLoaderInterface):
    join: Union[List[JoinDatasetReference], Dict[str, JoinDatasetReference]]
    joiner: Union[Type[Sample], Callable[..., Sample]]

    split_part: Optional[str] = None
    subflavors: Optional[Dict[str, Any]] = None
    shuffle_over_epochs_multiplier: Optional[int] = 1
    dataset_config: Optional[str] = None
    split_config: Optional[str] = None

    _dataset: Optional[JoinDatasetLoader] = None

    def post_initialize(self, mds_path: Optional[EPath] = None):
        assert mds_path is not None
        assert self.join is not None
        assert self.joiner is not None, "Must set joiner for joining datasets"
        assert self.dataset_config is None, "Cannot set dataset_config for joining datasets"
        assert self.split_config is None, "Cannot set split_config for joining datasets"
        if isinstance(self.join, list):
            inner_loaders = [
                JoinedDatasetInfo(
                    dataset=join.post_initialize(mds_path),
                    nonmatch=join.nonmatch,
                )
                for join in self.join
            ]
        elif isinstance(self.join, dict):
            inner_loaders = {
                key: JoinedDatasetInfo(
                    dataset=join.post_initialize(mds_path),
                    nonmatch=join.nonmatch,
                )
                for key, join in self.join.items()
            }
        else:
            raise ValueError("Invalid join type")

        self._dataset = JoinDatasetLoader(
            datasets=inner_loaders,
            joiner=self.joiner,
            split_part=self.split_part,
            subflavors=self.subflavors,
            shuffle_over_epochs_multiplier=self.shuffle_over_epochs_multiplier,
            split_config=self.split_config,
        )
        self._dataset.post_initialize(mds_path)

    def traverse(
        self,
        mds_path: Optional[EPath] = None,
        *,
        split_part: Union[Literal["train", "val", "test"], str],
        _subflavors: Optional[Dict[str, Any]] = None,
    ) -> List[TraversedDatasetReference]:
        raise NotImplementedError("traverse_metadataset() does not support joined datasets.")

    def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
        assert self._dataset is not None, "Missing post_initialize call."
        return self._dataset.prepare(split_part=split_part)

    def get_datasets(
        self,
        *,
        training: bool,
        split_part: Union[Literal["train", "val", "test"], str],
        worker_config: WorkerConfig,
        subflavors: Optional[Dict[str, Any]] = None,
        shuffle_over_epochs_multiplier: Optional[int] = 1,
        subset: Optional[DatasetSubset] = None,
        **kwargs,
    ) -> LoadedDatasetList:
        assert self._dataset is not None, "Missing post_initialize call."
        subset = self._get_subset(subset)
        return self._dataset.get_datasets(
            training=training,
            split_part=split_part,
            worker_config=worker_config,
            subflavors=subflavors,
            shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
            subset=subset,
            **kwargs,
        )


@dataclass
class BlendWeightMixin:
    weight: float = 1.0


@edataclass
class BlendDatasetReference(BlendWeightMixin, DatasetReference):
    pass


@edataclass
class BlendJoinDatasetReference(BlendWeightMixin, MetadatasetJoin):
    pass


@edataclass
class MetadatasetBlend(DatasetLoaderInterface, SubsetRatioMixin):
    """Blending of datasets by specifying the sampling weight for the inner datasets."""

    blend: List[Union[BlendDatasetReference, BlendJoinDatasetReference]]

    def post_initialize(self, mds_path: Optional[EPath] = None):
        assert mds_path is not None
        for dataset in self.blend:
            dataset.post_initialize(mds_path)

    def traverse(
        self,
        mds_path: Optional[EPath] = None,
        *,
        split_part: Union[Literal["train", "val", "test"], str],
        _subflavors: Optional[Dict[str, Any]] = None,
    ) -> List[TraversedDatasetReference]:
        assert mds_path is not None
        flattened: List[TraversedDatasetReference] = []
        for dataset in self.blend:
            flattened.extend(
                dataset.traverse(
                    mds_path,
                    split_part=split_part,
                    _subflavors=_subflavors,
                )
            )
        return flattened

    def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
        files = []
        for dataset in self.blend:
            files.extend(dataset.prepare(split_part=split_part))
        return files

    def get_datasets(
        self,
        *,
        training: bool,
        split_part: Union[Literal["train", "val", "test"], str],
        worker_config: WorkerConfig,
        subflavors: Optional[Dict[str, Any]] = None,
        shuffle_over_epochs_multiplier: Optional[int] = 1,
        subset: Optional[DatasetSubset] = None,
        **kwargs,
    ) -> LoadedDatasetList:
        subset = self._get_subset(subset)
        sum_weight = sum(dataset.weight for dataset in self.blend)
        datasets = []
        for dataset in self.blend:
            inner_result = dataset.get_datasets(
                training=training,
                split_part=split_part,
                worker_config=worker_config,
                subflavors=subflavors,
                shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
                subset=subset,
                **kwargs,
            )
            if inner_result.blend_mode not in (
                DatasetBlendMode.NONE,
                DatasetBlendMode.DATASET_WEIGHT,
            ):
                raise ValueError(
                    "Can only blend datasets which are of the same blend mode. Cannot mix blend with blend_epochized."
                )
            for loaded_dataset in inner_result.datasets:
                if inner_result.blend_mode == DatasetBlendMode.DATASET_WEIGHT:
                    assert isinstance(loaded_dataset.weight, float)
                else:
                    assert inner_result.blend_mode == DatasetBlendMode.NONE
                    assert loaded_dataset.weight is None
                    assert loaded_dataset.repetitions is None
                    loaded_dataset.weight = 1.0
                loaded_dataset.weight = loaded_dataset.weight * dataset.weight / sum_weight
                datasets.append(loaded_dataset)
        return LoadedDatasetList(
            blend_mode=DatasetBlendMode.DATASET_WEIGHT,
            datasets=datasets,
        )


@dataclass
class BlendRepetitionsMixin:
    repetitions: Union[int, float] = 1


@edataclass
class BlendEpochizedDatasetReference(BlendRepetitionsMixin, DatasetReference):
    pass


@edataclass
class BlendEpochizedJoinDatasetReference(BlendRepetitionsMixin, MetadatasetJoin):
    pass


@edataclass
class MetadatasetBlendEpochized(SubsetRatioMixin, DatasetLoaderInterface):
    """Blending of datasets, by specifying the number of repetitions for samples from the inner
    datasets. Ensures that the constraint, that samples are seen exactly this many times before
    repeating the "epoch" (i.e. one epoch contains the total number of repetitions for each inner
    dataset)."""

    blend_epochized: List[Union[BlendEpochizedDatasetReference, BlendEpochizedJoinDatasetReference]]

    def post_initialize(self, mds_path: Optional[EPath] = None):
        assert mds_path is not None
        for dataset in self.blend_epochized:
            dataset.post_initialize(mds_path)

    def traverse(
        self,
        mds_path: Optional[EPath] = None,
        *,
        split_part: Union[Literal["train", "val", "test"], str],
        _subflavors: Optional[Dict[str, Any]] = None,
    ) -> List[TraversedDatasetReference]:
        assert mds_path is not None
        flattened: List[TraversedDatasetReference] = []
        for dataset in self.blend_epochized:
            flattened.extend(
                dataset.traverse(
                    mds_path,
                    split_part=split_part,
                    _subflavors=_subflavors,
                )
            )
        return flattened

    def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
        files = []
        for dataset in self.blend_epochized:
            files.extend(dataset.prepare(split_part=split_part))
        return files

    def get_datasets(
        self,
        *,
        training: bool,
        split_part: Union[Literal["train", "val", "test"], str],
        worker_config: WorkerConfig,
        subflavors: Optional[Dict[str, Any]] = None,
        shuffle_over_epochs_multiplier: Optional[int] = 1,
        subset: Optional[DatasetSubset] = None,
        **kwargs,
    ) -> LoadedDatasetList:
        subset = self._get_subset(subset)
        datasets = []
        for dataset in self.blend_epochized:
            inner_result = dataset.get_datasets(
                training=training,
                split_part=split_part,
                worker_config=worker_config,
                subflavors=subflavors,
                shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
                subset=subset,
                **kwargs,
            )
            if inner_result.blend_mode not in (
                DatasetBlendMode.NONE,
                DatasetBlendMode.SAMPLE_REPETITIONS,
            ):
                raise ValueError(
                    "Can only blend datasets which are of the same blend mode. Cannot mix blend with blend_epochized."
                )
            for loaded_dataset in inner_result.datasets:
                if inner_result.blend_mode == DatasetBlendMode.SAMPLE_REPETITIONS:
                    assert isinstance(loaded_dataset.repetitions, (int, float))
                else:
                    assert loaded_dataset.weight is None
                    assert loaded_dataset.repetitions is None
                    loaded_dataset.repetitions = 1
                loaded_dataset.repetitions = dataset.repetitions * loaded_dataset.repetitions
                datasets.append(loaded_dataset)
        return LoadedDatasetList(
            blend_mode=DatasetBlendMode.SAMPLE_REPETITIONS,
            datasets=datasets,
        )


[docs] @edataclass class MetadatasetV2(DatasetLoaderInterface): path: EPath splits: Dict[ str, Union[MetadatasetBlend, MetadatasetBlendEpochized, MetadatasetJoin, DatasetReference] ]
[docs] def post_initialize(self, mds_path: Optional[EPath] = None): assert mds_path is None for split in self.splits.values(): split.post_initialize(self.path)
[docs] def traverse( self, mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: """Traverse the selected V2 split and flatten all reachable leaf references. Args: mds_path: Unused for top-level metadatasets. Present to satisfy the shared interface. split_part: Split to traverse. Returns: The flattened list of traversed leaf dataset references for `split_part`. """ assert mds_path is None return self.splits[split_part].traverse( self.path, split_part=split_part, _subflavors=_subflavors, )
[docs] def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]: # In the case of prepare for MetadatasetV2, we ignore the passed cache_path # and instead use the own path. # If someone runs energon prepare on a metadataset that refers to another metadataset, # any actions concerning the inner metadataset will be done on the inner metadataset's path. if split_part is None: files = [] for split_part, split in self.splits.items(): files.extend(split.prepare(split_part=split_part)) else: files = self.splits[split_part].prepare(split_part=split_part) # Cleanup paths here cache_path = EPath(self.path.parent / f"{self.path.name}.cache") remove_files = set(cache_path.glob("*")) - set(files) for file in remove_files: print(f"Cleanup cache file {file}") file.unlink() # Do not return paths. This is a barrier return ()
[docs] def get_datasets( self, *, training: bool, split_part: Union[Literal["train", "val", "test"], str], worker_config: WorkerConfig, subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, **kwargs, ) -> LoadedDatasetList: return self.splits[split_part].get_datasets( training=training, split_part=split_part, worker_config=worker_config, subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, **kwargs, )