Skip to content

Finetune token regressor

FineTuneSeqLenBioBertConfig dataclass

Bases: BioBertConfig[MegatronBioBertFineTuneSeqLengthModel, SequenceLengthRMSEPlusBERTMLMLossWithReduction], IOMixinWithGettersSetters

BioBert fine-tuning sequence length model configuration.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
@dataclass
class FineTuneSeqLenBioBertConfig(
    BioBertConfig[MegatronBioBertFineTuneSeqLengthModel, SequenceLengthRMSEPlusBERTMLMLossWithReduction],
    iom.IOMixinWithGettersSetters,
):
    """BioBert fine-tuning sequence length model configuration."""

    # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269
    model_cls: Type[MegatronBioBertFineTuneSeqLengthModel] = MegatronBioBertFineTuneSeqLengthModel
    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
    # that has this new head and want to keep using these weights, please drop this next line or set to []
    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])

    def get_loss_reduction_class(self) -> Type[SequenceLengthRMSEPlusBERTMLMLossWithReduction]:
        """Loss function type."""
        return SequenceLengthRMSEPlusBERTMLMLossWithReduction

get_loss_reduction_class()

Loss function type.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
220
221
222
def get_loss_reduction_class(self) -> Type[SequenceLengthRMSEPlusBERTMLMLossWithReduction]:
    """Loss function type."""
    return SequenceLengthRMSEPlusBERTMLMLossWithReduction

LoRAForGeneFormerTokenRegressor

Bases: LoRA

LoRA for Genformer Token Regression.

There are a few tricky things here to get everything to work right:

  1. Freezing logic for the transformer has to be updated in order to not freeze the new head layers.
  2. The LoRA adapter logic has to be updated to pull the input/output sizes of the layers to be adapted from the modules that are passed (the previous method was compatible with nn and TE, but not megatrons tensor_parallel modules that are currently used by geneformer). This method contains a suggested refactor to make these methods a little more general and extensible with structural pattern matching as well. We should push this requirement onto NeMo, since we shouldn't duplicate the adapter method.
  3. There's a ton of assumptions in NeMo about which module is being called and that it inherits specific mixins. This could break the if it is updated from a megatron module to a torch module or something else. Functional calls are generally favored for this reason and some have been made here to avoid updating inheritance throughout the code base.
Source code in bionemo/geneformer/model/finetune_token_regressor.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class LoRAForGeneFormerTokenRegressor(LoRA):
    """LoRA for Genformer Token Regression.

    There are a few tricky things here to get everything to work right:

    1. Freezing logic for the transformer has to be updated in order to not freeze the new head layers.
    2. The LoRA adapter logic has to be updated to pull the input/output sizes of the layers to be adapted from the
       modules that are passed (the previous method was compatible with nn and TE, but not megatrons tensor_parallel
       modules that are currently used by geneformer). This method contains a suggested refactor to make these methods
       a little more general and extensible with structural pattern matching as well. We should push this
       requirement onto NeMo, since we shouldn't duplicate the adapter method.
    3. There's a ton of assumptions in NeMo about which module is being called and that it inherits specific mixins.
       This could break the if it is updated from a megatron module to a torch module or something else. Functional
       calls are generally favored for this reason and some have been made here to avoid updating inheritance throughout
       the code base.
    """

    def input_size_getter(self, m: nn.Module) -> int:
        """Gets the input size of the supplied model."""
        match m:
            case object(input_size=n):
                return n
            case object(in_features=n):
                return n
            case _:
                raise ValueError(f"Module {m} does not have a supported input size calculation.")

    def output_size_getter(self, m: nn.Module) -> int:
        """Gets the output size of the supplied model."""
        match m:
            case object(output_size=n):
                return n
            case object(out_features=n):
                return n
            case _:
                raise ValueError(f"Module {m} does not have a supported output size calculation.")

    def __call__(self, model: nn.Module) -> nn.Module:
        """Inference."""
        fn.walk(model, self.selective_freeze)
        fn.walk(model, self.transform)
        return model

    def selective_freeze(self, m: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module:
        """Freezes either 'encoder' or 'embedding' parameters of the input model (`m`) iff name is one of these."""
        if name in ["encoder", "embedding"]:
            FNMixin.freeze(m)
        return m

    def transform(
        self, m: nn.Module, name: str | None = None, prefix: str | None = None
    ) -> nn.Module | AdapterParallelAdd:
        """Transforms the input model if the name is in the target modules."""
        tp_size = parallel_state.get_tensor_model_parallel_world_size()
        if name in self.target_modules:
            # m.in_features and m.out_features are divided by tp_size already,
            # but in_features and out_features passed to ParallelLinearAdapter are not.
            if prefix is not None and "regression_head" in prefix:
                return m
            if name in ["linear_qkv", "linear_fc1"]:
                # Column Parallel Linear
                input_is_parallel = False
                in_features = self.input_size_getter(
                    m
                )  # TODO(@georgea) note that this could break depending on the impl of `m`
                out_features = self.output_size_getter(m) * tp_size
                # LoRA is applied after layernorm, so layernorm output must be returned
                m.return_layernorm_output = True
                # perf optimization for LoRA + SP
                if m.config.sequence_parallel and not m.ub_overlap_ag:
                    m.return_layernorm_output_gathered = True
            else:  # name in ['linear_proj', 'linear_fc2']
                # Row Parallel Linear
                input_is_parallel = True
                in_features = (
                    self.input_size_getter(m) * tp_size
                )  # TODO(@georgea) note this could break depending on the impl of `m`
                out_features = self.output_size_getter(m)

            adapter = ParallelLinearAdapter(
                in_features,
                out_features,
                self.dim,
                activation="identity",
                norm_position=None,
                norm_type=None,
                column_init_method=self.lora_A_init_method,
                row_init_method=self.lora_B_init_method,
                gather_output=False,
                input_is_parallel=input_is_parallel,
                dropout=self.dropout,
                dropout_position=self.dropout_position,
                model_parallel_config=getattr(m, "config", None),
                alpha=self.alpha,
            )
            return AdapterParallelAdd(m, adapter)
        return m

__call__(model)

Inference.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
262
263
264
265
266
def __call__(self, model: nn.Module) -> nn.Module:
    """Inference."""
    fn.walk(model, self.selective_freeze)
    fn.walk(model, self.transform)
    return model

input_size_getter(m)

Gets the input size of the supplied model.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
242
243
244
245
246
247
248
249
250
def input_size_getter(self, m: nn.Module) -> int:
    """Gets the input size of the supplied model."""
    match m:
        case object(input_size=n):
            return n
        case object(in_features=n):
            return n
        case _:
            raise ValueError(f"Module {m} does not have a supported input size calculation.")

output_size_getter(m)

Gets the output size of the supplied model.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
252
253
254
255
256
257
258
259
260
def output_size_getter(self, m: nn.Module) -> int:
    """Gets the output size of the supplied model."""
    match m:
        case object(output_size=n):
            return n
        case object(out_features=n):
            return n
        case _:
            raise ValueError(f"Module {m} does not have a supported output size calculation.")

selective_freeze(m, name=None, prefix=None)

Freezes either 'encoder' or 'embedding' parameters of the input model (m) iff name is one of these.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
268
269
270
271
272
def selective_freeze(self, m: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module:
    """Freezes either 'encoder' or 'embedding' parameters of the input model (`m`) iff name is one of these."""
    if name in ["encoder", "embedding"]:
        FNMixin.freeze(m)
    return m

transform(m, name=None, prefix=None)

Transforms the input model if the name is in the target modules.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def transform(
    self, m: nn.Module, name: str | None = None, prefix: str | None = None
) -> nn.Module | AdapterParallelAdd:
    """Transforms the input model if the name is in the target modules."""
    tp_size = parallel_state.get_tensor_model_parallel_world_size()
    if name in self.target_modules:
        # m.in_features and m.out_features are divided by tp_size already,
        # but in_features and out_features passed to ParallelLinearAdapter are not.
        if prefix is not None and "regression_head" in prefix:
            return m
        if name in ["linear_qkv", "linear_fc1"]:
            # Column Parallel Linear
            input_is_parallel = False
            in_features = self.input_size_getter(
                m
            )  # TODO(@georgea) note that this could break depending on the impl of `m`
            out_features = self.output_size_getter(m) * tp_size
            # LoRA is applied after layernorm, so layernorm output must be returned
            m.return_layernorm_output = True
            # perf optimization for LoRA + SP
            if m.config.sequence_parallel and not m.ub_overlap_ag:
                m.return_layernorm_output_gathered = True
        else:  # name in ['linear_proj', 'linear_fc2']
            # Row Parallel Linear
            input_is_parallel = True
            in_features = (
                self.input_size_getter(m) * tp_size
            )  # TODO(@georgea) note this could break depending on the impl of `m`
            out_features = self.output_size_getter(m)

        adapter = ParallelLinearAdapter(
            in_features,
            out_features,
            self.dim,
            activation="identity",
            norm_position=None,
            norm_type=None,
            column_init_method=self.lora_A_init_method,
            row_init_method=self.lora_B_init_method,
            gather_output=False,
            input_is_parallel=input_is_parallel,
            dropout=self.dropout,
            dropout_position=self.dropout_position,
            model_parallel_config=getattr(m, "config", None),
            alpha=self.alpha,
        )
        return AdapterParallelAdd(m, adapter)
    return m

MegatronBioBertFineTuneSeqLengthModel

Bases: MegatronBioBertModel

Megatron model for biobert finetuning with sequence length.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class MegatronBioBertFineTuneSeqLengthModel(MegatronBioBertModel):
    """Megatron model for biobert finetuning with sequence length."""

    def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
        """Constructor."""
        super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)
        self.include_hiddens_finetuning = (
            include_hiddens  # this include_hiddens is for the final output of fine-tuning
        )
        # If post_process is True that means that we are at the last megatron parallelism stage and we can
        #   apply the head.
        if post_process:
            # if we are doing post process (eg pipeline last stage) then we need to add the output layers
            self.regression_head = MegatronRegressionMLPHead(config)

    def forward(self, *args, **kwargs) -> MegatronFineTuneOutput | BioBertOutput | Tensor:
        """Inference."""
        output: MegatronFineTuneOutput | BioBertOutput | Tensor = super().forward(*args, **kwargs)
        # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
        if not self.post_process:
            return output  # we are not at the last pipeline stage so just return what the parent has
        # Double check that the output from the parent has everything we need to do prediction in this head.
        if not isinstance(output, dict) or ("hidden_states" not in output):
            raise ValueError(
                f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
                "Make sure include_hiddens=True in the call to super().__init__"
            )
        # Get the hidden state from the parent output, and pull out the [CLS] token for this task
        hidden_states: Tensor = output["hidden_states"][:, 0]  # [b s h] => [b h], use [CLS] (first) token for reg
        # Predict our 1d regression target
        regression_output = self.regression_head(hidden_states)
        if not self.include_hiddens_finetuning:
            del output["hidden_states"]
        output["regression_output"] = regression_output
        return output

__init__(config, *args, include_hiddens=False, post_process=True, **kwargs)

Constructor.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
173
174
175
176
177
178
179
180
181
182
183
def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
    """Constructor."""
    super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)
    self.include_hiddens_finetuning = (
        include_hiddens  # this include_hiddens is for the final output of fine-tuning
    )
    # If post_process is True that means that we are at the last megatron parallelism stage and we can
    #   apply the head.
    if post_process:
        # if we are doing post process (eg pipeline last stage) then we need to add the output layers
        self.regression_head = MegatronRegressionMLPHead(config)

forward(*args, **kwargs)

Inference.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def forward(self, *args, **kwargs) -> MegatronFineTuneOutput | BioBertOutput | Tensor:
    """Inference."""
    output: MegatronFineTuneOutput | BioBertOutput | Tensor = super().forward(*args, **kwargs)
    # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
    if not self.post_process:
        return output  # we are not at the last pipeline stage so just return what the parent has
    # Double check that the output from the parent has everything we need to do prediction in this head.
    if not isinstance(output, dict) or ("hidden_states" not in output):
        raise ValueError(
            f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
            "Make sure include_hiddens=True in the call to super().__init__"
        )
    # Get the hidden state from the parent output, and pull out the [CLS] token for this task
    hidden_states: Tensor = output["hidden_states"][:, 0]  # [b s h] => [b h], use [CLS] (first) token for reg
    # Predict our 1d regression target
    regression_output = self.regression_head(hidden_states)
    if not self.include_hiddens_finetuning:
        del output["hidden_states"]
    output["regression_output"] = regression_output
    return output

MegatronFineTuneOutput

Bases: BioBertOutput

Inference output type for MegatronBioBertFineTuneSeqLengthModel.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
64
65
66
67
class MegatronFineTuneOutput(BioBertOutput):
    """Inference output type for MegatronBioBertFineTuneSeqLengthModel."""

    regression_output: Tensor

MegatronRegressionMLPHead

Bases: MegatronModule

A megatron MLP head.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class MegatronRegressionMLPHead(MegatronModule):
    """A megatron MLP head."""

    def __init__(self, config: TransformerConfig):
        """Constructor."""
        super().__init__(config)
        # FC layer over just the [CLS] token embedding
        # TODO use bias/activation fusion if requested
        self.linear_fc1 = nn.Linear(in_features=config.hidden_size, out_features=config.ffn_hidden_size)
        self.activation_function = config.activation_func
        self.linear_fc2 = nn.Linear(in_features=config.ffn_hidden_size, out_features=1)

    def forward(self, hidden_states: Tensor) -> Tensor:
        """Inference."""
        return self.linear_fc2(self.activation_function(self.linear_fc1(hidden_states)))

__init__(config)

Constructor.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
156
157
158
159
160
161
162
163
def __init__(self, config: TransformerConfig):
    """Constructor."""
    super().__init__(config)
    # FC layer over just the [CLS] token embedding
    # TODO use bias/activation fusion if requested
    self.linear_fc1 = nn.Linear(in_features=config.hidden_size, out_features=config.ffn_hidden_size)
    self.activation_function = config.activation_func
    self.linear_fc2 = nn.Linear(in_features=config.ffn_hidden_size, out_features=1)

forward(hidden_states)

Inference.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
165
166
167
def forward(self, hidden_states: Tensor) -> Tensor:
    """Inference."""
    return self.linear_fc2(self.activation_function(self.linear_fc1(hidden_states)))

SequenceLengthRMSEPlusBERTMLMLossWithReduction

Bases: BERTMLMLossWithReduction

Loss function.

Source code in bionemo/geneformer/model/finetune_token_regressor.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class SequenceLengthRMSEPlusBERTMLMLossWithReduction(BERTMLMLossWithReduction):
    """Loss function."""

    def forward(
        self,
        batch: SeqLenRmsepBatch,
        forward_out: Dict[str, Tensor],
    ) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict | DataParallelGroupLossAndIO]:
        """Computes loss of `labels` in the batch vs `token_logits` in the forward output currently.

        In the future this will be extended to handle other loss types like sequence loss if it is present in the
        forward_out and batch.

        Args:
            batch: The batch of data. Each tensor should be of shape [batch_size, *, *],
                and match the corresponding dimension for that particular key in the batch output.
                For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].
            forward_out: The forward output from the model. Each tensor should be of shape [batch_size, *, *]

        Taken from:
        https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976
        """
        if "labels" not in batch:
            raise ValueError("Labels not provided in the batch. These are required for this loss computation.")

        unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"])
        regression_output = forward_out["regression_output"]
        n_tokens = batch["attention_mask"].sum(dim=-1, keepdim=True).to(dtype=regression_output.dtype)
        assert len(n_tokens.shape) == 2
        assert n_tokens.shape[-1] == 1
        rmse_loss = torch.nn.functional.mse_loss(regression_output, n_tokens)

        # TODO(@jstjohn) also handle different output keys, like the sequence loss.

        cp_size = parallel_state.get_context_parallel_world_size()
        if cp_size == 1:
            # reduce the loss across the micro batch
            loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
        else:
            # reduce the loss across the micro batch.
            # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
            #  This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
            #  other necessary keys to the batch. Thanks!
            loss_for_microbatch = masked_token_loss_context_parallel(
                unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
            )

        # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
        #  reducing the loss across the data parallel group.
        if self.validation_step and not self.val_drop_last:
            num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
            if loss_for_microbatch.isnan():
                # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
                #  to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
                #  re-defining it as a 0 loss. This is standard in NeMo/NeMo2.
                if batch["loss_mask"].count_nonzero() != 0:
                    raise ValueError("Got NaN loss with non-empty input")
                loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
            else:
                loss_sum_for_microbatch = num_valid_tokens_in_microbatch * loss_for_microbatch

            # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
            loss_sum_and_microbatch_size_all_gpu = torch.cat(
                [
                    loss_sum_for_microbatch.clone().detach().view(1),
                    torch.tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
                ]
            )
            torch.distributed.all_reduce(
                loss_sum_and_microbatch_size_all_gpu, group=parallel_state.get_data_parallel_group()
            )
            return loss_for_microbatch * cp_size, {
                "loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
            }
        loss_for_microbatch = loss_for_microbatch + rmse_loss  # add in the RMSE loss after reducing the logit loss
        # average the losses across the data parallel group, but also return the unreduced loss
        reduced_loss: Tensor = average_losses_across_data_parallel_group([loss_for_microbatch])
        if (self.validation_step and self.send_val_output) or (not self.validation_step and self.send_train_output):
            return loss_for_microbatch * cp_size, {"avg": reduced_loss, "batch": batch, "forward_out": forward_out}
        else:
            return loss_for_microbatch * cp_size, {"avg": reduced_loss}

forward(batch, forward_out)

Computes loss of labels in the batch vs token_logits in the forward output currently.

In the future this will be extended to handle other loss types like sequence loss if it is present in the forward_out and batch.

Parameters:

Name Type Description Default
batch SeqLenRmsepBatch

The batch of data. Each tensor should be of shape [batch_size, , ], and match the corresponding dimension for that particular key in the batch output. For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].

required
forward_out Dict[str, Tensor]

The forward output from the model. Each tensor should be of shape [batch_size, , ]

required

Taken from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976

Source code in bionemo/geneformer/model/finetune_token_regressor.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def forward(
    self,
    batch: SeqLenRmsepBatch,
    forward_out: Dict[str, Tensor],
) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict | DataParallelGroupLossAndIO]:
    """Computes loss of `labels` in the batch vs `token_logits` in the forward output currently.

    In the future this will be extended to handle other loss types like sequence loss if it is present in the
    forward_out and batch.

    Args:
        batch: The batch of data. Each tensor should be of shape [batch_size, *, *],
            and match the corresponding dimension for that particular key in the batch output.
            For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].
        forward_out: The forward output from the model. Each tensor should be of shape [batch_size, *, *]

    Taken from:
    https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976
    """
    if "labels" not in batch:
        raise ValueError("Labels not provided in the batch. These are required for this loss computation.")

    unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"])
    regression_output = forward_out["regression_output"]
    n_tokens = batch["attention_mask"].sum(dim=-1, keepdim=True).to(dtype=regression_output.dtype)
    assert len(n_tokens.shape) == 2
    assert n_tokens.shape[-1] == 1
    rmse_loss = torch.nn.functional.mse_loss(regression_output, n_tokens)

    # TODO(@jstjohn) also handle different output keys, like the sequence loss.

    cp_size = parallel_state.get_context_parallel_world_size()
    if cp_size == 1:
        # reduce the loss across the micro batch
        loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
    else:
        # reduce the loss across the micro batch.
        # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
        #  This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
        #  other necessary keys to the batch. Thanks!
        loss_for_microbatch = masked_token_loss_context_parallel(
            unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
        )

    # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
    #  reducing the loss across the data parallel group.
    if self.validation_step and not self.val_drop_last:
        num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
        if loss_for_microbatch.isnan():
            # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
            #  to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
            #  re-defining it as a 0 loss. This is standard in NeMo/NeMo2.
            if batch["loss_mask"].count_nonzero() != 0:
                raise ValueError("Got NaN loss with non-empty input")
            loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
        else:
            loss_sum_for_microbatch = num_valid_tokens_in_microbatch * loss_for_microbatch

        # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
        loss_sum_and_microbatch_size_all_gpu = torch.cat(
            [
                loss_sum_for_microbatch.clone().detach().view(1),
                torch.tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
            ]
        )
        torch.distributed.all_reduce(
            loss_sum_and_microbatch_size_all_gpu, group=parallel_state.get_data_parallel_group()
        )
        return loss_for_microbatch * cp_size, {
            "loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
        }
    loss_for_microbatch = loss_for_microbatch + rmse_loss  # add in the RMSE loss after reducing the logit loss
    # average the losses across the data parallel group, but also return the unreduced loss
    reduced_loss: Tensor = average_losses_across_data_parallel_group([loss_for_microbatch])
    if (self.validation_step and self.send_val_output) or (not self.validation_step and self.send_train_output):
        return loss_for_microbatch * cp_size, {"avg": reduced_loss, "batch": batch, "forward_out": forward_out}
    else:
        return loss_for_microbatch * cp_size, {"avg": reduced_loss}