Skip to content

Infer

infer_model(config, data_module, tokenizer=get_tokenizer())

Infers a BioNeMo ESM2 model using PyTorch Lightning.

Parameters:

Name Type Description Default
config ESM2GenericConfig

The configuration for the ESM2 model.

required
data_module LightningDataModule

The data module for training and validation.

required
tokenizer BioNeMoESMTokenizer

The tokenizer to use. Defaults to get_tokenizer().

get_tokenizer()

Returns:

Type Description
list[Tensor]

A list of tensors containing the predictions of predict_dataset in datamodule

Source code in bionemo/esm2/model/finetune/infer.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def infer_model(
    config: ESM2GenericConfig,
    data_module: pl.LightningDataModule,
    tokenizer: BioNeMoESMTokenizer = get_tokenizer(),
) -> list[Tensor]:
    """Infers a BioNeMo ESM2 model using PyTorch Lightning.

    Parameters:
        config: The configuration for the ESM2 model.
        data_module: The data module for training and validation.
        tokenizer: The tokenizer to use. Defaults to `get_tokenizer()`.

    Returns:
        A list of tensors containing the predictions of predict_dataset in datamodule
    """
    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=1, pipeline_model_parallel_size=1, ddp="megatron", find_unused_parameters=True
    )

    trainer = nl.Trainer(
        accelerator="gpu",
        devices=1,
        strategy=strategy,
        num_nodes=1,
        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
    )
    module = biobert_lightning_module(config=config, tokenizer=tokenizer)
    results = batch_collator(trainer.predict(module, datamodule=data_module))

    return results