Skip to content

Lightning

DataStep = Callable[[Iterator[DataT]], DataT] module-attribute

Batches together an iterator of individual examples.

Necessary for compatability with Megatron. This function type is similiar to the collate function of PyTorch.

A DataStep function takes an iterator over individual examples. Each example may be a tensor, sequence of tensors, or a set of named tensors (provided as a dict mapping str names to each Tensor). Each iteration must yield the same type.

The output of this function will mirror the same structure of each yielded example. It will be a concatenation of all of the examples in the iterator.

ForwardStep = Callable[[MegatronModelType, DataT], DataT] module-attribute

Megatron-compatible forward pass function.

BionemoLightningModule

Bases: Generic[MegatronModelType, MegatronLossType], LightningModule, IOMixin, ConnectorMixin, LightningPassthroughPredictionMixin

Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions.

Source code in bionemo/llm/lightning.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
class BionemoLightningModule(
    Generic[MegatronModelType, MegatronLossType],
    pl.LightningModule,
    nlio.IOMixin,
    nlio.ConnectorMixin,
    LightningPassthroughPredictionMixin,
):
    """Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions."""

    def __init__(
        self,
        config: BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
        forward_step: ForwardStep,
        data_step: DataStep,
        # TODO: Add transformer_layer_spec when we update mcore
        optimizer: MegatronOptimizerModule,
        model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
        log_train_ppl: bool = False,
        log_val_ppl: bool = False,
        **model_construct_args,
    ) -> None:
        """Constructor.

        Args:
            config: Serializable configuration object that allows one to construct a new model instance and loss
                function. Necessary for Megatron-based training as the model itself cannot be serialized and
                distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.
            forward_step: Performs forward pass using the model and a batch of data.
            data_step: Custom batch-creating function for the model.
            optimizer: Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning
                rate.
            model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
                `configure_model` method.
            model_transform: Optional. The model transform function.
            log_train_ppl (bool): Log training perplexity.
            log_val_ppl (bool): Log validation perplexity.
            **model_construct_args: Optional. Arguments necessary for the supplied model configuration's
                `configure_model` method, which will make an instance of the model.
        """
        super().__init__()
        self.config = config
        self.module_construct_args: Optional[dict[str, Any]] = model_construct_args
        # ***must** be set up in configure_model() -- megatron constraint
        # also, must be called `module`: nemo expects the actual model to be stored this way
        self.module: Optional[MegatronModelType] = None
        self.loss_reduction_class: type[MegatronLossType] = config.get_loss_reduction_class()
        self.optim = optimizer
        self.optim.connect(self)  # This will bind the `configure_optimizers` method
        self._data_step = data_step
        self._forward_step = forward_step
        self.model_transform = model_transform

        # torchmetrics must init here for fiddle serialization
        self.train_ppl = torchmetrics.text.Perplexity(ignore_index=MLM_LOSS_IGNORE_INDEX) if log_train_ppl else None
        self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=MLM_LOSS_IGNORE_INDEX) if log_val_ppl else None

    def configure_model(self) -> None:
        """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

        NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

        Raises:
            ValueError iff the internal config's configure_model method returns None.
        """
        if self.module is None:
            model: MegatronModelType = (
                self.config.configure_model(**self.module_construct_args)
                if self.module_construct_args is not None
                else self.config.configure_model()
            )
            self.module = model

        if self.module is None:
            raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

    def is_on_logging_device(self):
        """Return True if last stage of pipeline parallel and first tensor parallel rank."""
        return parallel_state.is_pipeline_last_stage() and parallel_state.get_tensor_model_parallel_rank() == 0

    def forward(self, *args, **kwargs) -> DataT:
        """Call the forward method of the underlying model, and return whatever it outputs."""
        # safe to do because configure_model is idempotent
        self.configure_model()
        assert self.module is not None, "ERROR: configure_model() method has been incorrectly overridden!"
        prediction = self.module(*args, **kwargs)  # for now just pass through to the underlying model
        return prediction

    def data_step(self, dataloader_iter: Iterator[DataT]) -> DataT:  # noqa: D102
        return self._data_step(dataloader_iter)

    def forward_step(self, batch) -> Tensor:
        """Megatron-required: the training forward step for the model, which is required to produce the loss.

        Normally, the forward pass of a model means its inference. Loss is computed using the predictions
        from the forward pass against labels. Megatron unfortunately conflates these two different concepts
        and instead has models "forward" method produce the loss. See the Megatron docs for details:
        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

        To get actual predictions, use the :func:`forward` method instead.
        """
        # safe to do because configure_model is idempotent
        self.configure_model()
        assert self.module is not None
        return self._forward_step(self.module, batch)

    def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """In mcore the loss-function is part of the forward-pass when labels are provided."""
        outputs = self.forward_step(batch)
        logits = outputs["token_logits"].detach().transpose(0, 1).clone()  #  [s, b, v] -> [b, s, v]

        if self.train_ppl is not None:
            if self.is_on_logging_device():
                self.train_ppl(logits, batch["labels"])

            self.log("train_ppl", self.train_ppl, on_step=True, on_epoch=False, prog_bar=True)

        return outputs

    def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """In mcore the loss-function is part of the forward-pass when labels are provided."""
        outputs = self.forward_step(batch)
        logits = outputs["token_logits"].detach().transpose(0, 1).clone()  #  [s, b, v] -> [b, s, v]

        if self.valid_ppl is not None and self.is_on_logging_device():
            self.valid_ppl.update(logits, batch["labels"])

        return outputs

    def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """Alias for forward_step."""
        if len(batch) == 0:
            return
        return self.forward_step(batch)

    def training_loss_reduction(self) -> MegatronLossType:
        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss."""
        return self.loss_reduction_class()

    def validation_loss_reduction(self) -> MegatronLossType:  # noqa: D102
        return self.loss_reduction_class(validation_step=True)

    def test_loss_reduction(self) -> MegatronLossType:  # noqa: D102
        return self.loss_reduction_class(validation_step=True)

    def on_validation_epoch_end(self):  # noqa: D102
        if self.valid_ppl is None:
            return

        if self.trainer.sanity_checking:
            self.valid_ppl.reset()  # clean up sanity runs
            return

        self.log("valid_ppl", self.valid_ppl, on_step=False, on_epoch=True, prog_bar=True)

__init__(config, forward_step, data_step, optimizer, model_transform=None, log_train_ppl=False, log_val_ppl=False, **model_construct_args)

Constructor.

Parameters:

Name Type Description Default
config BionemoTrainableModelConfig[MegatronModelType, MegatronLossType]

Serializable configuration object that allows one to construct a new model instance and loss function. Necessary for Megatron-based training as the model itself cannot be serialized and distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.

required
forward_step ForwardStep

Performs forward pass using the model and a batch of data.

required
data_step DataStep

Custom batch-creating function for the model.

required
optimizer MegatronOptimizerModule

Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning rate.

required
model_construct_args

Optional. Any arguments necessary to construct the model in the config's configure_model method.

{}
model_transform Optional[Callable[[MegatronModelType], MegatronModelType]]

Optional. The model transform function.

None
log_train_ppl bool

Log training perplexity.

False
log_val_ppl bool

Log validation perplexity.

False
**model_construct_args

Optional. Arguments necessary for the supplied model configuration's configure_model method, which will make an instance of the model.

{}
Source code in bionemo/llm/lightning.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def __init__(
    self,
    config: BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
    forward_step: ForwardStep,
    data_step: DataStep,
    # TODO: Add transformer_layer_spec when we update mcore
    optimizer: MegatronOptimizerModule,
    model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
    log_train_ppl: bool = False,
    log_val_ppl: bool = False,
    **model_construct_args,
) -> None:
    """Constructor.

    Args:
        config: Serializable configuration object that allows one to construct a new model instance and loss
            function. Necessary for Megatron-based training as the model itself cannot be serialized and
            distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.
        forward_step: Performs forward pass using the model and a batch of data.
        data_step: Custom batch-creating function for the model.
        optimizer: Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning
            rate.
        model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
            `configure_model` method.
        model_transform: Optional. The model transform function.
        log_train_ppl (bool): Log training perplexity.
        log_val_ppl (bool): Log validation perplexity.
        **model_construct_args: Optional. Arguments necessary for the supplied model configuration's
            `configure_model` method, which will make an instance of the model.
    """
    super().__init__()
    self.config = config
    self.module_construct_args: Optional[dict[str, Any]] = model_construct_args
    # ***must** be set up in configure_model() -- megatron constraint
    # also, must be called `module`: nemo expects the actual model to be stored this way
    self.module: Optional[MegatronModelType] = None
    self.loss_reduction_class: type[MegatronLossType] = config.get_loss_reduction_class()
    self.optim = optimizer
    self.optim.connect(self)  # This will bind the `configure_optimizers` method
    self._data_step = data_step
    self._forward_step = forward_step
    self.model_transform = model_transform

    # torchmetrics must init here for fiddle serialization
    self.train_ppl = torchmetrics.text.Perplexity(ignore_index=MLM_LOSS_IGNORE_INDEX) if log_train_ppl else None
    self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=MLM_LOSS_IGNORE_INDEX) if log_val_ppl else None

configure_model()

Updates internal state: instantiates the model from the object's config, assigns to model attribute.

NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

Source code in bionemo/llm/lightning.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def configure_model(self) -> None:
    """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

    NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

    Raises:
        ValueError iff the internal config's configure_model method returns None.
    """
    if self.module is None:
        model: MegatronModelType = (
            self.config.configure_model(**self.module_construct_args)
            if self.module_construct_args is not None
            else self.config.configure_model()
        )
        self.module = model

    if self.module is None:
        raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

forward(*args, **kwargs)

Call the forward method of the underlying model, and return whatever it outputs.

Source code in bionemo/llm/lightning.py
289
290
291
292
293
294
295
def forward(self, *args, **kwargs) -> DataT:
    """Call the forward method of the underlying model, and return whatever it outputs."""
    # safe to do because configure_model is idempotent
    self.configure_model()
    assert self.module is not None, "ERROR: configure_model() method has been incorrectly overridden!"
    prediction = self.module(*args, **kwargs)  # for now just pass through to the underlying model
    return prediction

forward_step(batch)

Megatron-required: the training forward step for the model, which is required to produce the loss.

Normally, the forward pass of a model means its inference. Loss is computed using the predictions from the forward pass against labels. Megatron unfortunately conflates these two different concepts and instead has models "forward" method produce the loss. See the Megatron docs for details: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

To get actual predictions, use the :func:forward method instead.

Source code in bionemo/llm/lightning.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def forward_step(self, batch) -> Tensor:
    """Megatron-required: the training forward step for the model, which is required to produce the loss.

    Normally, the forward pass of a model means its inference. Loss is computed using the predictions
    from the forward pass against labels. Megatron unfortunately conflates these two different concepts
    and instead has models "forward" method produce the loss. See the Megatron docs for details:
    https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

    To get actual predictions, use the :func:`forward` method instead.
    """
    # safe to do because configure_model is idempotent
    self.configure_model()
    assert self.module is not None
    return self._forward_step(self.module, batch)

is_on_logging_device()

Return True if last stage of pipeline parallel and first tensor parallel rank.

Source code in bionemo/llm/lightning.py
285
286
287
def is_on_logging_device(self):
    """Return True if last stage of pipeline parallel and first tensor parallel rank."""
    return parallel_state.is_pipeline_last_stage() and parallel_state.get_tensor_model_parallel_rank() == 0

predict_step(batch, batch_idx=None)

Alias for forward_step.

Source code in bionemo/llm/lightning.py
338
339
340
341
342
def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """Alias for forward_step."""
    if len(batch) == 0:
        return
    return self.forward_step(batch)

training_loss_reduction()

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

Source code in bionemo/llm/lightning.py
344
345
346
def training_loss_reduction(self) -> MegatronLossType:
    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss."""
    return self.loss_reduction_class()

training_step(batch, batch_idx=None)

In mcore the loss-function is part of the forward-pass when labels are provided.

Source code in bionemo/llm/lightning.py
315
316
317
318
319
320
321
322
323
324
325
326
def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """In mcore the loss-function is part of the forward-pass when labels are provided."""
    outputs = self.forward_step(batch)
    logits = outputs["token_logits"].detach().transpose(0, 1).clone()  #  [s, b, v] -> [b, s, v]

    if self.train_ppl is not None:
        if self.is_on_logging_device():
            self.train_ppl(logits, batch["labels"])

        self.log("train_ppl", self.train_ppl, on_step=True, on_epoch=False, prog_bar=True)

    return outputs

validation_step(batch, batch_idx=None)

In mcore the loss-function is part of the forward-pass when labels are provided.

Source code in bionemo/llm/lightning.py
328
329
330
331
332
333
334
335
336
def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """In mcore the loss-function is part of the forward-pass when labels are provided."""
    outputs = self.forward_step(batch)
    logits = outputs["token_logits"].detach().transpose(0, 1).clone()  #  [s, b, v] -> [b, s, v]

    if self.valid_ppl is not None and self.is_on_logging_device():
        self.valid_ppl.update(logits, batch["labels"])

    return outputs

LightningPassthroughPredictionMixin

A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism.

Source code in bionemo/llm/lightning.py
184
185
186
187
188
189
class LightningPassthroughPredictionMixin:
    """A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism."""

    def predict_loss_reduction(self) -> PassthroughLossReduction:
        """For the predict step, pass through the forward pass output."""
        return PassthroughLossReduction()

predict_loss_reduction()

For the predict step, pass through the forward pass output.

Source code in bionemo/llm/lightning.py
187
188
189
def predict_loss_reduction(self) -> PassthroughLossReduction:
    """For the predict step, pass through the forward pass output."""
    return PassthroughLossReduction()

PassthroughLossReduction

Bases: MegatronLossReduction, Generic[DataT]

A workaround for nemo/megatron to perform inference.

Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of tensors. The inner type must always be a Tensor.

Source code in bionemo/llm/lightning.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class PassthroughLossReduction(MegatronLossReduction, Generic[DataT]):
    """A workaround for nemo/megatron to perform inference.

    Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is
    expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed
    as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of
    forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of
    tensors. The inner type _must always be a Tensor_.
    """

    def forward(self, batch: DataT, forward_out: DataT) -> Tuple[Tensor, DataT]:
        """Passes through the `forward_out` value as the 2nd tuple element.

        Args:
            batch: The batch of data that was passed through the model to generate output. NOTE: this value is ignored.
            forward_out: The output from your model's forward pass.

        Returns:
            A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).
        """
        return torch.zeros((1, 1)), forward_out

    def reduce(self, forward_out: List[DataT]) -> DataT:
        """Collates list of model's outputs into a single output."""
        return batch_collator(forward_out)

forward(batch, forward_out)

Passes through the forward_out value as the 2nd tuple element.

Parameters:

Name Type Description Default
batch DataT

The batch of data that was passed through the model to generate output. NOTE: this value is ignored.

required
forward_out DataT

The output from your model's forward pass.

required

Returns:

Type Description
Tuple[Tensor, DataT]

A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).

Source code in bionemo/llm/lightning.py
167
168
169
170
171
172
173
174
175
176
177
def forward(self, batch: DataT, forward_out: DataT) -> Tuple[Tensor, DataT]:
    """Passes through the `forward_out` value as the 2nd tuple element.

    Args:
        batch: The batch of data that was passed through the model to generate output. NOTE: this value is ignored.
        forward_out: The output from your model's forward pass.

    Returns:
        A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).
    """
    return torch.zeros((1, 1)), forward_out

reduce(forward_out)

Collates list of model's outputs into a single output.

Source code in bionemo/llm/lightning.py
179
180
181
def reduce(self, forward_out: List[DataT]) -> DataT:
    """Collates list of model's outputs into a single output."""
    return batch_collator(forward_out)

batch_collator(batches, batch_dim=0, batch_dim_key_defaults={'token_logits': 1})

Takes a sequence of batches and collates them into a single batch.

This is distinct from the standard pytorch default_collator since it does
not add the batch dimension, it's assumed the batch
dimension is already present in the input, as would be the case when
parallelizing across minibatches.

IMPORTANT: The underlying data primitive must be a torch Tensor. The input to this function is a recurisve type, there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.

Examples:

Outer container = Dict: [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])} Outer container = List: [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])] Outer container = Tuple: ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))

Parameters:

Name Type Description Default
batches Optional[Sequence[ReductionT]]

sequence of batches to collate into a single batch.

required
batch_dim int

If you know that the batch dim for the batch you are concatenating is not the 0th dimension (for example it is sequence first) then supply that dimension.

0
batch_dim_key_defaults dictionary of keys to integers

If your batch is a dictionary and you know that some keys have non-standard (0) batch dimensions, supply those here. By default "token_logits" has batch dim 1 and otherwise all keys are assumed to have batch dim 0.

{'token_logits': 1}

Returns:

Type Description
Optional[ReductionT]

A single batch of the same type as the elements of your input sequence.

Source code in bionemo/llm/lightning.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def batch_collator(
    batches: Optional[Union[Tuple[ReductionT], List[ReductionT]]],
    batch_dim: int = 0,
    batch_dim_key_defaults: dict[str, int] = {"token_logits": 1},
) -> Optional[ReductionT]:
    """Takes a sequence of batches and collates them into a single batch.

        This is distinct from the standard pytorch default_collator since it does
        not add the batch dimension, it's assumed the batch
        dimension is already present in the input, as would be the case when
        parallelizing across minibatches.

    IMPORTANT: The underlying data primitive _must_ be a torch Tensor. The input to this function is a recurisve type,
    there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.

    Examples:
        Outer container = Dict:
            [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])}
        Outer container = List:
            [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])]
        Outer container = Tuple:
            ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))

    Args:
        batches (Optional[Sequence[ReductionT]]): sequence of batches to collate into a single batch.
        batch_dim: If you know that the batch dim for the batch you are concatenating is not the 0th dimension (for
            example it is sequence first) then supply that dimension.
        batch_dim_key_defaults (dictionary of keys to integers): If your batch is a dictionary and you know that some
            keys have non-standard (0) batch dimensions, supply those here. By default "token_logits" has batch dim 1
            and otherwise all keys are assumed to have batch dim 0.

    Returns:
        A single batch of the same type as the elements of your input sequence.
    """
    match batches:
        # Handle base-cases for batch concatenation, either a list of None or a list of tensors
        case [None, *_]:
            return None
        case [Tensor(), *_]:
            return torch.cat(batches, dim=batch_dim)
        # Next 3 calls are the recursive calls into the sub-structures of the batch. We handle dictionaries, tuples, and lists
        case [dict(), *_]:
            return {
                key: batch_collator(
                    [batch[key] for batch in batches],
                    batch_dim=batch_dim_key_defaults.get(key, 0),
                    batch_dim_key_defaults=batch_dim_key_defaults,
                )
                for key in batches[0]
            }
        case [tuple(), *_]:
            return tuple(
                batch_collator(
                    [batch[i] for batch in batches], batch_dim=batch_dim, batch_dim_key_defaults=batch_dim_key_defaults
                )
                for i in range(len(batches[0]))
            )
        case [list(), *_]:
            return [
                batch_collator(
                    [batch[i] for batch in batches], batch_dim=batch_dim, batch_dim_key_defaults=batch_dim_key_defaults
                )
                for i in range(len(batches[0]))
            ]
        # Final cases shouldn't happen, an empty sequence (no batches), or "other".
        case []:
            raise ValueError("Cannot process an empty sequence")
        case _:
            raise ValueError("Unsupported input structure in batch_collator")

default_megatron_optimizer()

Default distributed optimizer uses Adam with a 1e-4 learning rate.

Source code in bionemo/llm/lightning.py
365
366
367
368
369
def default_megatron_optimizer() -> MegatronOptimizerModule:
    """Default distributed optimizer uses Adam with a 1e-4 learning rate."""
    return MegatronOptimizerModule(
        config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True),
    )

some_first(seq)

Returns the first non-None value from the sequence or fails

Source code in bionemo/llm/lightning.py
51
52
53
54
55
56
def some_first(seq: Iterable[Optional[T]]) -> T:
    """Returns the first non-None value from the sequence or fails"""  # noqa: D415
    for s in seq:
        if s is not None:
            return s
    raise ValueError("non-None value not found")