Source code for megatron.energon.task_encoder.cooking

# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
from dataclasses import dataclass
from typing import Callable, Optional

from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.crude import CrudeSample


[docs] @dataclass class Cooker: """A cooker transforms a crude sample (simple dict) into a specific sample type inheriting from `Sample`. The `cook` method performs the transformation, the other fields are used to select the samples which this cooker can transform. If no filters are provided, the cooker will transform any sample. """ cook: Callable[[dict], Sample] # If multiple of the following conditions are provided # then the sample must satisfy all of them. is_subflavor: Optional[str] = None has_subflavors: Optional[dict] = None condition: Optional[Callable[[dict], bool]] = None
[docs] def is_match(self, crude_sample: CrudeSample) -> bool: if self.is_subflavor is not None: if crude_sample["__subflavor__"] != self.is_subflavor: return False if self.has_subflavors is not None: # Checks if the dict entries provided as a filter all match # the ones in the sample. The sample may have additional entries. for k, v in self.has_subflavors.items(): if ( k not in crude_sample["__subflavors__"] or crude_sample["__subflavors__"][k] != v ): return False if self.condition is not None: if not self.condition(crude_sample): return False return True
[docs] def basic_sample_keys(crude_sample: dict) -> dict: """A convenience helper to extract the basic keys from a crude sample, which you will always need to forward to the cooked sample.""" return { field.name: crude_sample[field.name] for field in dataclasses.fields(Sample) if field.name in crude_sample }