Source code for megatron.energon.source_info

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from dataclasses import dataclass
from typing import Any, Union

from megatron.energon.epathlib import EPath


[docs] @dataclass(frozen=True, slots=True, kw_only=True) class SourceInfo: """Information about the source of a sample, i.e. where the data was loaded from.""" #: The path to the dataset dataset_path: EPath #: The index of the sample in the dataset index: Union[str, int] #: The name of the shard tar file shard_name: str #: The names of the files in the shard used to create the sample file_names: tuple[str, ...]
def add_source_info(sample: Any, *source_infos: SourceInfo) -> Any: """Adds a source info to a sample.""" from megatron.energon.flavors.base_dataset import Sample if isinstance(sample, Sample) or hasattr(sample, "__sources__"): if sample.__sources__ is None: sample.__sources__ = source_infos else: sample.__sources__ = (*sample.__sources__, *source_infos) elif isinstance(sample, dict): if sample.get("__sources__") is None: sample["__sources__"] = source_infos else: sample["__sources__"] = (*sample["__sources__"], *source_infos) return sample