Custom Sample Loader

As discussed in Dataset Format on Disk, you can specify how samples are loaded from disk and mapped to the sample dataclass in memory.

Instead of using a field_map in your dataset.yaml, you can also create custom python code for sample loading.

Here’s an example for your updated dataset.yaml:

sample_type:
  __module__: megatron.energon
  __class__: OCRSample
sample_loader: sample_loader.py:sample_loader
part_filter: sample_loader.py:part_filter

In addition, you need to create a python file inside the .nv-meta folder. In this case it’s called sample_loader.py. That file needs to contain the two methods referenced above:

import json
import torch

def sample_loader(raw: dict) -> dict:
    data = json.loads(raw["bbox.json"])

    return dict(
        __key__=raw["__key__"],
        image=raw["jpg"],
        text=raw["gt.txt"],
        lines_boxes=torch.tensor([box["bbox"] for box in data], dtype=torch.int64),
        lines_text=[box["text"] for box in data],
    )

def part_filter(part: str) -> bool:
    return part in ("bbox.json", "gt.txt", "jpg")

Your sample_loader method must accept a dictionary as argument and return a dictionary. It directly operates on the webdataset samples and the resulting dictionary keys should map to the corresponding sample class members. In this case an OCRSample.

With the optional part_filter method, you can prevent some webdataset fields from being loaded. Given a field name, the method should return True if the field is to be kept.

Crude Data and How to Cook It 👨‍🍳

Sometimes, you cannot do all the conversion work to obtain a Sample right in a small sample_loader.py next to the dataset. Maybe your data

  • is complex and contains custom binary formats and pickle data

  • or needs project-specific pre-processing that you don’t want to place next to the dataset

In this case, you can feed the data in its “crude” or “raw” form into your code and do the conversion there.

Let’s try it. If you already have an existing energon-compliant data set, you may want to create a copy of your dataset.yaml specifically for this use case and keep the original:

Original dataset.yaml:

sample_type:
  __module__: megatron.energon
  __class__: OCRSample
sample_loader: sample_loader.py:sample_loader
part_filter: sample_loader.py:part_filter

New crude.yaml next to it:

__module__: megatron.energon
__class__: CrudeWebdataset
subflavors:
  example_crude_selector: foobar

We recommend that you assign subflavors to differentiate the data. This helps, if you decide to mix different crude datasets or crude with non-crude data using a Metadataset. You can then use the subflavors to assign how each sample shall be processed.

Let’s see how we set up our TaskEncoder to cook the crude samples and turn them into real samples like TextSample.

A crude sample is essentially a dictionary (CrudeSample inherits from dict) and the loader will look for a Cooker that can handle and convert your crude samples. You can provide cookers as follows:

from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys

# ...

def cook_text(sample: dict) -> TextSample:
    return TextSample(
        **basic_sample_keys(sample),
        text=f">{sample['txt'].decode()}<",
    )


class MyTaskEncoder(DefaultTaskEncoder[TextSample, TextSample, TextRawBatch, TextBatch]):
    cookers = [
        Cooker(cook_text, has_subflavors={"example_crude_selector": "foobar"}),
        Cooker(...)  # other cookers for other crude data if needed
    ]

    # ...

In the example above, the cooker acts on all crude samples that have a subflavor example_crude_selector set to foobar. The cooker will convert the dictionary to a TextSample by decoding the raw bytes and decorating the text with some nice angle brackets. Probably you noticed the basic_sample_keys helper that we inserted. All it does is to forward the key, restore key and flavors from the dict to the real sample. You will always need to forward these, or your dataset will not be restorable.

In a real use-case you will want to do a lot more here and we recommend keeping the cook methods in separate files and importing them where you define your TaskEncoder.

Other Filters for Cookers

You can filter using the subflavors as above, you can also filter using the deprecated single subflavor like this:

Cooker(cook_text, is_subflavor="helloworld")

or if you need custom filtering, you can provide a method to filter the sample:

Cooker(cook_text, condition=lambda sample: return sample['myprop'] == 'yes_thats_it')

If you use multiple filters, they must all be satisfied for the sample to match.