Skip to content

Collation

Batch collation utilities for inference prediction workflows.

This module is part of bionemo-recipeutils and MUST NOT import from megatron, nemo, or mbridge. It depends only on torch.

batch_collator(batches, batch_dim=0, seq_dim=1, batch_dim_key_defaults=None, seq_dim_key_defaults=None, preferred_gpu=0)

Collate multiple batches into a single batch by concatenating along the batch dimension.

This function handles nested structures (dicts, lists, tuples) containing tensors. Unlike PyTorch's default_collate, this assumes the batch dimension already exists (as when parallelizing across microbatches or DP ranks).

Parameters:

Name Type Description Default
batches Optional[Union[Tuple[ReductionT, ...], List[ReductionT]]]

Sequence of batches to collate. Each batch can be a tensor, dict, list, or tuple. The structure must be consistent across all batches.

required
batch_dim int

Dimension along which to concatenate tensors. Default 0.

0
seq_dim int

Sequence dimension, used for padding to max length. Default 1.

1
batch_dim_key_defaults Optional[dict[str, int]]

For dict batches, override batch_dim for specific keys. Default: {"token_logits": 1} (legacy compatibility, recommend passing {}).

None
seq_dim_key_defaults Optional[dict[str, int]]

For dict batches, override seq_dim for specific keys. Default: {"token_logits": 0} (legacy compatibility, recommend passing {}).

None
preferred_gpu int

If any tensor is on GPU, move all to this device. Default 0.

0

Returns:

Type Description
Optional[ReductionT]

Collated batch with same structure as input batches, or None if input contains None.

Raises:

Type Description
ValueError

If batches is empty or contains unsupported types.

Examples:

>>> # Collate dict batches
>>> batch1 = {"logits": torch.randn(2, 10, 512), "mask": torch.ones(2, 10)}
>>> batch2 = {"logits": torch.randn(3, 10, 512), "mask": torch.ones(3, 10)}
>>> result = batch_collator([batch1, batch2], batch_dim=0, seq_dim=1,
...                         batch_dim_key_defaults={}, seq_dim_key_defaults={})
>>> result["logits"].shape  # torch.Size([5, 10, 512])
>>> # Collate with padding (different sequence lengths)
>>> batch1 = {"tokens": torch.randn(2, 100)}
>>> batch2 = {"tokens": torch.randn(2, 150)}
>>> result = batch_collator([batch1, batch2], batch_dim=0, seq_dim=1,
...                         batch_dim_key_defaults={}, seq_dim_key_defaults={})
>>> result["tokens"].shape  # torch.Size([4, 150]) - padded to max length
Source code in bionemo/recipeutils/inference/collation.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def batch_collator(
    batches: Optional[Union[Tuple[ReductionT, ...], List[ReductionT]]],
    batch_dim: int = 0,
    seq_dim: int = 1,
    batch_dim_key_defaults: Optional[dict[str, int]] = None,
    seq_dim_key_defaults: Optional[dict[str, int]] = None,
    preferred_gpu: int = 0,
) -> Optional[ReductionT]:
    """Collate multiple batches into a single batch by concatenating along the batch dimension.

    This function handles nested structures (dicts, lists, tuples) containing tensors.
    Unlike PyTorch's default_collate, this assumes the batch dimension already exists
    (as when parallelizing across microbatches or DP ranks).

    Args:
        batches: Sequence of batches to collate. Each batch can be a tensor, dict, list, or tuple.
            The structure must be consistent across all batches.
        batch_dim: Dimension along which to concatenate tensors. Default 0.
        seq_dim: Sequence dimension, used for padding to max length. Default 1.
        batch_dim_key_defaults: For dict batches, override batch_dim for specific keys.
            Default: {"token_logits": 1} (legacy compatibility, recommend passing {}).
        seq_dim_key_defaults: For dict batches, override seq_dim for specific keys.
            Default: {"token_logits": 0} (legacy compatibility, recommend passing {}).
        preferred_gpu: If any tensor is on GPU, move all to this device. Default 0.

    Returns:
        Collated batch with same structure as input batches, or None if input contains None.

    Raises:
        ValueError: If batches is empty or contains unsupported types.

    Examples:
        >>> # Collate dict batches
        >>> batch1 = {"logits": torch.randn(2, 10, 512), "mask": torch.ones(2, 10)}
        >>> batch2 = {"logits": torch.randn(3, 10, 512), "mask": torch.ones(3, 10)}
        >>> result = batch_collator([batch1, batch2], batch_dim=0, seq_dim=1,
        ...                         batch_dim_key_defaults={}, seq_dim_key_defaults={})
        >>> result["logits"].shape  # torch.Size([5, 10, 512])

        >>> # Collate with padding (different sequence lengths)
        >>> batch1 = {"tokens": torch.randn(2, 100)}
        >>> batch2 = {"tokens": torch.randn(2, 150)}
        >>> result = batch_collator([batch1, batch2], batch_dim=0, seq_dim=1,
        ...                         batch_dim_key_defaults={}, seq_dim_key_defaults={})
        >>> result["tokens"].shape  # torch.Size([4, 150]) - padded to max length
    """
    if batch_dim_key_defaults is None:
        batch_dim_key_defaults = {"token_logits": 1}
    if seq_dim_key_defaults is None:
        seq_dim_key_defaults = {"token_logits": 0}

    match batches:
        case [None, *_]:
            return None

        case [Tensor(), *_]:
            return _collate_tensors(batches, batch_dim=batch_dim, seq_dim=seq_dim, preferred_gpu=preferred_gpu)

        case [dict(), *_]:
            return {
                key: batch_collator(
                    [batch[key] for batch in batches],
                    batch_dim=batch_dim_key_defaults.get(key, batch_dim),
                    seq_dim=seq_dim_key_defaults.get(key, seq_dim),
                    batch_dim_key_defaults=batch_dim_key_defaults,
                    seq_dim_key_defaults=seq_dim_key_defaults,
                    preferred_gpu=preferred_gpu,
                )
                for key in batches[0]
            }

        case [tuple(), *_]:
            return tuple(
                batch_collator(
                    [batch[i] for batch in batches],
                    batch_dim=batch_dim,
                    seq_dim=seq_dim,
                    batch_dim_key_defaults=batch_dim_key_defaults,
                    seq_dim_key_defaults=seq_dim_key_defaults,
                    preferred_gpu=preferred_gpu,
                )
                for i in range(len(batches[0]))
            )

        case [list(), *_]:
            return [
                batch_collator(
                    [batch[i] for batch in batches],
                    batch_dim=batch_dim,
                    seq_dim=seq_dim,
                    batch_dim_key_defaults=batch_dim_key_defaults,
                    seq_dim_key_defaults=seq_dim_key_defaults,
                    preferred_gpu=preferred_gpu,
                )
                for i in range(len(batches[0]))
            ]

        case []:
            raise ValueError("Cannot collate an empty sequence of batches")
        case _:
            raise ValueError(f"Unsupported batch type: {type(batches[0]) if batches else 'empty'}")