Skip to content

Lightning

BertBatch

Bases: BertBatchCore

Input datatype for inference with BERT-like models.

Source code in bionemo/llm/model/biobert/lightning.py
78
79
80
81
class BertBatch(BertBatchCore, total=False):
    """Input datatype for inference with BERT-like models."""

    cu_seqlens: Tensor

BertBatchCore

Bases: TypedDict

Input datatype for inference with BERT-like models.

Source code in bionemo/llm/model/biobert/lightning.py
66
67
68
69
70
class BertBatchCore(TypedDict):
    """Input datatype for inference with BERT-like models."""

    text: Tensor
    attention_mask: Tensor

BertModel

Bases: Protocol[DataT]

Interface for BERT-like models.

Source code in bionemo/llm/model/biobert/lightning.py
52
53
54
55
56
57
58
59
60
61
62
63
class BertModel(Protocol[DataT]):
    """Interface for BERT-like models."""

    def forward(
        self, input_ids: Tensor, attention_mask: Tensor, packed_seq_params: Optional[PackedSeqParams] = None
    ) -> DataT:
        """Inference for BERT-like models.

        Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input,
        and the original sequence lengths if the sequences are packed into a dense batch.
        """
        ...

forward(input_ids, attention_mask, packed_seq_params=None)

Inference for BERT-like models.

Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input, and the original sequence lengths if the sequences are packed into a dense batch.

Source code in bionemo/llm/model/biobert/lightning.py
55
56
57
58
59
60
61
62
63
def forward(
    self, input_ids: Tensor, attention_mask: Tensor, packed_seq_params: Optional[PackedSeqParams] = None
) -> DataT:
    """Inference for BERT-like models.

    Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input,
    and the original sequence lengths if the sequences are packed into a dense batch.
    """
    ...

BioBertLightningModule

Bases: BionemoLightningModule

Source code in bionemo/llm/model/biobert/lightning.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
class BioBertLightningModule(BionemoLightningModule):
    def __init__(
        self,
        *args,
        data_step_function: DataStepFunction = biobert_data_step,
        forward_step_function: ForwardStepFunction = bert_forward_step,
        **kwargs,
    ):
        """DEPRECATED! Please use BionemoLightningModule. This is here so we can load older checkpoints.
        This maps the old name `forward_step_function` to the new name `forward_step` and `data_step_function` to
        `data_step`.

        Args:
            *args: all args are passed through to BionemoLightningModule
            data_step_function (DataStepFunction, optional): The data step function. Defaults to biobert_data_step.
            forward_step_function (ForwardStepFunction, optional): The forward step function. Defaults to bert_forward_step.
            **kwargs: all other kwargs are passed through to BionemoLightningModule.
        """  # noqa: D205
        super().__init__(*args, forward_step=forward_step_function, data_step=data_step_function, **kwargs)

__init__(*args, data_step_function=biobert_data_step, forward_step_function=bert_forward_step, **kwargs)

DEPRECATED! Please use BionemoLightningModule. This is here so we can load older checkpoints. This maps the old name forward_step_function to the new name forward_step and data_step_function to data_step.

Parameters:

Name Type Description Default
*args

all args are passed through to BionemoLightningModule

()
data_step_function DataStepFunction

The data step function. Defaults to biobert_data_step.

biobert_data_step
forward_step_function ForwardStepFunction

The forward step function. Defaults to bert_forward_step.

bert_forward_step
**kwargs

all other kwargs are passed through to BionemoLightningModule.

{}
Source code in bionemo/llm/model/biobert/lightning.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def __init__(
    self,
    *args,
    data_step_function: DataStepFunction = biobert_data_step,
    forward_step_function: ForwardStepFunction = bert_forward_step,
    **kwargs,
):
    """DEPRECATED! Please use BionemoLightningModule. This is here so we can load older checkpoints.
    This maps the old name `forward_step_function` to the new name `forward_step` and `data_step_function` to
    `data_step`.

    Args:
        *args: all args are passed through to BionemoLightningModule
        data_step_function (DataStepFunction, optional): The data step function. Defaults to biobert_data_step.
        forward_step_function (ForwardStepFunction, optional): The forward step function. Defaults to bert_forward_step.
        **kwargs: all other kwargs are passed through to BionemoLightningModule.
    """  # noqa: D205
    super().__init__(*args, forward_step=forward_step_function, data_step=data_step_function, **kwargs)

SequenceBatch

Bases: SequenceBatchCore

Input datatype for inference with BERT-like models.

Source code in bionemo/llm/model/biobert/lightning.py
90
91
92
93
94
class SequenceBatch(SequenceBatchCore, total=False):
    """Input datatype for inference with BERT-like models."""

    cu_seqlens_argmin: Tensor
    max_seqlen: Tensor

SequenceBatchCore

Bases: TypedDict

Input datatype for inference with BERT-like models.

Source code in bionemo/llm/model/biobert/lightning.py
84
85
86
87
class SequenceBatchCore(TypedDict):
    """Input datatype for inference with BERT-like models."""

    cu_seqlens: Tensor

bert_default_optimizer(model)

Returns the default optimizer for the BERT model.

Parameters:

Name Type Description Default
model Module

The BERT model.

required

Returns:

Type Description
FusedAdam

The default optimizer initialized for this BERT module's parameters.

FusedAdam

Uses a learning rate of 1e-4 and weight decay of 1e-2.

Source code in bionemo/llm/model/biobert/lightning.py
185
186
187
188
189
190
191
192
193
194
195
def bert_default_optimizer(model: torch.nn.Module) -> FusedAdam:
    """Returns the default optimizer for the BERT model.

    Args:
        model: The BERT model.

    Returns:
        The default optimizer initialized for this BERT module's parameters.
        Uses a learning rate of 1e-4 and weight decay of 1e-2.
    """
    return FusedAdam(model.parameters(), lr=1e-4, weight_decay=0.01)

bert_forward_step(model, batch)

Performs the model's forward pass using the batch, for Megatron compatibility.

This subsets the batch keys to the ones actually used by forward pass of the model, and then calls the model's forward pass. if "cu_seqsens" are defined in the batch, then the packed sequence parameters are also passed to the model for forward pass efficiency.

Source code in bionemo/llm/model/biobert/lightning.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def bert_forward_step(model: BertModel[DataT], batch: BertBatch) -> DataT:
    """Performs the model's forward pass using the batch, for Megatron compatibility.

    This subsets the batch keys to the ones actually used by forward pass of the model, and then calls the model's
    forward pass. if "cu_seqsens" are defined in the batch, then the packed sequence parameters are also passed to the
    model for forward pass efficiency.
    """
    if "cu_seqlens" in batch:
        forward_results = model.forward(
            input_ids=batch["text"],
            attention_mask=batch["attention_mask"],
            packed_seq_params=get_packed_seq_params(cast(SequenceBatch, batch)),
        )
    else:
        forward_results = model.forward(input_ids=batch["text"], attention_mask=batch["attention_mask"])
    # TODO support losses that also include the binary head, this means doing something more fancy than the one
    #      default GPT reduction function above MaskedTokenLossReduction()
    return forward_results

biobert_data_step(dataloader_iter)

Preprocesses a batch of data for the GeneFormer model, and ingest a single batch of data from the dataloader iterator. only necessary batch keys are subsetted and passed to the model's forward pass, and the loss forward pass, depending on stage. TODO document how parallel_state pipeline stages work.

Parameters:

Name Type Description Default
dataloader_iter

An iterator over the dataloader.

required

Returns:

Name Type Description
output Dict[str, Tensor]

A dictionary of this batch limiting to relevant keys.

Source code in bionemo/llm/model/biobert/lightning.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def biobert_data_step(dataloader_iter) -> Dict[str, Tensor]:
    """Preprocesses a batch of data for the GeneFormer model, and ingest a single batch of data from the dataloader iterator.
        only necessary batch keys are subsetted and passed to the model's forward pass, and the loss forward pass, depending on stage.
        TODO document how parallel_state pipeline stages work.

    Args:
        dataloader_iter: An iterator over the dataloader.

    Returns:
        output: A dictionary of this batch limiting to relevant keys.

    """  # noqa: D205
    # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87
    # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842

    batch = next(dataloader_iter)

    if isinstance(batch, tuple) and len(batch) == 3:
        _batch: dict = batch[0]
    else:
        _batch = batch

    required_keys = set()
    required_keys.add("attention_mask")
    if parallel_state.is_pipeline_first_stage():
        required_keys.add("text")
    if parallel_state.is_pipeline_last_stage():
        required_keys.update(("labels", "loss_mask", "types", "is_random"))
    # if self.get_attention_mask_from_fusion:
    #     required_keys.remove('attention_mask')

    _batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()}
    # slice batch along sequence dimension for context parallelism
    output = get_batch_on_this_context_parallel_rank(_batch)

    return output

biobert_lightning_module(config, optimizer=None, tokenizer=None, data_step=biobert_data_step, forward_step=bert_forward_step, model_transform=None, **model_construct_args)

A pytorch lightning module for BioBert-derived models.

This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions. To change your loss, pass in a different config object that returns a different loss reduction class. To change your model and what it outputs, pass in a different config object that returns a different model. Do not modify this function unless you need to change higher level logic. You may need to modify the various step and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some of those functions may need to be refactored out into the config object or a different place so that they live closer to the model definition.

Source code in bionemo/llm/model/biobert/lightning.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def biobert_lightning_module(
    config: BioBertConfig[MegatronBioBertModel, MegatronLossReduction],
    optimizer: Optional[MegatronOptimizerModule] = None,
    tokenizer: Optional[TokenizerSpec | PreTrainedTokenizerBase] = None,
    data_step: DataStep = biobert_data_step,
    forward_step: ForwardStep = bert_forward_step,
    model_transform: Optional[Callable] = None,
    **model_construct_args,
) -> BionemoLightningModule[MegatronBioBertModel, MegatronLossReduction]:
    """A pytorch lightning module for BioBert-derived models.

    This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions.
    To change your loss, pass in a different config object that returns a different loss reduction class.
    To change your model and what it outputs, pass in a different config object that returns a different model.
    Do not modify this function unless you need to change higher level logic. You may need to modify the various step
    and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some
    of those functions may need to be refactored out into the config object or a different place so that they live
    closer to the model definition.
    """
    return BionemoLightningModule(
        config=config,
        optimizer=optimizer if optimizer is not None else default_megatron_optimizer(),
        data_step=data_step,
        forward_step=forward_step,
        tokenizer=tokenizer,
        model_transform=model_transform,
        **model_construct_args,
    )

get_batch_on_this_context_parallel_rank(batch, in_place=True)

Ensures that the input batch is in the right format for context parallel rank.

Modifies the batch data based on the context parallel rank, if the context parallel world size is greater than 1. Otherwise, the batch is returned as-is.

Parameters:

Name Type Description Default
batch Dict[str, Tensor]

The input batch data.

required
in_place bool

If true, then the input is mutated. The returned dict is a reference to the input. Otherwise, the input data is always shallow-copied and this copy is modified and returned.

True

Returns:

Name Type Description
dict Dict[str, Tensor]

The modified batch data based on the context parallel rank.

Source code in bionemo/llm/model/biobert/lightning.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def get_batch_on_this_context_parallel_rank(batch: Dict[str, Tensor], in_place: bool = True) -> Dict[str, Tensor]:
    """Ensures that the input batch is in the right format for context parallel rank.

    Modifies the batch data based on the context parallel rank, if the context parallel world size is greater than 1.
    Otherwise, the batch is returned as-is.


    Args:
        batch: The input batch data.
        in_place: If true, then the input is mutated. The returned dict is a reference to the input.
                  Otherwise, the input data is always shallow-copied and this copy is modified and returned.

    Returns:
        dict: The modified batch data based on the context parallel rank.
    """
    if not in_place:
        batch: dict[str, Tensor] = dict(**batch)

    if cp_size := parallel_state.get_context_parallel_world_size() > 1:
        num_valid_tokens_in_ub: Tensor | None = None
        if "loss_mask" in batch and batch["loss_mask"] is not None:
            num_valid_tokens_in_ub = batch["loss_mask"].sum()

        cp_rank = parallel_state.get_context_parallel_rank()
        for key, val in batch.items():
            if val is not None:
                seq_dim = 1 if key != "attention_mask" else 2
                _val = val.view(
                    *val.shape[0:seq_dim],
                    2 * cp_size,
                    val.shape[seq_dim] // (2 * cp_size),
                    *val.shape[(seq_dim + 1) :],
                )
                index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
                    non_blocking=True
                )
                _val = _val.index_select(seq_dim, index)
                _val = _val.view(*val.shape[0:seq_dim], -1, *_val.shape[(seq_dim + 2) :])
                batch[key] = _val
        batch["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub  # type: ignore

    return batch

get_packed_seq_params(batch)

Get the packed sequence parameters for the given batch.

This function should only be called if cu_seqlens is defined in the batch.

Parameters:

Name Type Description Default
batch SequenceBatch

The input batch to pack.

required

Returns:

Name Type Description
PackedSeqParams PackedSeqParams

The packed sequence parameters containing the following attributes: - cu_seqlens_q (Tensor): The sequence lengths for query. - cu_seqlens_kv (Tensor): The sequence lengths for key and value. - max_seqlen_q (Tensor, optional): The maximum sequence length for query. - max_seqlen_kv (Tensor, optional): The maximum sequence length for key and value. - qkv_format (str): The format of query, key, and value tensors.

Source code in bionemo/llm/model/biobert/lightning.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def get_packed_seq_params(batch: SequenceBatch) -> PackedSeqParams:
    """Get the packed sequence parameters for the given batch.

    This function should only be called if `cu_seqlens` is defined in the batch.

    Args:
        batch: The input batch to pack.

    Returns:
        PackedSeqParams: The packed sequence parameters containing the following attributes:
            - cu_seqlens_q (Tensor): The sequence lengths for query.
            - cu_seqlens_kv (Tensor): The sequence lengths for key and value.
            - max_seqlen_q (Tensor, optional): The maximum sequence length for query.
            - max_seqlen_kv (Tensor, optional): The maximum sequence length for key and value.
            - qkv_format (str): The format of query, key, and value tensors.

    """
    cu_seqlens = batch["cu_seqlens"].squeeze()  # remove batch size dimension (mbs=1)
    # remove -1 "paddings" added in collate_fn
    if cu_seqlens_argmin := batch.get("cu_seqlens_argmin", None) is not None:
        # pre-compute cu_seqlens_argmin in dataset class for perf
        cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()]
    else:
        cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)]

    # pre-compute max_seqlens in dataset class for perf
    max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None

    # these args are passed eventually into TEDotProductAttention.forward()
    return PackedSeqParams(
        cu_seqlens_q=cu_seqlens,
        cu_seqlens_kv=cu_seqlens,
        max_seqlen_q=max_seqlen,
        max_seqlen_kv=max_seqlen,
        qkv_format="thd",
    )