Source code for megatron.energon.task_encoder.cooking

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
import functools
from dataclasses import dataclass
from typing import Callable, Generic, Optional, TypeVar, Union, overload

from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.crude import CrudeSample
from megatron.energon.source_info import SourceInfo

T_sample = TypeVar("T_sample", bound=Sample, covariant=True)

F = TypeVar("F", bound=Callable[..., Sample])


@overload
def cooker(
    fn: None = None,
) -> Callable[[F], F]: ...


@overload
def cooker(
    *,
    need_cache: bool = False,
    need_primary: bool = False,
) -> Callable[[F], F]: ...


[docs] def cooker( fn: Optional[F] = None, *, need_cache: bool = False, need_primary: bool = False, ) -> Union[ F, Callable[[F], F], ]: """Decorator to mark a function as a cooker, optionally enabling cache and primary dataset arguments.""" if fn is None: return functools.partial( cooker, need_cache=need_cache, need_primary=need_primary, ) @functools.wraps(fn) def fn_wrapper(*args, **kwargs): return fn(*args, **kwargs) setattr(fn_wrapper, "__cooker_need_cache__", need_cache) setattr(fn_wrapper, "__cooker_need_primary__", need_primary) return fn_wrapper
[docs] def get_cooker_need_cache(fn: Callable[..., T_sample]) -> bool: """Get whether a function is a cooker.""" return getattr(fn, "__cooker_need_cache__", False)
[docs] def get_cooker_need_primary(fn: Callable[..., T_sample]) -> bool: """Get whether a function is a cooker.""" return getattr(fn, "__cooker_need_primary__", False)
[docs] @dataclass class Cooker(Generic[T_sample]): """A cooker transforms a crude sample (simple dict) into a specific sample type inheriting from `Sample`. The `cook` method performs the transformation, the other fields are used to select the samples which this cooker can transform. If no filters are provided, the cooker will transform any `CrudeSample`. """ #: The callable that performs the cooking (i.e. loading / transforming the crude sample). # Signature is: # `(/, raw_sample: dict, *, primary?: RandomAccessDataset, **aux: RandomAccessDataset, cache?: Cache) -> Sample`. # `primary` is passed only if want_primary_random_access is true. # `cache` is passed only if want_cache is true. cook: Callable[..., T_sample] #: The subflavors to be present in the sample to be cooked by this cooker. All keys and values # must match. has_subflavors: Optional[dict] = None @property def need_primary(self) -> bool: return get_cooker_need_primary(self.cook) @property def need_cache(self) -> bool: return get_cooker_need_cache(self.cook)
[docs] def is_match(self, crude_sample: CrudeSample) -> bool: if self.has_subflavors is not None: # Checks if the dict entries provided as a filter all match # the ones in the sample. The sample may have additional entries. for k, v in self.has_subflavors.items(): if ( k not in crude_sample["__subflavors__"] or crude_sample["__subflavors__"][k] != v ): return False return True
[docs] def basic_sample_keys( crude_sample: dict, additional_source_info: tuple[SourceInfo, ...] = () ) -> dict: """A convenience helper to extract the basic keys from a crude sample, which you will always need to forward to the cooked sample.""" res = { field.name: crude_sample[field.name] for field in dataclasses.fields(Sample) if field.name in crude_sample } if additional_source_info: res["__sources__"] = (*crude_sample["__sources__"], *additional_source_info) return res