Skip to content

Lightning

DataStep = Callable[[Iterator[DataT]], DataT] module-attribute

Batches together an iterator of individual examples.

Necessary for compatability with Megatron. This function type is similiar to the collate function of PyTorch.

A DataStep function takes an iterator over individual examples. Each example may be a tensor, sequence of tensors, or a set of named tensors (provided as a dict mapping str names to each Tensor). Each iteration must yield the same type.

The output of this function will mirror the same structure of each yielded example. It will be a concatenation of all of the examples in the iterator.

ForwardStep = Callable[[MegatronModelType, DataT], DataT] module-attribute

Megatron-compatible forward pass function.

BionemoLightningModule

Bases: Generic[MegatronModelType, MegatronLossType], LightningModule, IOMixin, ConnectorMixin, LightningPassthroughPredictionMixin

Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions.

Source code in bionemo/llm/lightning.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
class BionemoLightningModule(
    Generic[MegatronModelType, MegatronLossType],
    pl.LightningModule,
    nlio.IOMixin,
    nlio.ConnectorMixin,
    LightningPassthroughPredictionMixin,
):
    """Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions."""

    def __init__(
        self,
        config: BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
        forward_step: ForwardStep,
        data_step: DataStep,
        # TODO: Add transformer_layer_spec when we update mcore
        optimizer: MegatronOptimizerModule,
        model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
        **model_construct_args,
    ) -> None:
        """Constructor.

        Args:
            config: Serializable configuration object that allows one to construct a new model instance and loss
                function. Necessary for Megatron-based training as the model itself cannot be serialized and
                distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.
            forward_step: Performs forward pass using the model and a batch of data.
            data_step: Custom batch-creating function for the model.
            optimizer: Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning
                rate.
            model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
                `configure_model` method.
            model_transform: Optional. The model transform function.
            **model_construct_args: Optional. Arguments necessary for the supplied model configuration's
                `configure_model` method, which will make an instance of the model.
        """
        super().__init__()
        self.config = config
        self.module_construct_args: Optional[dict[str, Any]] = model_construct_args
        # ***must** be set up in configure_model() -- megatron constraint
        # also, must be called `module`: nemo expects the actual model to be stored this way
        self.module: Optional[MegatronModelType] = None
        self.loss_reduction_class: type[MegatronLossType] = config.get_loss_reduction_class()
        self.optim = optimizer
        self.optim.connect(self)  # This will bind the `configure_optimizers` method
        self._data_step = data_step
        self._forward_step = forward_step
        self.model_transform = model_transform

    def configure_model(self) -> None:
        """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

        NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

        Raises:
            ValueError iff the internal config's configure_model method returns None.
        """
        if self.module is None:
            model: MegatronModelType = (
                self.config.configure_model(**self.module_construct_args)
                if self.module_construct_args is not None
                else self.config.configure_model()
            )
            self.module = model
        if self.module is None:
            raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

    def forward(self, *args, **kwargs) -> DataT:
        """Call the forward method of the underlying model, and return whatever it outputs."""
        # safe to do because configure_model is idempotent
        self.configure_model()
        assert self.module is not None, "ERROR: configure_model() method has been incorrectly overridden!"
        prediction = self.module(*args, **kwargs)  # for now just pass through to the underlying model
        return prediction

    def data_step(self, dataloader_iter: Iterator[DataT]) -> DataT:  # noqa: D102
        return self._data_step(dataloader_iter)

    def forward_step(self, batch) -> Tensor:
        """Megatron-required: the training forward step for the model, which is required to produce the loss.

        Normally, the forward pass of a model means its inference. Loss is computed using the predictions
        from the forward pass against labels. Megatron unfortunately conflates these two different concepts
        and instead has models "forward" method produce the loss. See the Megatron docs for details:
        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

        To get actual predictions, use the :func:`forward` method instead.
        """
        # safe to do because configure_model is idempotent
        self.configure_model()
        assert self.module is not None
        return self._forward_step(self.module, batch)

    def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """In mcore the loss-function is part of the forward-pass when labels are provided."""
        return self.forward_step(batch)

    def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """In mcore the loss-function is part of the forward-pass when labels are provided."""
        return self.forward_step(batch)

    def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """Alias for forward_step."""
        return self.forward_step(batch)

    def training_loss_reduction(self) -> MegatronLossType:
        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss."""
        return self.loss_reduction_class()

    def validation_loss_reduction(self) -> MegatronLossType:  # noqa: D102
        return self.loss_reduction_class(validation_step=True)

    def test_loss_reduction(self) -> MegatronLossType:  # noqa: D102
        return self.loss_reduction_class(validation_step=True)

__init__(config, forward_step, data_step, optimizer, model_transform=None, **model_construct_args)

Constructor.

Parameters:

Name Type Description Default
config BionemoTrainableModelConfig[MegatronModelType, MegatronLossType]

Serializable configuration object that allows one to construct a new model instance and loss function. Necessary for Megatron-based training as the model itself cannot be serialized and distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.

required
forward_step ForwardStep

Performs forward pass using the model and a batch of data.

required
data_step DataStep

Custom batch-creating function for the model.

required
optimizer MegatronOptimizerModule

Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning rate.

required
model_construct_args

Optional. Any arguments necessary to construct the model in the config's configure_model method.

{}
model_transform Optional[Callable[[MegatronModelType], MegatronModelType]]

Optional. The model transform function.

None
**model_construct_args

Optional. Arguments necessary for the supplied model configuration's configure_model method, which will make an instance of the model.

{}
Source code in bionemo/llm/lightning.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def __init__(
    self,
    config: BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
    forward_step: ForwardStep,
    data_step: DataStep,
    # TODO: Add transformer_layer_spec when we update mcore
    optimizer: MegatronOptimizerModule,
    model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
    **model_construct_args,
) -> None:
    """Constructor.

    Args:
        config: Serializable configuration object that allows one to construct a new model instance and loss
            function. Necessary for Megatron-based training as the model itself cannot be serialized and
            distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.
        forward_step: Performs forward pass using the model and a batch of data.
        data_step: Custom batch-creating function for the model.
        optimizer: Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning
            rate.
        model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
            `configure_model` method.
        model_transform: Optional. The model transform function.
        **model_construct_args: Optional. Arguments necessary for the supplied model configuration's
            `configure_model` method, which will make an instance of the model.
    """
    super().__init__()
    self.config = config
    self.module_construct_args: Optional[dict[str, Any]] = model_construct_args
    # ***must** be set up in configure_model() -- megatron constraint
    # also, must be called `module`: nemo expects the actual model to be stored this way
    self.module: Optional[MegatronModelType] = None
    self.loss_reduction_class: type[MegatronLossType] = config.get_loss_reduction_class()
    self.optim = optimizer
    self.optim.connect(self)  # This will bind the `configure_optimizers` method
    self._data_step = data_step
    self._forward_step = forward_step
    self.model_transform = model_transform

configure_model()

Updates internal state: instantiates the model from the object's config, assigns to model attribute.

NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

Source code in bionemo/llm/lightning.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def configure_model(self) -> None:
    """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

    NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

    Raises:
        ValueError iff the internal config's configure_model method returns None.
    """
    if self.module is None:
        model: MegatronModelType = (
            self.config.configure_model(**self.module_construct_args)
            if self.module_construct_args is not None
            else self.config.configure_model()
        )
        self.module = model
    if self.module is None:
        raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

forward(*args, **kwargs)

Call the forward method of the underlying model, and return whatever it outputs.

Source code in bionemo/llm/lightning.py
280
281
282
283
284
285
286
def forward(self, *args, **kwargs) -> DataT:
    """Call the forward method of the underlying model, and return whatever it outputs."""
    # safe to do because configure_model is idempotent
    self.configure_model()
    assert self.module is not None, "ERROR: configure_model() method has been incorrectly overridden!"
    prediction = self.module(*args, **kwargs)  # for now just pass through to the underlying model
    return prediction

forward_step(batch)

Megatron-required: the training forward step for the model, which is required to produce the loss.

Normally, the forward pass of a model means its inference. Loss is computed using the predictions from the forward pass against labels. Megatron unfortunately conflates these two different concepts and instead has models "forward" method produce the loss. See the Megatron docs for details: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

To get actual predictions, use the :func:forward method instead.

Source code in bionemo/llm/lightning.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def forward_step(self, batch) -> Tensor:
    """Megatron-required: the training forward step for the model, which is required to produce the loss.

    Normally, the forward pass of a model means its inference. Loss is computed using the predictions
    from the forward pass against labels. Megatron unfortunately conflates these two different concepts
    and instead has models "forward" method produce the loss. See the Megatron docs for details:
    https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

    To get actual predictions, use the :func:`forward` method instead.
    """
    # safe to do because configure_model is idempotent
    self.configure_model()
    assert self.module is not None
    return self._forward_step(self.module, batch)

predict_step(batch, batch_idx=None)

Alias for forward_step.

Source code in bionemo/llm/lightning.py
314
315
316
def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """Alias for forward_step."""
    return self.forward_step(batch)

training_loss_reduction()

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

Source code in bionemo/llm/lightning.py
318
319
320
def training_loss_reduction(self) -> MegatronLossType:
    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss."""
    return self.loss_reduction_class()

training_step(batch, batch_idx=None)

In mcore the loss-function is part of the forward-pass when labels are provided.

Source code in bionemo/llm/lightning.py
306
307
308
def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """In mcore the loss-function is part of the forward-pass when labels are provided."""
    return self.forward_step(batch)

validation_step(batch, batch_idx=None)

In mcore the loss-function is part of the forward-pass when labels are provided.

Source code in bionemo/llm/lightning.py
310
311
312
def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """In mcore the loss-function is part of the forward-pass when labels are provided."""
    return self.forward_step(batch)

LightningPassthroughPredictionMixin

A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism.

Source code in bionemo/llm/lightning.py
188
189
190
191
192
193
class LightningPassthroughPredictionMixin:
    """A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism."""

    def predict_loss_reduction(self) -> PassthroughLossReduction:
        """For the predict step, pass through the forward pass output."""
        return PassthroughLossReduction()

predict_loss_reduction()

For the predict step, pass through the forward pass output.

Source code in bionemo/llm/lightning.py
191
192
193
def predict_loss_reduction(self) -> PassthroughLossReduction:
    """For the predict step, pass through the forward pass output."""
    return PassthroughLossReduction()

PassthroughLossReduction

Bases: MegatronLossReduction, Generic[DataT]

A workaround for nemo/megatron to perform inference.

Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of tensors. The inner type must always be a Tensor.

Source code in bionemo/llm/lightning.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class PassthroughLossReduction(MegatronLossReduction, Generic[DataT]):
    """A workaround for nemo/megatron to perform inference.

    Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is
    expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed
    as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of
    forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of
    tensors. The inner type _must always be a Tensor_.
    """

    def forward(self, batch: DataT, forward_out: DataT) -> Tuple[Tensor, DataT]:
        """Passes through the `forward_out` value as the 2nd tuple element.

        Args:
            batch: The batch of data that was passed through the model to generate output. NOTE: this value is ignored.
            forward_out: The output from your model's forward pass.

        Returns:
            A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).
        """
        dtype, device = get_dtype_device(forward_out)
        return torch.zeros(1, device=device, dtype=dtype), forward_out

    def reduce(self, forward_out: List[DataT]) -> DataT:
        """Collates list of model's outputs into a single output."""
        return batch_collator(forward_out)

forward(batch, forward_out)

Passes through the forward_out value as the 2nd tuple element.

Parameters:

Name Type Description Default
batch DataT

The batch of data that was passed through the model to generate output. NOTE: this value is ignored.

required
forward_out DataT

The output from your model's forward pass.

required

Returns:

Type Description
Tuple[Tensor, DataT]

A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).

Source code in bionemo/llm/lightning.py
170
171
172
173
174
175
176
177
178
179
180
181
def forward(self, batch: DataT, forward_out: DataT) -> Tuple[Tensor, DataT]:
    """Passes through the `forward_out` value as the 2nd tuple element.

    Args:
        batch: The batch of data that was passed through the model to generate output. NOTE: this value is ignored.
        forward_out: The output from your model's forward pass.

    Returns:
        A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).
    """
    dtype, device = get_dtype_device(forward_out)
    return torch.zeros(1, device=device, dtype=dtype), forward_out

reduce(forward_out)

Collates list of model's outputs into a single output.

Source code in bionemo/llm/lightning.py
183
184
185
def reduce(self, forward_out: List[DataT]) -> DataT:
    """Collates list of model's outputs into a single output."""
    return batch_collator(forward_out)

PerplexityLoggingCallback

Bases: Callback, CallbackMethods

Megatron Callback to log perplexity in validation and optionally training.

NeMo2.0 checks whether a callback is an instance of {LightningModule,LightningDataModule,Callback} but only megatron_hooks are useful.

Source code in bionemo/llm/lightning.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
class PerplexityLoggingCallback(pl.Callback, CallbackMethods):
    """Megatron Callback to log perplexity in validation and optionally training.

    NeMo2.0 checks whether a callback is an instance of {LightningModule,LightningDataModule,Callback} but only megatron_hooks are useful.
    """

    def __init__(self, log_train: bool = False, log_val: bool = True):
        """Initialize PerplexityLoggingCallback.

        Args:
            log_train: whether to log train perplexity. Defaults to False.
            log_val: whether to log validation perplexity. Defaults to True.
        """
        super().__init__()
        self.log_train = log_train
        self.log_val = log_val

    def _pad_to_max_length(
        self,
        microbatch_outputs: List[Dict[str, Dict[str, Tensor]]],
        key1: str,
        key2: str,
        pad_value: int = 0,
        seq_dim: int = 1,
        batch_dim: int = 0,
    ) -> Tensor:
        """Pad tensors to max length in microbatch_outputs."""
        assert seq_dim != batch_dim, "Forgot to set one of seq_dim, batch_dim, they are equal!"
        max_sequence_length: int = max(output[key1][key2].shape[seq_dim] for output in microbatch_outputs)

        tensors: List[Tensor] = []
        for microbatch_output in microbatch_outputs:
            tensor = microbatch_output[key1][key2]
            assert (
                tensor.dim() >= 2
            ), f"Tensor in microbatch_outputs must have at least 2 dimensions, but got {tensor.dim()} dimensions"
            pad_size = [(0, 0)] * tensor.dim()
            pad_size[seq_dim] = (0, max_sequence_length - tensor.shape[seq_dim])
            # Flatten pad size list for F.pad
            pad_size_flat = [item for sublist in reversed(pad_size) for item in sublist]
            tensors.append(
                torch.nn.functional.pad(  # padding reverse in order
                    tensor,
                    pad_size_flat,
                    mode="constant",
                    value=pad_value,
                )
            )

        return torch.cat(tensors, dim=batch_dim)  # concat on batch dim

    @override
    def on_megatron_reduce_microbatches_end(
        self,
        step: MegatronStep,
        microbatch_outputs: List[Any],
        loss_reduction: MegatronLossReduction,
        reduced: Tensor | dict[str, Tensor],
    ) -> None:
        """Log after MegatronReductionLoss.reduce is called.

        Expected microbatch_outputs to be a list of dicts with the following keys:
            - batch: dict of tensors with the following keys:
                - labels: [b s]
                - loss_mask: [b s]; 1 means included 0 means ignored
            - forward_out: dict of tensors with the following keys:
                - token_logits: [b s vocab]
        """
        if step.trainer.sanity_checking:  # skip sanity check
            return

        if step.trainer.training and not self.log_train:
            return

        if not parallel_state.is_pipeline_last_stage():
            return

        assert step.num_microbatches is not None, "num_microbatches must be initialized to non-None"
        assert step.num_microbatches > 0, "num_microbatches must be greater than 0"
        assert (
            len(microbatch_outputs) == step.num_microbatches
        ), "microbatch_outputs length does not match num_microbatches"
        labels = self._pad_to_max_length(microbatch_outputs, "batch", "labels", pad_value=-100)
        loss_mask = self._pad_to_max_length(microbatch_outputs, "batch", "loss_mask")
        token_logits = self._pad_to_max_length(
            microbatch_outputs, "forward_out", "token_logits", seq_dim=0, batch_dim=1
        )

        unreduced_token_loss = unreduced_token_loss_fn(
            token_logits.clone(),  # [s,b] as expected unreduced_token_loss_fn has inplace operation on token_logits
            labels.clone(),  # [b,s] as expected
        )  # [b s] is the return

        cp_size = parallel_state.get_context_parallel_world_size()
        if cp_size == 1:
            ppl = torch.exp((unreduced_token_loss * loss_mask).sum() / loss_mask.sum())
        else:
            raise NotImplementedError("Context parallel perplexity logging is not supported yet")

        if self.log_val and not step.trainer.training:
            step.pl_module.log("val_ppl", ppl, prog_bar=True, on_epoch=True)
        elif self.log_train and step.trainer.training:
            step.pl_module.log("train_ppl", ppl, prog_bar=True, batch_size=1, sync_dist=False)

__init__(log_train=False, log_val=True)

Initialize PerplexityLoggingCallback.

Parameters:

Name Type Description Default
log_train bool

whether to log train perplexity. Defaults to False.

False
log_val bool

whether to log validation perplexity. Defaults to True.

True
Source code in bionemo/llm/lightning.py
342
343
344
345
346
347
348
349
350
351
def __init__(self, log_train: bool = False, log_val: bool = True):
    """Initialize PerplexityLoggingCallback.

    Args:
        log_train: whether to log train perplexity. Defaults to False.
        log_val: whether to log validation perplexity. Defaults to True.
    """
    super().__init__()
    self.log_train = log_train
    self.log_val = log_val

on_megatron_reduce_microbatches_end(step, microbatch_outputs, loss_reduction, reduced)

Log after MegatronReductionLoss.reduce is called.

Expected microbatch_outputs to be a list of dicts with the following keys
  • batch: dict of tensors with the following keys:
    • labels: [b s]
    • loss_mask: [b s]; 1 means included 0 means ignored
  • forward_out: dict of tensors with the following keys:
    • token_logits: [b s vocab]
Source code in bionemo/llm/lightning.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
@override
def on_megatron_reduce_microbatches_end(
    self,
    step: MegatronStep,
    microbatch_outputs: List[Any],
    loss_reduction: MegatronLossReduction,
    reduced: Tensor | dict[str, Tensor],
) -> None:
    """Log after MegatronReductionLoss.reduce is called.

    Expected microbatch_outputs to be a list of dicts with the following keys:
        - batch: dict of tensors with the following keys:
            - labels: [b s]
            - loss_mask: [b s]; 1 means included 0 means ignored
        - forward_out: dict of tensors with the following keys:
            - token_logits: [b s vocab]
    """
    if step.trainer.sanity_checking:  # skip sanity check
        return

    if step.trainer.training and not self.log_train:
        return

    if not parallel_state.is_pipeline_last_stage():
        return

    assert step.num_microbatches is not None, "num_microbatches must be initialized to non-None"
    assert step.num_microbatches > 0, "num_microbatches must be greater than 0"
    assert (
        len(microbatch_outputs) == step.num_microbatches
    ), "microbatch_outputs length does not match num_microbatches"
    labels = self._pad_to_max_length(microbatch_outputs, "batch", "labels", pad_value=-100)
    loss_mask = self._pad_to_max_length(microbatch_outputs, "batch", "loss_mask")
    token_logits = self._pad_to_max_length(
        microbatch_outputs, "forward_out", "token_logits", seq_dim=0, batch_dim=1
    )

    unreduced_token_loss = unreduced_token_loss_fn(
        token_logits.clone(),  # [s,b] as expected unreduced_token_loss_fn has inplace operation on token_logits
        labels.clone(),  # [b,s] as expected
    )  # [b s] is the return

    cp_size = parallel_state.get_context_parallel_world_size()
    if cp_size == 1:
        ppl = torch.exp((unreduced_token_loss * loss_mask).sum() / loss_mask.sum())
    else:
        raise NotImplementedError("Context parallel perplexity logging is not supported yet")

    if self.log_val and not step.trainer.training:
        step.pl_module.log("val_ppl", ppl, prog_bar=True, on_epoch=True)
    elif self.log_train and step.trainer.training:
        step.pl_module.log("train_ppl", ppl, prog_bar=True, batch_size=1, sync_dist=False)

batch_collator(batches, batch_dim=0, batch_dim_key_defaults={'token_logits': 1})

Takes a sequence of batches and collates them into a single batch.

This is distinct from the standard pytorch default_collator since it does
not add the batch dimension, it's assumed the batch
dimension is already present in the input, as would be the case when
parallelizing across minibatches.

IMPORTANT: The underlying data primitive must be a torch Tensor. The input to this function is a recurisve type, there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.

Examples:

Outer container = Dict: [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])} Outer container = List: [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])] Outer container = Tuple: ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))

Parameters:

Name Type Description Default
batches Optional[Sequence[ReductionT]]

sequence of batches to collate into a single batch.

required
batch_dim int

If you know that the batch dim for the batch you are concatenating is not the 0th dimension (for example it is sequence first) then supply that dimension.

0
batch_dim_key_defaults dictionary of keys to integers

If your batch is a dictionary and you know that some keys have non-standard (0) batch dimensions, supply those here. By default "token_logits" has batch dim 1 and otherwise all keys are assumed to have batch dim 0.

{'token_logits': 1}

Returns:

Type Description
Optional[ReductionT]

A single batch of the same type as the elements of your input sequence.

Source code in bionemo/llm/lightning.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def batch_collator(
    batches: Optional[Union[Tuple[ReductionT], List[ReductionT]]],
    batch_dim: int = 0,
    batch_dim_key_defaults: dict[str, int] = {"token_logits": 1},
) -> Optional[ReductionT]:
    """Takes a sequence of batches and collates them into a single batch.

        This is distinct from the standard pytorch default_collator since it does
        not add the batch dimension, it's assumed the batch
        dimension is already present in the input, as would be the case when
        parallelizing across minibatches.

    IMPORTANT: The underlying data primitive _must_ be a torch Tensor. The input to this function is a recurisve type,
    there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.

    Examples:
        Outer container = Dict:
            [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])}
        Outer container = List:
            [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])]
        Outer container = Tuple:
            ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))

    Args:
        batches (Optional[Sequence[ReductionT]]): sequence of batches to collate into a single batch.
        batch_dim: If you know that the batch dim for the batch you are concatenating is not the 0th dimension (for
            example it is sequence first) then supply that dimension.
        batch_dim_key_defaults (dictionary of keys to integers): If your batch is a dictionary and you know that some
            keys have non-standard (0) batch dimensions, supply those here. By default "token_logits" has batch dim 1
            and otherwise all keys are assumed to have batch dim 0.

    Returns:
        A single batch of the same type as the elements of your input sequence.
    """
    match batches:
        # Handle base-cases for batch concatenation, either a list of None or a list of tensors
        case [None, *_]:
            return None
        case [Tensor(), *_]:
            return torch.cat(batches, dim=batch_dim)
        # Next 3 calls are the recursive calls into the sub-structures of the batch. We handle dictionaries, tuples, and lists
        case [dict(), *_]:
            return {
                key: batch_collator(
                    [batch[key] for batch in batches],
                    batch_dim=batch_dim_key_defaults.get(key, 0),
                    batch_dim_key_defaults=batch_dim_key_defaults,
                )
                for key in batches[0]
            }
        case [tuple(), *_]:
            return tuple(
                batch_collator(
                    [batch[i] for batch in batches], batch_dim=batch_dim, batch_dim_key_defaults=batch_dim_key_defaults
                )
                for i in range(len(batches[0]))
            )
        case [list(), *_]:
            return [
                batch_collator(
                    [batch[i] for batch in batches], batch_dim=batch_dim, batch_dim_key_defaults=batch_dim_key_defaults
                )
                for i in range(len(batches[0]))
            ]
        # Final cases shouldn't happen, an empty sequence (no batches), or "other".
        case []:
            raise ValueError("Cannot process an empty sequence")
        case _:
            raise ValueError("Unsupported input structure in batch_collator")

default_megatron_optimizer()

Default distributed optimizer uses Adam with a 1e-4 learning rate.

Source code in bionemo/llm/lightning.py
329
330
331
332
333
def default_megatron_optimizer() -> MegatronOptimizerModule:
    """Default distributed optimizer uses Adam with a 1e-4 learning rate."""
    return MegatronOptimizerModule(
        config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True),
    )

some_first(seq)

Returns the first non-None value from the sequence or fails

Source code in bionemo/llm/lightning.py
54
55
56
57
58
59
def some_first(seq: Iterable[Optional[T]]) -> T:
    """Returns the first non-None value from the sequence or fails"""  # noqa: D415
    for s in seq:
        if s is not None:
            return s
    raise ValueError("non-None value not found")