Groupingο
Grouping allows for rule-based batching of samples into one batch on the fly.
Note how this is different from packing which joins multiple samples into one (and is done before batching). On the other hand, grouping is an alternative to standard batching.
Example use casesο
Select samples to batch based on image resolution, so that only samples of the same size are in one batch
Select blended samples based on their dataset origin, so that one batch does not mix different tasks or data types
How to groupο
To use grouping, you need to define the method batch_group_criterion
in your custom task encoder.
This method gets a sample and returns a hashable value that will be used to cluster/group the samples
and it also returns the batch size for that group.
Samples with the same batch group criterion will be batched together. Once enough samples for one group have been collected (reached the batch size for that group), they will be batched and pushed down the pipeline to the next processing step.
Hereβs an example task encoder that batches samples based on their image aspect ratios:
class GroupingTaskEncoder(DefaultTaskEncoder):
def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, Optional[int]]:
aspect_ratio = sample.image.shape[2] / sample.image.shape[1]
# Bin aspect ratios into 3 groups
if aspect_ratio < 0.8:
return "portrait", 8
elif aspect_ratio < 1.2:
return "square", 8
else:
return "landscape", 8
In the example, the aspect ratio is sorted into one of three bins and a string is used as the grouping key. The batch size used here is always 8.
Here is another example where each batch contains only images with the exact same size. Note how the image shape itself is used as the grouping key.
class GroupingTaskEncoder(DefaultTaskEncoder):
def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, Optional[int]]:
batch_size = 4 if sample.image.shape[1] < 512 else 2
return sample.image.shape, batch_size
For images with a height of less than 512 pixels, the batch size will be 4, for larger images itβs reduced to 2.
Fixed global batch sizeο
Instead of specifying the batch size for each group individually, you can also specify the batch size as usually when calling
get_train_dataset
. The batch_group_criterion
method should then return None
for the batch_size.