Task Encoder

The “Task Encoder” is an Energon-specific concept. It is a class that describes how each sample is going to be processed at different stages of the pipeline. Please also take a look at Data Flow to see the pipeline.

If you don’t specify any task encoder, a default version, the DefaultTaskEncoder will be used.

For writing your own task encoder, create a class based on DefaultTaskEncoder and override one or more of the following methods. The data flow of get_train_dataset or get_val_dataset is as follows:

If a sample or batch is to be ignored, any of these methods may raise IgnoreSample to skip the sample being processed.

The types T_sample, T_encoded_sample, T_raw_batch and T_batch are generics and depend on your task. You do not necessarily have to specify them, it’s only used for proper typing in your IDE.

from dataclasses import dataclass
from typing import Callable, List, Optional

import torch

from megatron.energon import CaptioningSample, DefaultTaskEncoder, batch_list, batch_stack


# Type for intermediate batch, after batching operation
@dataclass
class CaptioningRawBatch:
    # (n,)
    __key__: List[str]
    # (n, c, h, w)
    image: torch.Tensor
    # (n,)
    caption: List[str]


# Typing for the resulting batch data
@dataclass
class CaptioningBatch:
    __keys__: List[str]
    # (n, c, h, w)
    images: torch.Tensor
    # (n, c)
    text_tokens: torch.Tensor
    # (n, c, c)
    text_attn_mask: torch.Tensor


# All the typing is optional
class CaptioningTaskEncoder(
    DefaultTaskEncoder[CaptioningSample, CaptioningSample, CaptioningRawBatch, CaptioningBatch]
):
    """A simple task encoder for captioning."""

    def __init__(
        self,
        tokenizer: Tokenizer,
        image_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        max_length: int = 128,
    ):
        # Specify the batch_type for default batching (batching is performed here "manually" by overwriting the `batch`
        # method)
        super().__init__(batch_type=CaptioningRawBatch)
        self.tokenizer = tokenizer
        self.image_transform = image_transform
        self.max_length = max_length

    def encode_sample(self, sample: CaptioningSample) -> CaptioningSample:
        sample.image = self.image_transform(sample.image)
        return sample
    
    def batch(self, samples: List[CaptioningSample]) -> CaptioningRawBatch:
        # Batch the samples
        # The actions dict specifies how to batch each field of the sample. In addition to these, you may use 
        # `batch_pad_stack` as well.
        # By default, `batch_pad_stack` is used for all tensor fields, and `batch_list` is used for all non-tensor 
        # fields. This example matches the default implementation (not overwriting the `batch` method).
        return self._batch(samples, result_type=CaptioningRawBatch, actions={"image": batch_stack, "caption": batch_list})

    def encode_batch(self, batch_data: CaptioningRawBatch) -> CaptioningBatch:
        # Run the encoder on the batch of captions.
        tokenized = self.tokenizer(batch_data.caption)
        # Return the final batch, going into the network
        return CaptioningBatch(
            __keys__=batch_data.__key__,
            images=batch_data.image,
            text_tokens=tokenized["input_ids"],
            text_attn_mask=tokenized["attention_mask"],
        )

Usage in your training script:

    
from torchvision import transforms
from transformers import AutoTokenizer
from megatron.energon import get_loader, get_train_dataset

    
train_img_transform = transforms.Compose(
    [
        transforms.RandomResizedCrop((224, 224)),
        transforms.RandomHorizontalFlip(),
    ]
)

train_loader = get_loader(get_train_dataset(
    '/my/dataset/path',
    batch_size=32,
    shuffle_buffer_size=100,
    max_samples_per_sequence=100,
    task_encoder=CaptioningTaskEncoder(
        tokenizer=AutoTokenizer.from_pretrained('gpt2'),
        image_transform=train_img_transform,
    ),
    # Change this to set how images are decoded.
    # E.g. "pil" is another commonly used valid option.
    # See `webdataset.imagehandler` for more options.
    image_decode="torchrgb",
))

for data in train_loader:
    # data is a CaptioningBatch
    pass