Skip to content

Lightning basic

This is intended to be a minimal self-container NeMo2 example.

BionemoLightningModule

Bases: LightningModule, IOMixin, LightningPassthroughPredictionMixin

A very basic lightning module for testing the megatron strategy and the megatron-nemo2-bionemo contract.

Source code in bionemo/example_model/lightning/lightning_basic.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
class BionemoLightningModule(pl.LightningModule, io.IOMixin, LightningPassthroughPredictionMixin):
    """A very basic lightning module for testing the megatron strategy and the megatron-nemo2-bionemo contract."""

    def __init__(self, config: MegatronBioNeMoTrainableModelConfig):
        """Initializes the model.

        Args:
            config: a Config object necessary to construct the actual nn.Module (the thing that has the parameters).
        """
        super().__init__()
        self.config = config
        self.optim = MegatronOptimizerModule(
            config=OptimizerConfig(
                lr=1e-4,
                optimizer="adam",
                use_distributed_optimizer=True,
                bf16=config.bf16,
                fp16=config.fp16,
                params_dtype=config.params_dtype,
            ),
        )
        # Bind the configure_optimizers method to the model
        self.optim.connect(self)

    def forward(self, batch: Dict, batch_idx: int) -> Any:
        """This forward will be called by the megatron scheduler and it will be wrapped.

        !!! note

            The `training_step` defines the training loop and is independent of the `forward` method here.

        Args:
            batch: A dictionary of data.
            batch_idx: The index of the batch.

        Returns:
            The output of the model.
        """
        x = batch["data"]
        return self.module(x)

    def training_step(self, batch, batch_idx: Optional[int] = None):
        """The training step is where the loss is calculated and the backpropagation is done.

        Background:
        - NeMo's Strategy overrides this method.
        - The strategies' training step will call the forward method of the model.
        - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model.
        - That wrapped forward step is then executed inside the Mcore scheduler, which calls the `_forward_step` method from the
            MegatronParallel class.
        - Which then calls the training_step function here.

        In this particular use case, we simply call the forward method of this class, the lightning module.

        Args:
            batch: A dictionary of data. requires `batch_idx` as default None.
            batch_idx: The index of the batch.
        """
        # Forward pass
        predictions = self(batch, batch_idx)

        # Calculate loss using the training loss reduction function
        loss_reduction = self.training_loss_reduction()
        loss_reduction.setup(batch)
        loss = loss_reduction(predictions)

        # Log the training loss
        self.log("train_loss", loss[1]["avg"], on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return predictions

    def validation_step(self, batch, batch_idx: Optional[int] = None):
        """Alias for forward step at validation."""
        predictions = self(batch, batch_idx)

        # Calculate loss using the validation loss reduction function
        loss_reduction = self.validation_loss_reduction()
        loss_reduction.setup(batch)
        loss = loss_reduction(predictions)
        # Log the validation loss
        self.log(
            "val_loss",
            loss[1]["avg"],
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        return predictions

    def predict_step(self, batch, batch_idx: Optional[int] = None):
        """Alias for forward step at prediction."""
        return self(batch, batch_idx)

    def training_loss_reduction(self) -> MegatronLossReduction:
        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

        Returns:
        A MegatronLossReduction
        """
        return self.loss_reduction_class()()

    def validation_loss_reduction(self) -> MegatronLossReduction:
        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

        Returns:
        A MegatronLossReduction
        """
        return self.loss_reduction_class()()

    def test_loss_reduction(self) -> MegatronLossReduction:
        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

        Returns:
        A MegatronLossReduction
        """
        return self.loss_reduction_class()()

    def configure_model(self) -> None:
        """This configures the model. It is called lazily by the megatron strategy."""
        self.module = self.config.configure_model()

    def loss_reduction_class(self) -> Type[MegatronLossReduction]:
        """Get the loss reduction class the user has specified in their config."""
        return self.config.get_loss_reduction_class()

__init__(config)

Initializes the model.

Parameters:

Name Type Description Default
config MegatronBioNeMoTrainableModelConfig

a Config object necessary to construct the actual nn.Module (the thing that has the parameters).

required
Source code in bionemo/example_model/lightning/lightning_basic.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
def __init__(self, config: MegatronBioNeMoTrainableModelConfig):
    """Initializes the model.

    Args:
        config: a Config object necessary to construct the actual nn.Module (the thing that has the parameters).
    """
    super().__init__()
    self.config = config
    self.optim = MegatronOptimizerModule(
        config=OptimizerConfig(
            lr=1e-4,
            optimizer="adam",
            use_distributed_optimizer=True,
            bf16=config.bf16,
            fp16=config.fp16,
            params_dtype=config.params_dtype,
        ),
    )
    # Bind the configure_optimizers method to the model
    self.optim.connect(self)

configure_model()

This configures the model. It is called lazily by the megatron strategy.

Source code in bionemo/example_model/lightning/lightning_basic.py
639
640
641
def configure_model(self) -> None:
    """This configures the model. It is called lazily by the megatron strategy."""
    self.module = self.config.configure_model()

forward(batch, batch_idx)

This forward will be called by the megatron scheduler and it will be wrapped.

Note

The training_step defines the training loop and is independent of the forward method here.

Parameters:

Name Type Description Default
batch Dict

A dictionary of data.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Any

The output of the model.

Source code in bionemo/example_model/lightning/lightning_basic.py
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
def forward(self, batch: Dict, batch_idx: int) -> Any:
    """This forward will be called by the megatron scheduler and it will be wrapped.

    !!! note

        The `training_step` defines the training loop and is independent of the `forward` method here.

    Args:
        batch: A dictionary of data.
        batch_idx: The index of the batch.

    Returns:
        The output of the model.
    """
    x = batch["data"]
    return self.module(x)

loss_reduction_class()

Get the loss reduction class the user has specified in their config.

Source code in bionemo/example_model/lightning/lightning_basic.py
643
644
645
def loss_reduction_class(self) -> Type[MegatronLossReduction]:
    """Get the loss reduction class the user has specified in their config."""
    return self.config.get_loss_reduction_class()

predict_step(batch, batch_idx=None)

Alias for forward step at prediction.

Source code in bionemo/example_model/lightning/lightning_basic.py
611
612
613
def predict_step(self, batch, batch_idx: Optional[int] = None):
    """Alias for forward step at prediction."""
    return self(batch, batch_idx)

test_loss_reduction()

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

Returns: A MegatronLossReduction

Source code in bionemo/example_model/lightning/lightning_basic.py
631
632
633
634
635
636
637
def test_loss_reduction(self) -> MegatronLossReduction:
    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

    Returns:
    A MegatronLossReduction
    """
    return self.loss_reduction_class()()

training_loss_reduction()

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

Returns: A MegatronLossReduction

Source code in bionemo/example_model/lightning/lightning_basic.py
615
616
617
618
619
620
621
def training_loss_reduction(self) -> MegatronLossReduction:
    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

    Returns:
    A MegatronLossReduction
    """
    return self.loss_reduction_class()()

training_step(batch, batch_idx=None)

The training step is where the loss is calculated and the backpropagation is done.

Background: - NeMo's Strategy overrides this method. - The strategies' training step will call the forward method of the model. - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model. - That wrapped forward step is then executed inside the Mcore scheduler, which calls the _forward_step method from the MegatronParallel class. - Which then calls the training_step function here.

In this particular use case, we simply call the forward method of this class, the lightning module.

Parameters:

Name Type Description Default
batch

A dictionary of data. requires batch_idx as default None.

required
batch_idx Optional[int]

The index of the batch.

None
Source code in bionemo/example_model/lightning/lightning_basic.py
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
def training_step(self, batch, batch_idx: Optional[int] = None):
    """The training step is where the loss is calculated and the backpropagation is done.

    Background:
    - NeMo's Strategy overrides this method.
    - The strategies' training step will call the forward method of the model.
    - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model.
    - That wrapped forward step is then executed inside the Mcore scheduler, which calls the `_forward_step` method from the
        MegatronParallel class.
    - Which then calls the training_step function here.

    In this particular use case, we simply call the forward method of this class, the lightning module.

    Args:
        batch: A dictionary of data. requires `batch_idx` as default None.
        batch_idx: The index of the batch.
    """
    # Forward pass
    predictions = self(batch, batch_idx)

    # Calculate loss using the training loss reduction function
    loss_reduction = self.training_loss_reduction()
    loss_reduction.setup(batch)
    loss = loss_reduction(predictions)

    # Log the training loss
    self.log("train_loss", loss[1]["avg"], on_step=True, on_epoch=True, prog_bar=True, logger=True)

    return predictions

validation_loss_reduction()

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

Returns: A MegatronLossReduction

Source code in bionemo/example_model/lightning/lightning_basic.py
623
624
625
626
627
628
629
def validation_loss_reduction(self) -> MegatronLossReduction:
    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

    Returns:
    A MegatronLossReduction
    """
    return self.loss_reduction_class()()

validation_step(batch, batch_idx=None)

Alias for forward step at validation.

Source code in bionemo/example_model/lightning/lightning_basic.py
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
def validation_step(self, batch, batch_idx: Optional[int] = None):
    """Alias for forward step at validation."""
    predictions = self(batch, batch_idx)

    # Calculate loss using the validation loss reduction function
    loss_reduction = self.validation_loss_reduction()
    loss_reduction.setup(batch)
    loss = loss_reduction(predictions)
    # Log the validation loss
    self.log(
        "val_loss",
        loss[1]["avg"],
        on_step=False,
        on_epoch=True,
        prog_bar=True,
        logger=True,
    )

    return predictions

ClassifierLossReduction

Bases: MegatronLossReduction

A class used for calculating the loss, and for logging the reduced loss across micro batches.

Source code in bionemo/example_model/lightning/lightning_basic.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class ClassifierLossReduction(MegatronLossReduction):
    """A class used for calculating the loss, and for logging the reduced loss across micro batches."""

    def forward(self, batch: MnistItem, forward_out: Tensor) -> Tuple[Tensor, SameSizeLossDict]:
        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside LitAutoEncoder.

        Returns:
            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
                backpropagation and the ReductionT will be passed to the reduce method
                (which currently only works for logging.).
        """
        digits = batch["label"]
        digit_logits = forward_out
        loss = nn.functional.cross_entropy(digit_logits, digits)
        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
        """Works across micro-batches. (data on single gpu).

        Note: This currently only works for logging and this loss will not be used for backpropagation.

        Args:
            losses_reduced_per_micro_batch: a list of the outputs of forward

        Returns:
            A tensor that is the mean of the losses. (used for logging).
        """
        mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return mse_losses.mean()

forward(batch, forward_out)

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

Parameters:

Name Type Description Default
batch MnistItem

A batch of data that gets passed to the original forward inside LitAutoEncoder.

required
forward_out Tensor

the output of the forward method inside LitAutoEncoder.

required

Returns:

Type Description
Tuple[Tensor, SameSizeLossDict]

A tuple containing [, ReductionT] where the loss tensor will be used for backpropagation and the ReductionT will be passed to the reduce method (which currently only works for logging.).

Source code in bionemo/example_model/lightning/lightning_basic.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def forward(self, batch: MnistItem, forward_out: Tensor) -> Tuple[Tensor, SameSizeLossDict]:
    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

    Args:
        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
        forward_out: the output of the forward method inside LitAutoEncoder.

    Returns:
        A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
            backpropagation and the ReductionT will be passed to the reduce method
            (which currently only works for logging.).
    """
    digits = batch["label"]
    digit_logits = forward_out
    loss = nn.functional.cross_entropy(digit_logits, digits)
    return loss, {"avg": loss}

reduce(losses_reduced_per_micro_batch)

Works across micro-batches. (data on single gpu).

Note: This currently only works for logging and this loss will not be used for backpropagation.

Parameters:

Name Type Description Default
losses_reduced_per_micro_batch Sequence[SameSizeLossDict]

a list of the outputs of forward

required

Returns:

Type Description
Tensor

A tensor that is the mean of the losses. (used for logging).

Source code in bionemo/example_model/lightning/lightning_basic.py
191
192
193
194
195
196
197
198
199
200
201
202
203
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
    """Works across micro-batches. (data on single gpu).

    Note: This currently only works for logging and this loss will not be used for backpropagation.

    Args:
        losses_reduced_per_micro_batch: a list of the outputs of forward

    Returns:
        A tensor that is the mean of the losses. (used for logging).
    """
    mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
    return mse_losses.mean()

ExampleFineTuneBothConfig dataclass

Bases: ExampleGenericConfig['ExampleFineTuneBothModel', 'MSEPlusClassifierLossReduction'], IOMixinWithGettersSetters

ExampleConfig is a dataclass that is used to configure the model.

Timers from ModelParallelConfig are required for megatron forward compatibility.

Source code in bionemo/example_model/lightning/lightning_basic.py
489
490
491
492
493
494
495
496
497
498
499
@dataclass
class ExampleFineTuneBothConfig(
    ExampleGenericConfig["ExampleFineTuneBothModel", "MSEPlusClassifierLossReduction"], iom.IOMixinWithGettersSetters
):
    """ExampleConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    model_cls: Type[ExampleFineTuneBothModel] = ExampleFineTuneBothModel
    loss_cls: Type[MSEPlusClassifierLossReduction] = MSEPlusClassifierLossReduction

ExampleFineTuneBothModel

Bases: ExampleModel

Example of taking the example model and adding an output task.

Source code in bionemo/example_model/lightning/lightning_basic.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
class ExampleFineTuneBothModel(ExampleModel):
    """Example of taking the example model and adding an output task."""

    def __init__(self, config: ModelParallelConfig):
        super().__init__(config)
        # 10 output digits, and use the latent output layer (z) for making predictions
        self.digit_classifier = nn.Linear(self.linear2.out_features, 10)

    def forward(self, x: Tensor) -> ExampleFineTuneOutput:
        parent_out: ExampleModelOutput = super().forward(x)
        digit_logits = self.digit_classifier(parent_out["z"])
        return {
            "x_hat": parent_out["x_hat"],
            "z": parent_out["z"],
            "digit_logits": digit_logits,
        }

ExampleFineTuneConfig dataclass

Bases: ExampleGenericConfig['ExampleFineTuneConfig', 'ClassifierLossReduction'], IOMixinWithGettersSetters

ExampleConfig is a dataclass that is used to configure the model.

Timers from ModelParallelConfig are required for megatron forward compatibility.

Source code in bionemo/example_model/lightning/lightning_basic.py
502
503
504
505
506
507
508
509
510
511
512
@dataclass
class ExampleFineTuneConfig(
    ExampleGenericConfig["ExampleFineTuneConfig", "ClassifierLossReduction"], iom.IOMixinWithGettersSetters
):
    """ExampleConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    model_cls: Type[ExampleFineTuneModel] = ExampleFineTuneModel
    loss_cls: Type[ClassifierLossReduction] = ClassifierLossReduction

ExampleFineTuneModel

Bases: ExampleModelTrunk

Example of taking the example model and replacing output task.

Source code in bionemo/example_model/lightning/lightning_basic.py
411
412
413
414
415
416
417
418
419
420
421
422
class ExampleFineTuneModel(ExampleModelTrunk):
    """Example of taking the example model and replacing output task."""

    def __init__(self, config: ModelParallelConfig):
        super().__init__(config)
        # 10 output digits, and use the latent output layer (z) for making predictions
        self.digit_classifier = nn.Linear(self.linear2.out_features, 10)

    def forward(self, x: Tensor) -> Tensor:
        z: Tensor = super().forward(x)
        digit_logits = self.digit_classifier(z)  # to demonstrate flexibility, in this case we return a tensor
        return digit_logits

ExampleFineTuneOutput

Bases: ExampleModelOutput

Output for the fine-tuned example model implementation.

Source code in bionemo/example_model/lightning/lightning_basic.py
88
89
90
91
class ExampleFineTuneOutput(ExampleModelOutput):
    """Output for the fine-tuned example model implementation."""

    digit_logits: Tensor

ExampleGenericConfig dataclass

Bases: Generic[ExampleModelT, MegatronLossType], MegatronBioNeMoTrainableModelConfig[ExampleModelT, MegatronLossType]

ExampleGenericConfig is a dataclass that is used to configure the model.

Timers from ModelParallelConfig are required for megatron forward compatibility.

Source code in bionemo/example_model/lightning/lightning_basic.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
@dataclass
class ExampleGenericConfig(
    Generic[ExampleModelT, MegatronLossType], MegatronBioNeMoTrainableModelConfig[ExampleModelT, MegatronLossType]
):
    """ExampleGenericConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    loss_cls: Type[MegatronLossType] = MSELossReduction  # type: ignore  # this will get overriden by children
    hidden_size: int = 64  # Needs to be set to avoid zero division error in megatron :(
    num_attention_heads: int = 1  # Needs to be set to avoid zero division error in megatron :(
    num_layers: int = 1  # Needs to be set to avoid zero division error in megatron :(
    # IMPORTANT: Since we're adding/overriding the loss_cls, and that's not how we generally track this, we need to
    #   add this into the list of config settings that we do not draw from the loaded checkpoint when restoring.
    override_parent_fields: List[str] = field(default_factory=lambda: OVERRIDE_BIONEMO_CONFIG_DEFAULTS + ["loss_cls"])

    def configure_model(self) -> ExampleModelT:
        """Uses model_cls and loss_cls to configure the model.

        Note: Must pass self into Model since model requires having a config object.

        Returns:
            The model object.
        """
        # 1. first load any settings that may exist in the checkpoint related to the model.
        if self.initial_ckpt_path:
            self.load_settings_from_checkpoint(self.initial_ckpt_path)
        # 2. then initialize the model
        model = self.model_cls(self)
        # 3. Load weights from the checkpoint into the model
        if self.initial_ckpt_path:
            self.update_model_from_checkpoint(model, self.initial_ckpt_path)
        return model

    def get_loss_reduction_class(self) -> Type[MegatronLossType]:
        """Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config."""
        return self.loss_cls

configure_model()

Uses model_cls and loss_cls to configure the model.

Note: Must pass self into Model since model requires having a config object.

Returns:

Type Description
ExampleModelT

The model object.

Source code in bionemo/example_model/lightning/lightning_basic.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def configure_model(self) -> ExampleModelT:
    """Uses model_cls and loss_cls to configure the model.

    Note: Must pass self into Model since model requires having a config object.

    Returns:
        The model object.
    """
    # 1. first load any settings that may exist in the checkpoint related to the model.
    if self.initial_ckpt_path:
        self.load_settings_from_checkpoint(self.initial_ckpt_path)
    # 2. then initialize the model
    model = self.model_cls(self)
    # 3. Load weights from the checkpoint into the model
    if self.initial_ckpt_path:
        self.update_model_from_checkpoint(model, self.initial_ckpt_path)
    return model

get_loss_reduction_class()

Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config.

Source code in bionemo/example_model/lightning/lightning_basic.py
471
472
473
def get_loss_reduction_class(self) -> Type[MegatronLossType]:
    """Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config."""
    return self.loss_cls

ExampleModel

Bases: ExampleModelTrunk

An example model.

Source code in bionemo/example_model/lightning/lightning_basic.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
class ExampleModel(ExampleModelTrunk):
    """An example model."""

    def __init__(self, config: ModelParallelConfig) -> None:
        """Constructor of the model.

        Args:
            config: The config object is responsible for telling the strategy what model to create.
        """
        super().__init__(config)
        self.linear3 = nn.Linear(3, 64)
        self.relu2 = nn.ReLU()
        self.linear4 = nn.Linear(64, 28 * 28)

    def forward(self, x: Tensor) -> ExampleModelOutput:
        """Forward pass of the model.

        Args:
            x: The input data.

        Returns:
            x_hat: The result of the last linear layer of the network.
        """
        z: Tensor = super().forward(x)
        x_hat = self.linear3(z)
        x_hat = self.relu2(x_hat)
        x_hat = self.linear4(x_hat)
        return {"x_hat": x_hat, "z": z}

__init__(config)

Constructor of the model.

Parameters:

Name Type Description Default
config ModelParallelConfig

The config object is responsible for telling the strategy what model to create.

required
Source code in bionemo/example_model/lightning/lightning_basic.py
366
367
368
369
370
371
372
373
374
375
def __init__(self, config: ModelParallelConfig) -> None:
    """Constructor of the model.

    Args:
        config: The config object is responsible for telling the strategy what model to create.
    """
    super().__init__(config)
    self.linear3 = nn.Linear(3, 64)
    self.relu2 = nn.ReLU()
    self.linear4 = nn.Linear(64, 28 * 28)

forward(x)

Forward pass of the model.

Parameters:

Name Type Description Default
x Tensor

The input data.

required

Returns:

Name Type Description
x_hat ExampleModelOutput

The result of the last linear layer of the network.

Source code in bionemo/example_model/lightning/lightning_basic.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def forward(self, x: Tensor) -> ExampleModelOutput:
    """Forward pass of the model.

    Args:
        x: The input data.

    Returns:
        x_hat: The result of the last linear layer of the network.
    """
    z: Tensor = super().forward(x)
    x_hat = self.linear3(z)
    x_hat = self.relu2(x_hat)
    x_hat = self.linear4(x_hat)
    return {"x_hat": x_hat, "z": z}

ExampleModelOutput

Bases: TypedDict

Output for the example model implementation.

Source code in bionemo/example_model/lightning/lightning_basic.py
81
82
83
84
85
class ExampleModelOutput(TypedDict):
    """Output for the example model implementation."""

    x_hat: Tensor
    z: Tensor

ExampleModelTrunk

Bases: MegatronModule

Source code in bionemo/example_model/lightning/lightning_basic.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
class ExampleModelTrunk(MegatronModule):
    def __init__(self, config: ModelParallelConfig) -> None:
        """Constructor of the model.

        Args:
            config: The config object is responsible for telling the strategy what model to create.
        """
        super().__init__(config)
        # FIXME add an assertion that the user is not trying to do tensor parallelism since this doesn't use
        #  parallelizable megatron linear layers.
        self.model_type: ModelType = ModelType.encoder_or_decoder
        self.linear1 = nn.Linear(28 * 28, 64)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(64, 3)

    def forward(self, x: Tensor) -> Tensor:
        # we could return a dictionary of strings to tensors here, but let's demonstrate this is not necessary
        x = x.view(x.size(0), -1)
        z = self.linear1(x)
        z = self.relu(z)
        z = self.linear2(z)
        return z

    def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None:
        """This _would_ be needed for model parallel and other kinds of more complicated forward passes in megatron."""
        pass

__init__(config)

Constructor of the model.

Parameters:

Name Type Description Default
config ModelParallelConfig

The config object is responsible for telling the strategy what model to create.

required
Source code in bionemo/example_model/lightning/lightning_basic.py
336
337
338
339
340
341
342
343
344
345
346
347
348
def __init__(self, config: ModelParallelConfig) -> None:
    """Constructor of the model.

    Args:
        config: The config object is responsible for telling the strategy what model to create.
    """
    super().__init__(config)
    # FIXME add an assertion that the user is not trying to do tensor parallelism since this doesn't use
    #  parallelizable megatron linear layers.
    self.model_type: ModelType = ModelType.encoder_or_decoder
    self.linear1 = nn.Linear(28 * 28, 64)
    self.relu = nn.ReLU()
    self.linear2 = nn.Linear(64, 3)

set_input_tensor(input_tensor)

This would be needed for model parallel and other kinds of more complicated forward passes in megatron.

Source code in bionemo/example_model/lightning/lightning_basic.py
358
359
360
def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None:
    """This _would_ be needed for model parallel and other kinds of more complicated forward passes in megatron."""
    pass

MNISTCustomDataset

Bases: MNIST

A Wrapper for the MNIST Dataset.

Source code in bionemo/example_model/lightning/lightning_basic.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class MNISTCustomDataset(MNIST):
    """A Wrapper for the MNIST Dataset."""

    def __getitem__(self, idx: int) -> MnistItem:
        """Wraps the getitem method of the MNIST dataset such that we return a Dict.

        This is instead of a Tuple or tensor.

        Args:
            idx: The index we want to grab, an int.

        Returns:
            A dict containing the data ("x"), label ("y"), and index ("idx").
        """
        data, label = super().__getitem__(idx)

        return {
            "data": data,
            "label": label,
            "idx": idx,
        }

__getitem__(idx)

Wraps the getitem method of the MNIST dataset such that we return a Dict.

This is instead of a Tuple or tensor.

Parameters:

Name Type Description Default
idx int

The index we want to grab, an int.

required

Returns:

Type Description
MnistItem

A dict containing the data ("x"), label ("y"), and index ("idx").

Source code in bionemo/example_model/lightning/lightning_basic.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def __getitem__(self, idx: int) -> MnistItem:
    """Wraps the getitem method of the MNIST dataset such that we return a Dict.

    This is instead of a Tuple or tensor.

    Args:
        idx: The index we want to grab, an int.

    Returns:
        A dict containing the data ("x"), label ("y"), and index ("idx").
    """
    data, label = super().__getitem__(idx)

    return {
        "data": data,
        "label": label,
        "idx": idx,
    }

MNISTDataModule

Bases: LightningDataModule

A Megatron Compatible Data Module for MNIST.

Attributes: data_dir: data directory micro_batch_size: batch_size global_batch_size: global batch size max_len: maximal sequence length for megatron sampler rampup_batch_size: ramp up batch size num_workers: number of workers data_sampler: data_sampler set to be a megatron one

Source code in bionemo/example_model/lightning/lightning_basic.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
class MNISTDataModule(pl.LightningDataModule):
    """A Megatron Compatible Data Module for MNIST.

    Attributes:
    data_dir: data directory
    micro_batch_size: batch_size
    global_batch_size: global batch size
    max_len: maximal sequence length for megatron sampler
    rampup_batch_size: ramp up batch size
    num_workers: number of workers
    data_sampler: data_sampler set to be a megatron one
    """

    def __init__(
        self,
        data_dir: str | os.PathLike = str(BIONEMO_CACHE_DIR),
        batch_size: int = 32,
        num_workers: int = 0,
        global_batch_size: int | None = None,
        output_log: bool = True,
    ) -> None:
        """Initialize class.

        Args:
            data_dir: data directory
            batch_size: batch_size
            global_batch_size: global batch size
            num_workers: number of workers
            output_log: whether to output logs

        """
        super().__init__()
        self.data_dir = data_dir
        self.micro_batch_size = batch_size
        self.global_batch_size = global_batch_size or batch_size
        self.max_len = 1048
        self.rampup_batch_size = None
        self.num_workers = num_workers
        #  Note that this sampler is sequential, meaning it does not do any shuffling. Let's wrap our data in a shuffler.
        # Wraps the datasampler with the MegatronDataSampler. The MegatronDataSampler is a wrapper that allows the sampler
        # to be used with megatron. It sets up the capability to utilize micro-batching and gradient accumulation. It is also
        # the place where the global batch size is constructed.
        self.data_sampler = MegatronDataSampler(
            seq_len=self.max_len,
            micro_batch_size=self.micro_batch_size,
            global_batch_size=self.global_batch_size,
            rampup_batch_size=self.rampup_batch_size,
            output_log=output_log,
        )

    def setup(self, stage: str) -> None:
        """Sets up the datasets.

        Args:
            stage: can be one of train / test / predict.
        """
        self.mnist_test = MultiEpochDatasetResampler(
            IdentityMultiEpochDatasetWrapper(
                MNISTCustomDataset(self.data_dir, download=True, transform=transforms.ToTensor(), train=False)
            ),
            seed=43,
            shuffle=False,
        )
        mnist_full = MNISTCustomDataset(self.data_dir, download=True, transform=transforms.ToTensor(), train=True)
        mnist_train, mnist_val = torch.utils.data.random_split(
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )
        self.mnist_train = MultiEpochDatasetResampler(
            IdentityMultiEpochDatasetWrapper(mnist_train), seed=44, shuffle=True
        )

        self.mnist_val = MultiEpochDatasetResampler(
            IdentityMultiEpochDatasetWrapper(mnist_val),
            seed=45,
            shuffle=False,
        )

    def train_dataloader(self) -> DataLoader:
        """Returns the training dataloader."""
        return DataLoader(self.mnist_train, batch_size=self.micro_batch_size, num_workers=self.num_workers)

    def val_dataloader(self) -> DataLoader:
        """Returns the validation dataloader."""
        return DataLoader(self.mnist_val, batch_size=self.micro_batch_size, num_workers=self.num_workers)

    def predict_dataloader(self) -> DataLoader:
        """Returns the prediction dataloader."""
        return DataLoader(self.mnist_test, batch_size=self.micro_batch_size, num_workers=self.num_workers)

__init__(data_dir=str(BIONEMO_CACHE_DIR), batch_size=32, num_workers=0, global_batch_size=None, output_log=True)

Initialize class.

Parameters:

Name Type Description Default
data_dir str | PathLike

data directory

str(BIONEMO_CACHE_DIR)
batch_size int

batch_size

32
global_batch_size int | None

global batch size

None
num_workers int

number of workers

0
output_log bool

whether to output logs

True
Source code in bionemo/example_model/lightning/lightning_basic.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def __init__(
    self,
    data_dir: str | os.PathLike = str(BIONEMO_CACHE_DIR),
    batch_size: int = 32,
    num_workers: int = 0,
    global_batch_size: int | None = None,
    output_log: bool = True,
) -> None:
    """Initialize class.

    Args:
        data_dir: data directory
        batch_size: batch_size
        global_batch_size: global batch size
        num_workers: number of workers
        output_log: whether to output logs

    """
    super().__init__()
    self.data_dir = data_dir
    self.micro_batch_size = batch_size
    self.global_batch_size = global_batch_size or batch_size
    self.max_len = 1048
    self.rampup_batch_size = None
    self.num_workers = num_workers
    #  Note that this sampler is sequential, meaning it does not do any shuffling. Let's wrap our data in a shuffler.
    # Wraps the datasampler with the MegatronDataSampler. The MegatronDataSampler is a wrapper that allows the sampler
    # to be used with megatron. It sets up the capability to utilize micro-batching and gradient accumulation. It is also
    # the place where the global batch size is constructed.
    self.data_sampler = MegatronDataSampler(
        seq_len=self.max_len,
        micro_batch_size=self.micro_batch_size,
        global_batch_size=self.global_batch_size,
        rampup_batch_size=self.rampup_batch_size,
        output_log=output_log,
    )

predict_dataloader()

Returns the prediction dataloader.

Source code in bionemo/example_model/lightning/lightning_basic.py
323
324
325
def predict_dataloader(self) -> DataLoader:
    """Returns the prediction dataloader."""
    return DataLoader(self.mnist_test, batch_size=self.micro_batch_size, num_workers=self.num_workers)

setup(stage)

Sets up the datasets.

Parameters:

Name Type Description Default
stage str

can be one of train / test / predict.

required
Source code in bionemo/example_model/lightning/lightning_basic.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def setup(self, stage: str) -> None:
    """Sets up the datasets.

    Args:
        stage: can be one of train / test / predict.
    """
    self.mnist_test = MultiEpochDatasetResampler(
        IdentityMultiEpochDatasetWrapper(
            MNISTCustomDataset(self.data_dir, download=True, transform=transforms.ToTensor(), train=False)
        ),
        seed=43,
        shuffle=False,
    )
    mnist_full = MNISTCustomDataset(self.data_dir, download=True, transform=transforms.ToTensor(), train=True)
    mnist_train, mnist_val = torch.utils.data.random_split(
        mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
    )
    self.mnist_train = MultiEpochDatasetResampler(
        IdentityMultiEpochDatasetWrapper(mnist_train), seed=44, shuffle=True
    )

    self.mnist_val = MultiEpochDatasetResampler(
        IdentityMultiEpochDatasetWrapper(mnist_val),
        seed=45,
        shuffle=False,
    )

train_dataloader()

Returns the training dataloader.

Source code in bionemo/example_model/lightning/lightning_basic.py
315
316
317
def train_dataloader(self) -> DataLoader:
    """Returns the training dataloader."""
    return DataLoader(self.mnist_train, batch_size=self.micro_batch_size, num_workers=self.num_workers)

val_dataloader()

Returns the validation dataloader.

Source code in bionemo/example_model/lightning/lightning_basic.py
319
320
321
def val_dataloader(self) -> DataLoader:
    """Returns the validation dataloader."""
    return DataLoader(self.mnist_val, batch_size=self.micro_batch_size, num_workers=self.num_workers)

MSELossReduction

Bases: MegatronLossReduction

A class used for calculating the loss, and for logging the reduced loss across micro batches.

Source code in bionemo/example_model/lightning/lightning_basic.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class MSELossReduction(MegatronLossReduction):
    """A class used for calculating the loss, and for logging the reduced loss across micro batches."""

    def forward(self, batch: MnistItem, forward_out: Dict[str, Tensor]) -> Tuple[Tensor, SameSizeLossDict]:
        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside LitAutoEncoder.

        Returns:
            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
                backpropagation and the ReductionT will be passed to the reduce method
                (which currently only works for logging.).
        """
        x = batch["data"]
        x_hat = forward_out["x_hat"]
        xview = x.view(x.size(0), -1).to(x_hat.dtype)
        loss = nn.functional.mse_loss(x_hat, xview)

        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
        """Works across micro-batches. (data on single gpu).

        Note: This currently only works for logging and this loss will not be used for backpropagation.

        Args:
            losses_reduced_per_micro_batch: a list of the outputs of forward

        Returns:
            A tensor that is the mean of the losses. (used for logging).
        """
        mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return mse_losses.mean()

forward(batch, forward_out)

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

Parameters:

Name Type Description Default
batch MnistItem

A batch of data that gets passed to the original forward inside LitAutoEncoder.

required
forward_out Dict[str, Tensor]

the output of the forward method inside LitAutoEncoder.

required

Returns:

Type Description
Tuple[Tensor, SameSizeLossDict]

A tuple containing [, ReductionT] where the loss tensor will be used for backpropagation and the ReductionT will be passed to the reduce method (which currently only works for logging.).

Source code in bionemo/example_model/lightning/lightning_basic.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def forward(self, batch: MnistItem, forward_out: Dict[str, Tensor]) -> Tuple[Tensor, SameSizeLossDict]:
    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

    Args:
        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
        forward_out: the output of the forward method inside LitAutoEncoder.

    Returns:
        A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
            backpropagation and the ReductionT will be passed to the reduce method
            (which currently only works for logging.).
    """
    x = batch["data"]
    x_hat = forward_out["x_hat"]
    xview = x.view(x.size(0), -1).to(x_hat.dtype)
    loss = nn.functional.mse_loss(x_hat, xview)

    return loss, {"avg": loss}

reduce(losses_reduced_per_micro_batch)

Works across micro-batches. (data on single gpu).

Note: This currently only works for logging and this loss will not be used for backpropagation.

Parameters:

Name Type Description Default
losses_reduced_per_micro_batch Sequence[SameSizeLossDict]

a list of the outputs of forward

required

Returns:

Type Description
Tensor

A tensor that is the mean of the losses. (used for logging).

Source code in bionemo/example_model/lightning/lightning_basic.py
116
117
118
119
120
121
122
123
124
125
126
127
128
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
    """Works across micro-batches. (data on single gpu).

    Note: This currently only works for logging and this loss will not be used for backpropagation.

    Args:
        losses_reduced_per_micro_batch: a list of the outputs of forward

    Returns:
        A tensor that is the mean of the losses. (used for logging).
    """
    mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
    return mse_losses.mean()

MSEPlusClassifierLossReduction

Bases: MegatronLossReduction

A class used for calculating the loss, and for logging the reduced loss across micro batches.

Source code in bionemo/example_model/lightning/lightning_basic.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class MSEPlusClassifierLossReduction(MegatronLossReduction):
    """A class used for calculating the loss, and for logging the reduced loss across micro batches."""

    def forward(self, batch: MnistItem, forward_out: ExampleFineTuneOutput) -> Tuple[Tensor, SameSizeLossDict]:
        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside LitAutoEncoder.

        Returns:
            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
                backpropagation and the ReductionT will be passed to the reduce method
                (which currently only works for logging.).
        """
        x = batch["data"]
        digits = batch["label"]
        x_hat = forward_out["x_hat"]
        digit_logits = forward_out["digit_logits"]
        xview = x.view(x.size(0), -1).to(x_hat.dtype)
        mse_loss = nn.functional.mse_loss(x_hat, xview)
        classifier_loss = nn.functional.cross_entropy(digit_logits, digits)
        loss = classifier_loss + mse_loss
        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
        """Works across micro-batches. (data on single gpu).

        Note: This currently only works for logging and this loss will not be used for backpropagation.

        Args:
            losses_reduced_per_micro_batch: a list of the outputs of forward

        Returns:
            A tensor that is the mean of the losses. (used for logging).
        """
        mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return mse_losses.mean()

forward(batch, forward_out)

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

Parameters:

Name Type Description Default
batch MnistItem

A batch of data that gets passed to the original forward inside LitAutoEncoder.

required
forward_out ExampleFineTuneOutput

the output of the forward method inside LitAutoEncoder.

required

Returns:

Type Description
Tuple[Tensor, SameSizeLossDict]

A tuple containing [, ReductionT] where the loss tensor will be used for backpropagation and the ReductionT will be passed to the reduce method (which currently only works for logging.).

Source code in bionemo/example_model/lightning/lightning_basic.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def forward(self, batch: MnistItem, forward_out: ExampleFineTuneOutput) -> Tuple[Tensor, SameSizeLossDict]:
    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

    Args:
        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
        forward_out: the output of the forward method inside LitAutoEncoder.

    Returns:
        A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
            backpropagation and the ReductionT will be passed to the reduce method
            (which currently only works for logging.).
    """
    x = batch["data"]
    digits = batch["label"]
    x_hat = forward_out["x_hat"]
    digit_logits = forward_out["digit_logits"]
    xview = x.view(x.size(0), -1).to(x_hat.dtype)
    mse_loss = nn.functional.mse_loss(x_hat, xview)
    classifier_loss = nn.functional.cross_entropy(digit_logits, digits)
    loss = classifier_loss + mse_loss
    return loss, {"avg": loss}

reduce(losses_reduced_per_micro_batch)

Works across micro-batches. (data on single gpu).

Note: This currently only works for logging and this loss will not be used for backpropagation.

Parameters:

Name Type Description Default
losses_reduced_per_micro_batch Sequence[SameSizeLossDict]

a list of the outputs of forward

required

Returns:

Type Description
Tensor

A tensor that is the mean of the losses. (used for logging).

Source code in bionemo/example_model/lightning/lightning_basic.py
156
157
158
159
160
161
162
163
164
165
166
167
168
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
    """Works across micro-batches. (data on single gpu).

    Note: This currently only works for logging and this loss will not be used for backpropagation.

    Args:
        losses_reduced_per_micro_batch: a list of the outputs of forward

    Returns:
        A tensor that is the mean of the losses. (used for logging).
    """
    mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
    return mse_losses.mean()

MnistItem

Bases: TypedDict

Training input for the MNIST dataset.

Source code in bionemo/example_model/lightning/lightning_basic.py
73
74
75
76
77
78
class MnistItem(TypedDict):
    """Training input for the MNIST dataset."""

    data: Tensor
    label: Tensor
    idx: int

PretrainConfig dataclass

Bases: ExampleGenericConfig['ExampleModel', 'MSELossReduction'], IOMixinWithGettersSetters

PretrainConfig is a dataclass that is used to configure the model.

Timers from ModelParallelConfig are required for megatron forward compatibility.

Source code in bionemo/example_model/lightning/lightning_basic.py
478
479
480
481
482
483
484
485
486
@dataclass
class PretrainConfig(ExampleGenericConfig["ExampleModel", "MSELossReduction"], iom.IOMixinWithGettersSetters):
    """PretrainConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    model_cls: Type[ExampleModel] = ExampleModel
    loss_cls: Type[MSELossReduction] = MSELossReduction

SameSizeLossDict

Bases: TypedDict

This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size.

Source code in bionemo/example_model/lightning/lightning_basic.py
67
68
69
70
class SameSizeLossDict(TypedDict):
    """This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size."""

    avg: Tensor