Skip to content

Collate

bert_padding_collate_fn(batch, padding_value, min_length=None, max_length=None)

Padding collate function for BERT dataloaders.

Parameters:

Name Type Description Default
batch list

List of samples.

required
padding_value int

The tokenizer's pad token ID.

required
min_length int | None

Minimum length of the output batch; tensors will be padded to this length. If not provided, no extra padding beyond the max_length will be added.

None
max_length int | None

Maximum length of the sequence. If not provided, tensors will be padded to the longest sequence in the batch.

None
Source code in bionemo/llm/data/collate.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def bert_padding_collate_fn(
    batch: Sequence[types.BertSample],
    padding_value: int,
    min_length: int | None = None,
    max_length: int | None = None,
) -> types.BertSample:
    """Padding collate function for BERT dataloaders.

    Args:
        batch (list): List of samples.
        padding_value (int, optional): The tokenizer's pad token ID.
        min_length: Minimum length of the output batch; tensors will be padded to this length. If not
            provided, no extra padding beyond the max_length will be added.
        max_length: Maximum length of the sequence. If not provided, tensors will be padded to the
            longest sequence in the batch.
    """
    padding_values = {
        "text": padding_value,
        "types": 0,
        "attention_mask": False,
        "labels": -100,  # This should match the masked value used in the MLM loss mask.
        "loss_mask": False,
        "is_random": 0,
    }
    return padding_collate_fn(
        batch=batch,  # type: ignore[assignment]
        padding_values=padding_values,
        min_length=min_length,
        max_length=max_length,
    )

padding_collate_fn(batch, padding_values, min_length=None, max_length=None)

Collate function with padding.

Parameters:

Name Type Description Default
batch Sequence[_T]

List of samples, each of which is a dictionary of tensors.

required
padding_values dict[str, int]

A dictionary of padding values for each tensor key.

required
min_length int | None

Minimum length of the output batch; tensors will be padded to this length. If not provided, no extra padding beyond the max_length will be added.

None
max_length int | None

Maximum length of the sequence. If not provided, tensors will be padded to the longest sequence in the batch.

None

Returns:

Type Description
_T

A collated batch with the same dictionary input structure.

Source code in bionemo/llm/data/collate.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def padding_collate_fn(
    batch: Sequence[_T],
    padding_values: dict[str, int],
    min_length: int | None = None,
    max_length: int | None = None,
) -> _T:
    """Collate function with padding.

    Args:
        batch: List of samples, each of which is a dictionary of tensors.
        padding_values: A dictionary of padding values for each tensor key.
        min_length: Minimum length of the output batch; tensors will be padded to this length. If not
            provided, no extra padding beyond the max_length will be added.
        max_length: Maximum length of the sequence. If not provided, tensors will be padded to the
            longest sequence in the batch.

    Returns:
        A collated batch with the same dictionary input structure.
    """
    global _warned_once
    keys: set[str] | None = None

    if len(batch) == 0:  # empty batches passed through in DDP inference
        return {}

    for entry in batch:
        # First check that we have sane batches where keys align with each other.
        if keys is None:
            keys = set(entry.keys())
        else:
            if set(entry.keys()) != keys:
                raise ValueError(f"All keys in inputs must match each other. Got: {[sorted(e.keys()) for e in batch]}")
        if entry.keys() != padding_values.keys():
            if not _warned_once:
                extra_keys = {k for k in entry.keys() if k not in padding_values}
                missing_keys = {k for k in padding_values.keys() if k not in entry}
                logger.warning(
                    f"Extra keys in batch that will not be padded: {extra_keys}. Missing keys in batch: {missing_keys}"
                )
                _warned_once = True

    def _pad(tensors, padding_value):
        if max_length is not None:
            tensors = [t[:max_length] for t in tensors]
        batched_tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True, padding_value=padding_value)
        if min_length is None:
            return batched_tensors
        return torch.nn.functional.pad(batched_tensors, (0, min_length - batched_tensors.size(1)), value=padding_value)

    return {
        k: _pad([s[k] for s in batch], padding_values[k])
        if k in padding_values
        else torch.stack([s[k] for s in batch])
        for k in batch[0].keys()
    }  # type: ignore[return-value]