Task Encoder

The “Task Encoder” is an Energon-specific concept. It is a class that describes how each sample is going to be processed at different stages of the pipeline. Please also take a look at Data Flow to see the pipeline.

If you don’t specify any task encoder, a default version, the DefaultTaskEncoder will be used.

For writing your own task encoder, create a class based on DefaultTaskEncoder and override one or more of the following methods. The data flow of get_train_dataset or get_val_dataset is as follows:

  • def cook_crude_sample(self, sample: Union[T_sample, CrudeSample]) -> T_sample

  • def encode_sample(self, sample: T_sample) -> T_encoded_sample

    • Transform the raw data from the dataset (e.g. augment/transform images, tokenize a single sample).

  • def select_samples_to_pack(self, samples: List[T_encoded_sample]) -> List[List[T_encoded_sample]]

    • Optional. Allows for efficient sample packing. See Packing.

  • def pack_selected_samples(self, samples: List[T_encoded_sample]) -> T_batch_sample]

    • Required if select_samples_to_pack is used. Compresses a group of samples to a single sample.

  • (samples are collected for a batch)

  • def batch(self, batch: List[T_encoded_sample]) -> T_raw_batch

    • Collate the batch to a single sample, defaults to padded batching for tensors, lists for everything else.

  • def encode_batch(self, batch_data: T_raw_batch) -> T_batch

    • Transform the batched data (e.g. tokenize the whole batch).

  • (optionally limit the dataset size, based on the limit argument)

  • (optionally epochize the dataset)

  • (move data from the worker to the main process through the torch.data.DataLoader via get_loader)

  • For batches based on Batch, call def pin_memory(self, batch: T_batch) -> T_batch, or if not a dataclass, use default torch pinning (this must happen in the main process, thus after data loading)

If a sample or batch is to be ignored, any of these methods may raise IgnoreSample to skip the sample being processed.

The types T_sample, T_encoded_sample, T_raw_batch and T_batch are generics and depend on your task. You do not necessarily have to specify them, it’s only used for proper typing in your IDE.

from dataclasses import dataclass
from typing import Callable, List, Optional

import torch

from megatron.energon import CaptioningSample, DefaultTaskEncoder, batch_list, batch_stack


# Type for intermediate batch, after batching operation
@dataclass
class CaptioningRawBatch:
    # (n,)
    __key__: List[str]
    # (n, c, h, w)
    image: torch.Tensor
    # (n,)
    caption: List[str]


# Typing for the resulting batch data
@dataclass
class CaptioningBatch:
    __keys__: List[str]
    # (n, c, h, w)
    images: torch.Tensor
    # (n, c)
    text_tokens: torch.Tensor
    # (n, c, c)
    text_attn_mask: torch.Tensor


# All the typing is optional
class CaptioningTaskEncoder(
    DefaultTaskEncoder[CaptioningSample, CaptioningSample, CaptioningRawBatch, CaptioningBatch]
):
    """A simple task encoder for captioning."""

    def __init__(
        self,
        tokenizer: Tokenizer,
        image_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        max_length: int = 128,
    ):
        # Specify the batch_type for default batching (batching is performed here "manually" by overwriting the `batch`
        # method)
        super().__init__(batch_type=CaptioningRawBatch)
        self.tokenizer = tokenizer
        self.image_transform = image_transform
        self.max_length = max_length

    def encode_sample(self, sample: CaptioningSample) -> CaptioningSample:
        sample.image = self.image_transform(sample.image)
        return sample
    
    def batch(self, samples: List[CaptioningSample]) -> CaptioningRawBatch:
        # Batch the samples
        # The actions dict specifies how to batch each field of the sample. In addition to these, you may use 
        # `batch_pad_stack` as well.
        # By default, `batch_pad_stack` is used for all tensor fields, and `batch_list` is used for all non-tensor 
        # fields. This example matches the default implementation (not overwriting the `batch` method).
        return self._batch(samples, result_type=CaptioningRawBatch, actions={"image": batch_stack, "caption": batch_list})

    def encode_batch(self, batch_data: CaptioningRawBatch) -> CaptioningBatch:
        # Run the encoder on the batch of captions.
        tokenized = self.tokenizer(batch_data.caption)
        # Return the final batch, going into the network
        return CaptioningBatch(
            __keys__=batch_data.__key__,
            images=batch_data.image,
            text_tokens=tokenized["input_ids"],
            text_attn_mask=tokenized["attention_mask"],
        )

Usage in your training script:

    
from torchvision import transforms
from transformers import AutoTokenizer
from megatron.energon import get_loader, get_train_dataset

    
train_img_transform = transforms.Compose(
    [
        transforms.RandomResizedCrop((224, 224)),
        transforms.RandomHorizontalFlip(),
    ]
)

train_loader = get_loader(get_train_dataset(
    '/my/dataset/path',
    batch_size=32,
    shuffle_buffer_size=100,
    max_samples_per_sequence=100,
    task_encoder=CaptioningTaskEncoder(
        tokenizer=AutoTokenizer.from_pretrained('gpt2'),
        image_transform=train_img_transform,
    ),
    # Change this to set how images are decoded.
    # E.g. "pil" is another commonly used valid option.
    # See `webdataset.imagehandler` for more options.
    image_decode="torchrgb",
))

for data in train_loader:
    # data is a CaptioningBatch
    pass