Collation
Batch collation utilities for inference prediction workflows.
This module is part of bionemo-recipeutils and MUST NOT import from megatron, nemo, or mbridge. It depends only on torch.
batch_collator(batches, batch_dim=0, seq_dim=1, batch_dim_key_defaults=None, seq_dim_key_defaults=None, preferred_gpu=0)
Collate multiple batches into a single batch by concatenating along the batch dimension.
This function handles nested structures (dicts, lists, tuples) containing tensors. Unlike PyTorch's default_collate, this assumes the batch dimension already exists (as when parallelizing across microbatches or DP ranks).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batches
|
Optional[Union[Tuple[ReductionT, ...], List[ReductionT]]]
|
Sequence of batches to collate. Each batch can be a tensor, dict, list, or tuple. The structure must be consistent across all batches. |
required |
batch_dim
|
int
|
Dimension along which to concatenate tensors. Default 0. |
0
|
seq_dim
|
int
|
Sequence dimension, used for padding to max length. Default 1. |
1
|
batch_dim_key_defaults
|
Optional[dict[str, int]]
|
For dict batches, override batch_dim for specific keys. Default: {"token_logits": 1} (legacy compatibility, recommend passing {}). |
None
|
seq_dim_key_defaults
|
Optional[dict[str, int]]
|
For dict batches, override seq_dim for specific keys. Default: {"token_logits": 0} (legacy compatibility, recommend passing {}). |
None
|
preferred_gpu
|
int
|
If any tensor is on GPU, move all to this device. Default 0. |
0
|
Returns:
| Type | Description |
|---|---|
Optional[ReductionT]
|
Collated batch with same structure as input batches, or None if input contains None. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If batches is empty or contains unsupported types. |
Examples:
>>> # Collate dict batches
>>> batch1 = {"logits": torch.randn(2, 10, 512), "mask": torch.ones(2, 10)}
>>> batch2 = {"logits": torch.randn(3, 10, 512), "mask": torch.ones(3, 10)}
>>> result = batch_collator([batch1, batch2], batch_dim=0, seq_dim=1,
... batch_dim_key_defaults={}, seq_dim_key_defaults={})
>>> result["logits"].shape # torch.Size([5, 10, 512])
>>> # Collate with padding (different sequence lengths)
>>> batch1 = {"tokens": torch.randn(2, 100)}
>>> batch2 = {"tokens": torch.randn(2, 150)}
>>> result = batch_collator([batch1, batch2], batch_dim=0, seq_dim=1,
... batch_dim_key_defaults={}, seq_dim_key_defaults={})
>>> result["tokens"].shape # torch.Size([4, 150]) - padded to max length
Source code in bionemo/recipeutils/inference/collation.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | |