Skip to content

Lightning

DataStep = Callable[[Iterator[DataT]], DataT] module-attribute

Batches together an iterator of individual examples.

Necessary for compatability with Megatron. This function type is similiar to the collate function of PyTorch.

A DataStep function takes an iterator over individual examples. Each example may be a tensor, sequence of tensors, or a set of named tensors (provided as a dict mapping str names to each Tensor). Each iteration must yield the same type.

The output of this function will mirror the same structure of each yielded example. It will be a concatenation of all of the examples in the iterator.

ForwardStep = Callable[[MegatronModelType, DataT], DataT] module-attribute

Megatron-compatible forward pass function.

BionemoLightningModule

Bases: Generic[MegatronModelType, MegatronLossType], LightningModule, IOMixin, ConnectorMixin, LightningPassthroughPredictionMixin

Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions.

Source code in bionemo/llm/lightning.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
class BionemoLightningModule(
    Generic[MegatronModelType, MegatronLossType],
    pl.LightningModule,
    nlio.IOMixin,
    nlio.ConnectorMixin,
    LightningPassthroughPredictionMixin,
):
    """Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions."""

    def __init__(
        self,
        config: BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
        forward_step: ForwardStep,
        data_step: DataStep,
        optimizer: MegatronOptimizerModule,
        model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
        **model_construct_args,
    ) -> None:
        """Constructor.

        Args:
            config: Serializable configuration object that allows one to construct a new model instance and loss
                function. Necessary for Megatron-based training as the model itself cannot be serialized and
                distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.
            forward_step: Performs forward pass using the model and a batch of data.
            data_step: Custom batch-creating function for the model.
            optimizer: Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning
                rate.
            model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
                `configure_model` method.
            model_transform: Optional. The model transform function.
            **model_construct_args: Optional. Arguments necessary for the supplied model configuration's
                `configure_model` method, which will make an instance of the model.
        """
        super().__init__()
        self.config = config
        self.module_construct_args: Optional[dict[str, Any]] = model_construct_args
        # ***must** be set up in configure_model() -- megatron constraint
        # also, must be called `module`: nemo expects the actual model to be stored this way
        self.module: Optional[MegatronModelType] = None
        self.loss_reduction_class: type[MegatronLossType] = config.get_loss_reduction_class()
        self.optim = optimizer
        self.optim.connect(self)  # This will bind the `configure_optimizers` method
        self._data_step = data_step
        self._forward_step = forward_step
        self.model_transform = model_transform

        # configure metrics
        self.train_metric = self.config.train_metric.get_instance() if self.config.train_metric else None
        self.valid_metric = self.config.valid_metric.get_instance() if self.config.valid_metric else None

    def configure_model(self) -> None:
        """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

        NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

        Raises:
            ValueError iff the internal config's configure_model method returns None.
        """
        if self.module is None:
            model: MegatronModelType = (
                self.config.configure_model(**self.module_construct_args)
                if self.module_construct_args is not None
                else self.config.configure_model()
            )
            self.module = model

        if self.module is None:
            raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

    def is_on_logging_device(self):
        """Return True if last stage of pipeline parallel and first tensor parallel rank."""
        return parallel_state.is_pipeline_last_stage() and parallel_state.get_tensor_model_parallel_rank() == 0

    def forward(self, *args, **kwargs) -> DataT:
        """Call the forward method of the underlying model, and return whatever it outputs."""
        # safe to do because configure_model is idempotent
        self.configure_model()
        assert self.module is not None, "ERROR: configure_model() method has been incorrectly overridden!"
        prediction = self.module(*args, **kwargs)  # for now just pass through to the underlying model
        return prediction

    def data_step(self, dataloader_iter: Iterator[DataT]) -> DataT:  # noqa: D102
        return self._data_step(dataloader_iter)

    def forward_step(self, batch) -> Tensor:
        """Megatron-required: the training forward step for the model, which is required to produce the loss.

        Normally, the forward pass of a model means its inference. Loss is computed using the predictions
        from the forward pass against labels. Megatron unfortunately conflates these two different concepts
        and instead has models "forward" method produce the loss. See the Megatron docs for details:
        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

        To get actual predictions, use the :func:`forward` method instead.
        """
        # safe to do because configure_model is idempotent
        self.configure_model()
        assert self.module is not None
        return self._forward_step(self.module, batch)

    def update_metric(
        self, batch, outputs, metric, task: Literal["pretraining", "classification", "regression"]
    ) -> None:
        """Update metric for logging."""
        match task:
            case "pretraining":
                logits = outputs["token_logits"].detach().transpose(0, 1)  #  [s, b, v] -> [b, s, v]
                metric(logits, batch["labels"])
            case "classification":
                classification_output = outputs["classification_output"]
                num_classes = classification_output.shape[-1]
                metric(
                    classification_output.reshape(-1, num_classes),
                    batch["labels"].reshape(-1),
                )
            case "regression":
                regression_output = outputs["regression_output"]
                metric(regression_output, batch["labels"])
            case _:
                raise NotImplementedError(f"unrecognized task {task}")

    def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """In mcore the loss-function is part of the forward-pass when labels are provided."""
        outputs = self.forward_step(batch)
        if self.train_metric is not None:
            if self.is_on_logging_device():
                self.update_metric(batch, outputs, self.train_metric, self.config.train_metric.task)

            self.log(
                self.config.train_metric.metric_name,
                self.train_metric,
                on_step=True,
                on_epoch=False,
                prog_bar=True,
            )

        return outputs

    def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """In mcore the loss-function is part of the forward-pass when labels are provided."""
        outputs = self.forward_step(batch)
        if self.valid_metric is not None and self.is_on_logging_device():
            self.update_metric(batch, outputs, self.valid_metric, self.config.valid_metric.task)

        return outputs

    def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
        """Alias for forward_step."""
        if len(batch) == 0:
            return
        return self.forward_step(batch)

    def training_loss_reduction(self) -> MegatronLossType:
        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss."""
        return self.loss_reduction_class()

    def validation_loss_reduction(self) -> MegatronLossType:  # noqa: D102
        return self.loss_reduction_class(validation_step=True)

    def test_loss_reduction(self) -> MegatronLossType:  # noqa: D102
        return self.loss_reduction_class(validation_step=True)

    def on_validation_epoch_end(self):  # noqa: D102
        if self.valid_metric is None:
            return

        if self.trainer.sanity_checking:
            self.valid_metric.reset()  # clean up sanity runs
            return

        self.log(
            self.config.valid_metric.metric_name,
            self.valid_metric,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )

__init__(config, forward_step, data_step, optimizer, model_transform=None, **model_construct_args)

Constructor.

Parameters:

Name Type Description Default
config BionemoTrainableModelConfig[MegatronModelType, MegatronLossType]

Serializable configuration object that allows one to construct a new model instance and loss function. Necessary for Megatron-based training as the model itself cannot be serialized and distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.

required
forward_step ForwardStep

Performs forward pass using the model and a batch of data.

required
data_step DataStep

Custom batch-creating function for the model.

required
optimizer MegatronOptimizerModule

Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning rate.

required
model_construct_args

Optional. Any arguments necessary to construct the model in the config's configure_model method.

{}
model_transform Optional[Callable[[MegatronModelType], MegatronModelType]]

Optional. The model transform function.

None
**model_construct_args

Optional. Arguments necessary for the supplied model configuration's configure_model method, which will make an instance of the model.

{}
Source code in bionemo/llm/lightning.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def __init__(
    self,
    config: BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
    forward_step: ForwardStep,
    data_step: DataStep,
    optimizer: MegatronOptimizerModule,
    model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
    **model_construct_args,
) -> None:
    """Constructor.

    Args:
        config: Serializable configuration object that allows one to construct a new model instance and loss
            function. Necessary for Megatron-based training as the model itself cannot be serialized and
            distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.
        forward_step: Performs forward pass using the model and a batch of data.
        data_step: Custom batch-creating function for the model.
        optimizer: Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning
            rate.
        model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
            `configure_model` method.
        model_transform: Optional. The model transform function.
        **model_construct_args: Optional. Arguments necessary for the supplied model configuration's
            `configure_model` method, which will make an instance of the model.
    """
    super().__init__()
    self.config = config
    self.module_construct_args: Optional[dict[str, Any]] = model_construct_args
    # ***must** be set up in configure_model() -- megatron constraint
    # also, must be called `module`: nemo expects the actual model to be stored this way
    self.module: Optional[MegatronModelType] = None
    self.loss_reduction_class: type[MegatronLossType] = config.get_loss_reduction_class()
    self.optim = optimizer
    self.optim.connect(self)  # This will bind the `configure_optimizers` method
    self._data_step = data_step
    self._forward_step = forward_step
    self.model_transform = model_transform

    # configure metrics
    self.train_metric = self.config.train_metric.get_instance() if self.config.train_metric else None
    self.valid_metric = self.config.valid_metric.get_instance() if self.config.valid_metric else None

configure_model()

Updates internal state: instantiates the model from the object's config, assigns to model attribute.

NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

Source code in bionemo/llm/lightning.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def configure_model(self) -> None:
    """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

    NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

    Raises:
        ValueError iff the internal config's configure_model method returns None.
    """
    if self.module is None:
        model: MegatronModelType = (
            self.config.configure_model(**self.module_construct_args)
            if self.module_construct_args is not None
            else self.config.configure_model()
        )
        self.module = model

    if self.module is None:
        raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

forward(*args, **kwargs)

Call the forward method of the underlying model, and return whatever it outputs.

Source code in bionemo/llm/lightning.py
319
320
321
322
323
324
325
def forward(self, *args, **kwargs) -> DataT:
    """Call the forward method of the underlying model, and return whatever it outputs."""
    # safe to do because configure_model is idempotent
    self.configure_model()
    assert self.module is not None, "ERROR: configure_model() method has been incorrectly overridden!"
    prediction = self.module(*args, **kwargs)  # for now just pass through to the underlying model
    return prediction

forward_step(batch)

Megatron-required: the training forward step for the model, which is required to produce the loss.

Normally, the forward pass of a model means its inference. Loss is computed using the predictions from the forward pass against labels. Megatron unfortunately conflates these two different concepts and instead has models "forward" method produce the loss. See the Megatron docs for details: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

To get actual predictions, use the :func:forward method instead.

Source code in bionemo/llm/lightning.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def forward_step(self, batch) -> Tensor:
    """Megatron-required: the training forward step for the model, which is required to produce the loss.

    Normally, the forward pass of a model means its inference. Loss is computed using the predictions
    from the forward pass against labels. Megatron unfortunately conflates these two different concepts
    and instead has models "forward" method produce the loss. See the Megatron docs for details:
    https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

    To get actual predictions, use the :func:`forward` method instead.
    """
    # safe to do because configure_model is idempotent
    self.configure_model()
    assert self.module is not None
    return self._forward_step(self.module, batch)

is_on_logging_device()

Return True if last stage of pipeline parallel and first tensor parallel rank.

Source code in bionemo/llm/lightning.py
315
316
317
def is_on_logging_device(self):
    """Return True if last stage of pipeline parallel and first tensor parallel rank."""
    return parallel_state.is_pipeline_last_stage() and parallel_state.get_tensor_model_parallel_rank() == 0

predict_step(batch, batch_idx=None)

Alias for forward_step.

Source code in bionemo/llm/lightning.py
391
392
393
394
395
def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """Alias for forward_step."""
    if len(batch) == 0:
        return
    return self.forward_step(batch)

training_loss_reduction()

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

Source code in bionemo/llm/lightning.py
397
398
399
def training_loss_reduction(self) -> MegatronLossType:
    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss."""
    return self.loss_reduction_class()

training_step(batch, batch_idx=None)

In mcore the loss-function is part of the forward-pass when labels are provided.

Source code in bionemo/llm/lightning.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """In mcore the loss-function is part of the forward-pass when labels are provided."""
    outputs = self.forward_step(batch)
    if self.train_metric is not None:
        if self.is_on_logging_device():
            self.update_metric(batch, outputs, self.train_metric, self.config.train_metric.task)

        self.log(
            self.config.train_metric.metric_name,
            self.train_metric,
            on_step=True,
            on_epoch=False,
            prog_bar=True,
        )

    return outputs

update_metric(batch, outputs, metric, task)

Update metric for logging.

Source code in bionemo/llm/lightning.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def update_metric(
    self, batch, outputs, metric, task: Literal["pretraining", "classification", "regression"]
) -> None:
    """Update metric for logging."""
    match task:
        case "pretraining":
            logits = outputs["token_logits"].detach().transpose(0, 1)  #  [s, b, v] -> [b, s, v]
            metric(logits, batch["labels"])
        case "classification":
            classification_output = outputs["classification_output"]
            num_classes = classification_output.shape[-1]
            metric(
                classification_output.reshape(-1, num_classes),
                batch["labels"].reshape(-1),
            )
        case "regression":
            regression_output = outputs["regression_output"]
            metric(regression_output, batch["labels"])
        case _:
            raise NotImplementedError(f"unrecognized task {task}")

validation_step(batch, batch_idx=None)

In mcore the loss-function is part of the forward-pass when labels are provided.

Source code in bionemo/llm/lightning.py
383
384
385
386
387
388
389
def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
    """In mcore the loss-function is part of the forward-pass when labels are provided."""
    outputs = self.forward_step(batch)
    if self.valid_metric is not None and self.is_on_logging_device():
        self.update_metric(batch, outputs, self.valid_metric, self.config.valid_metric.task)

    return outputs

LightningPassthroughPredictionMixin

A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism.

Source code in bionemo/llm/lightning.py
219
220
221
222
223
224
class LightningPassthroughPredictionMixin:
    """A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism."""

    def predict_loss_reduction(self) -> PassthroughLossReduction:
        """For the predict step, pass through the forward pass output."""
        return PassthroughLossReduction()

predict_loss_reduction()

For the predict step, pass through the forward pass output.

Source code in bionemo/llm/lightning.py
222
223
224
def predict_loss_reduction(self) -> PassthroughLossReduction:
    """For the predict step, pass through the forward pass output."""
    return PassthroughLossReduction()

PassthroughLossReduction

Bases: MegatronLossReduction, Generic[DataT]

A workaround for nemo/megatron to perform inference.

Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of tensors. The inner type must always be a Tensor.

Source code in bionemo/llm/lightning.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class PassthroughLossReduction(MegatronLossReduction, Generic[DataT]):
    """A workaround for nemo/megatron to perform inference.

    Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is
    expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed
    as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of
    forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of
    tensors. The inner type _must always be a Tensor_.
    """

    def forward(self, batch: DataT, forward_out: DataT) -> Tuple[Tensor, DataT]:
        """Passes through the `forward_out` value as the 2nd tuple element.

        Args:
            batch: The batch of data that was passed through the model to generate output. NOTE: this value is ignored.
            forward_out: The output from your model's forward pass.

        Returns:
            A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).
        """
        return torch.zeros((1, 1)), forward_out

    def reduce(self, forward_out: List[DataT]) -> DataT:
        """Collates list of model's outputs into a single output."""
        return batch_collator(forward_out)

forward(batch, forward_out)

Passes through the forward_out value as the 2nd tuple element.

Parameters:

Name Type Description Default
batch DataT

The batch of data that was passed through the model to generate output. NOTE: this value is ignored.

required
forward_out DataT

The output from your model's forward pass.

required

Returns:

Type Description
Tuple[Tensor, DataT]

A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).

Source code in bionemo/llm/lightning.py
202
203
204
205
206
207
208
209
210
211
212
def forward(self, batch: DataT, forward_out: DataT) -> Tuple[Tensor, DataT]:
    """Passes through the `forward_out` value as the 2nd tuple element.

    Args:
        batch: The batch of data that was passed through the model to generate output. NOTE: this value is ignored.
        forward_out: The output from your model's forward pass.

    Returns:
        A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).
    """
    return torch.zeros((1, 1)), forward_out

reduce(forward_out)

Collates list of model's outputs into a single output.

Source code in bionemo/llm/lightning.py
214
215
216
def reduce(self, forward_out: List[DataT]) -> DataT:
    """Collates list of model's outputs into a single output."""
    return batch_collator(forward_out)

batch_collator(batches, batch_dim=0, seq_dim=1, batch_dim_key_defaults={'token_logits': 1}, seq_dim_key_defaults={'token_logits': 0})

Takes a sequence of batches and collates them into a single batch.

This is distinct from the standard pytorch default_collator since it does
not add the batch dimension, it's assumed the batch
dimension is already present in the input, as would be the case when
parallelizing across minibatches.

IMPORTANT: The underlying data primitive must be a torch Tensor. The input to this function is a recurisve type, there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.

Examples:

Outer container = Dict: [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])} Outer container = List: [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])] Outer container = Tuple: ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))

Parameters:

Name Type Description Default
batches Optional[Sequence[ReductionT]]

sequence of batches to collate into a single batch.

required
batch_dim int

If you know that the batch dim for the batch you are concatenating is not the 0th dimension (for example it is sequence first) then supply that dimension.

0
seq_dim int

If you know that the sequence dim for the batch you are concatenating is not the 1st dimension (for example it is sequence first) then supply that dimension. This is used for padding to the max length.

1
batch_dim_key_defaults dictionary of keys to integers

If your batch is a dictionary and you know that some keys have non-standard (0) batch dimensions, supply those here. By default "token_logits" has batch dim 1 and otherwise all keys are assumed to have batch dim 0.

{'token_logits': 1}
seq_dim_key_defaults dictionary of keys to integers

If your batch is a dictionary and you know that some keys have non-standard (1) sequence dimensions, supply those here. By default "token_logits" has seq dim 0 and otherwise all keys are assumed to have seq dim 1.

{'token_logits': 0}

Returns:

Type Description
Optional[ReductionT]

A single batch of the same type as the elements of your input sequence.

Source code in bionemo/llm/lightning.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def batch_collator(
    batches: Optional[Union[Tuple[ReductionT], List[ReductionT]]],
    batch_dim: int = 0,
    seq_dim: int = 1,
    batch_dim_key_defaults: dict[str, int] = {"token_logits": 1},
    seq_dim_key_defaults: dict[str, int] = {"token_logits": 0},
) -> Optional[ReductionT]:
    """Takes a sequence of batches and collates them into a single batch.

        This is distinct from the standard pytorch default_collator since it does
        not add the batch dimension, it's assumed the batch
        dimension is already present in the input, as would be the case when
        parallelizing across minibatches.

    IMPORTANT: The underlying data primitive _must_ be a torch Tensor. The input to this function is a recurisve type,
    there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.

    Examples:
        Outer container = Dict:
            [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])}
        Outer container = List:
            [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])]
        Outer container = Tuple:
            ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))

    Args:
        batches (Optional[Sequence[ReductionT]]): sequence of batches to collate into a single batch.
        batch_dim: If you know that the batch dim for the batch you are concatenating is not the 0th dimension (for
            example it is sequence first) then supply that dimension.
        seq_dim: If you know that the sequence dim for the batch you are concatenating is not the 1st dimension (for
            example it is sequence first) then supply that dimension. This is used for padding to the max length.
        batch_dim_key_defaults (dictionary of keys to integers): If your batch is a dictionary and you know that some
            keys have non-standard (0) batch dimensions, supply those here. By default "token_logits" has batch dim 1
            and otherwise all keys are assumed to have batch dim 0.
        seq_dim_key_defaults (dictionary of keys to integers): If your batch is a dictionary and you know that some
            keys have non-standard (1) sequence dimensions, supply those here. By default "token_logits" has seq dim 0
            and otherwise all keys are assumed to have seq dim 1.

    Returns:
        A single batch of the same type as the elements of your input sequence.
    """
    match batches:
        # Handle base-cases for batch concatenation, either a list of None or a list of tensors
        case [None, *_]:
            return None
        case [Tensor(), *_]:
            # First shortcut if all tensors are 1D (they have at least one batch dim, and it must be at 0)
            if len(batches) > 0 and isinstance(batches[0], Tensor) and batches[0].ndim == 1:
                return torch.cat(batches, dim=0)
            # Find max sequence length across all tensors
            max_seq_len = max(batch.size(seq_dim) for batch in batches)
            # Pad each tensor to max length along seq_dim
            padded_batches = []
            for batch in batches:
                # Initialize padding tuple - needs 2 values per dim, starting from last dim
                # e.g. for 3D tensor: [left_pad_dim2, right_pad_dim2, left_pad_dim1, right_pad_dim1, left_pad_dim0, right_pad_dim0]
                pad_size = [0] * (2 * batch.ndim)
                # Calculate padding needed at end of sequence dimension
                pad_amount = max_seq_len - batch.size(seq_dim)
                # Pad end of sequence dimension by putting padding amount in correct position
                # For seq_dim=1 in 3D tensor: [0, 0, 0, pad_amount, 0, 0]
                pad_size[2 * (batch.ndim - 1 - seq_dim) + 1] = pad_amount
                padded_batch = torch.nn.functional.pad(batch, tuple(pad_size))
                padded_batches.append(padded_batch)
            padded_batch = torch.cat(padded_batches, dim=batch_dim)
            assert padded_batch.size(seq_dim) == max_seq_len
            return padded_batch
        # Next 3 calls are the recursive calls into the sub-structures of the batch. We handle dictionaries, tuples, and lists
        case [dict(), *_]:
            return {
                key: batch_collator(
                    [batch[key] for batch in batches],
                    batch_dim=batch_dim_key_defaults.get(key, batch_dim),
                    seq_dim=seq_dim_key_defaults.get(key, seq_dim),
                    batch_dim_key_defaults=batch_dim_key_defaults,
                    seq_dim_key_defaults=seq_dim_key_defaults,
                )
                for key in batches[0]
            }
        case [tuple(), *_]:
            return tuple(
                batch_collator(
                    [batch[i] for batch in batches],
                    batch_dim=batch_dim,
                    seq_dim=seq_dim,
                    batch_dim_key_defaults=batch_dim_key_defaults,
                    seq_dim_key_defaults=seq_dim_key_defaults,
                )
                for i in range(len(batches[0]))
            )
        case [list(), *_]:
            return [
                batch_collator(
                    [batch[i] for batch in batches],
                    batch_dim=batch_dim,
                    seq_dim=seq_dim,
                    batch_dim_key_defaults=batch_dim_key_defaults,
                    seq_dim_key_defaults=seq_dim_key_defaults,
                )
                for i in range(len(batches[0]))
            ]
        # Final cases shouldn't happen, an empty sequence (no batches), or "other".
        case []:
            raise ValueError("Cannot process an empty sequence")
        case _:
            raise ValueError("Unsupported input structure in batch_collator")

default_megatron_optimizer()

Default distributed optimizer uses Adam with a 1e-4 learning rate.

Source code in bionemo/llm/lightning.py
424
425
426
427
428
def default_megatron_optimizer() -> MegatronOptimizerModule:
    """Default distributed optimizer uses Adam with a 1e-4 learning rate."""
    return MegatronOptimizerModule(
        config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True),
    )

some_first(seq)

Returns the first non-None value from the sequence or fails

Source code in bionemo/llm/lightning.py
49
50
51
52
53
54
def some_first(seq: Iterable[Optional[T]]) -> T:
    """Returns the first non-None value from the sequence or fails"""  # noqa: D415
    for s in seq:
        if s is not None:
            return s
    raise ValueError("non-None value not found")