Skip to content

Infer esm2

get_parser()

Return the cli parser for this tool.

Source code in bionemo/esm2/scripts/infer_esm2.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def get_parser():
    """Return the cli parser for this tool."""
    parser = argparse.ArgumentParser(description="Infer ESM2.")
    parser.add_argument(
        "--checkpoint-path",
        type=Path,
        required=True,
        help="Path to the ESM2 pretrained checkpoint",
    )
    parser.add_argument(
        "--data-path",
        type=Path,
        required=True,
        help="Path to the CSV file containing sequences and label columns",
    )
    parser.add_argument("--results-path", type=Path, required=True, help="Path to the results file.")

    parser.add_argument(
        "--precision",
        type=str,
        choices=get_args(PrecisionTypes),
        required=False,
        default="bf16-mixed",
        help="Precision type to use for training.",
    )
    parser.add_argument(
        "--num-gpus",
        type=int,
        required=False,
        default=1,
        help="Number of GPUs to use for training. Default is 1.",
    )
    parser.add_argument(
        "--num-nodes",
        type=int,
        required=False,
        default=1,
        help="Number of nodes to use for training. Default is 1.",
    )
    parser.add_argument(
        "--micro-batch-size",
        type=int,
        required=False,
        default=2,
        help="Micro-batch size. Global batch size is inferred from this.",
    )
    parser.add_argument(
        "--pipeline-model-parallel-size",
        type=int,
        required=False,
        default=1,
        help="Pipeline model parallel size. Default is 1.",
    )
    parser.add_argument(
        "--tensor-model-parallel-size",
        type=int,
        required=False,
        default=1,
        help="Tensor model parallel size. Default is 1.",
    )
    parser.add_argument(
        "--include-hiddens",
        action="store_true",
        default=False,
        help="Include hiddens in output of inference",
    )
    parser.add_argument(
        "--include-embeddings",
        action="store_true",
        default=False,
        help="Include embeddings in output of inference",
    )
    parser.add_argument(
        "--include-logits", action="store_true", default=False, help="Include per-token logits in output."
    )
    config_class_options: Dict[str, Type[BioBertConfig]] = SUPPORTED_CONFIGS

    def config_class_type(desc: str) -> Type[BioBertConfig]:
        try:
            return config_class_options[desc]
        except KeyError:
            raise argparse.ArgumentTypeError(
                f"Do not recognize key {desc}, valid options are: {config_class_options.keys()}"
            )

    parser.add_argument(
        "--config-class",
        type=config_class_type,
        default="ESM2Config",
        help="Model configs link model classes with losses, and handle model initialization (including from a prior "
        "checkpoint). This is how you can fine-tune a model. First train with one config class that points to one model "
        "class and loss, then implement and provide an alternative config class that points to a variant of that model "
        "and alternative loss. In the future this script should also provide similar support for picking different data "
        f"modules for fine-tuning with different data types. Choices: {config_class_options.keys()}",
    )
    return parser

infer_esm2_entrypoint()

Entrypoint for running inference on a geneformer checkpoint and data.

Source code in bionemo/esm2/scripts/infer_esm2.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def infer_esm2_entrypoint():
    """Entrypoint for running inference on a geneformer checkpoint and data."""
    # 1. get arguments
    parser = get_parser()
    args = parser.parse_args()
    # 2. Call infer with args
    infer_model(
        data_path=args.data_path,
        checkpoint_path=args.checkpoint_path,
        results_path=args.results_path,
        include_hiddens=args.include_hiddens,
        include_embeddings=args.include_embeddings,
        include_logits=args.include_logits,
        micro_batch_size=args.micro_batch_size,
        precision=args.precision,
        tensor_model_parallel_size=args.tensor_model_parallel_size,
        pipeline_model_parallel_size=args.pipeline_model_parallel_size,
        devices=args.num_gpus,
        num_nodes=args.num_nodes,
        config_class=args.config_class,
    )

infer_model(data_path, checkpoint_path, results_path, min_seq_length=1024, include_hiddens=False, include_embeddings=False, include_logits=False, micro_batch_size=64, precision='bf16-mixed', tensor_model_parallel_size=1, pipeline_model_parallel_size=1, devices=1, num_nodes=1, config_class=ESM2Config)

Runs inference on a BioNeMo ESM2 model using PyTorch Lightning.

Parameters:

Name Type Description Default
data_path Path

Path to the input data.

required
checkpoint_path Path

Path to the model checkpoint.

required
results_path Path

Path to save the inference results.

required
min_seq_length int

minimum sequence length to be padded. This should be at least equal to the length of largest sequence in the dataset

1024
include_hiddens bool

Whether to include hidden states in the output. Defaults to False.

False
include_embeddings bool

Whether to include embeddings in the output. Defaults to False.

False
include_logits (bool, Optional)

Whether to include token logits in the output. Defaults to False.

False
micro_batch_size int

Micro batch size for inference. Defaults to 64.

64
precision PrecisionTypes

Precision type for inference. Defaults to "bf16-mixed".

'bf16-mixed'
tensor_model_parallel_size int

Tensor model parallel size for distributed inference. Defaults to 1.

1
pipeline_model_parallel_size int

Pipeline model parallel size for distributed inference. Defaults to 1.

1
devices int

Number of devices to use for inference. Defaults to 1.

1
num_nodes int

Number of nodes to use for distributed inference. Defaults to 1.

1
config_class Type[BioBertConfig]

The config class for configuring the model using checkpoint provided

ESM2Config
Source code in bionemo/esm2/scripts/infer_esm2.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def infer_model(
    data_path: Path,
    checkpoint_path: Path,
    results_path: Path,
    min_seq_length: int = 1024,
    include_hiddens: bool = False,
    include_embeddings: bool = False,
    include_logits: bool = False,
    micro_batch_size: int = 64,
    precision: PrecisionTypes = "bf16-mixed",
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    devices: int = 1,
    num_nodes: int = 1,
    config_class: Type[BioBertConfig] = ESM2Config,
) -> None:
    """Runs inference on a BioNeMo ESM2 model using PyTorch Lightning.

    Args:
        data_path (Path): Path to the input data.
        checkpoint_path (Path): Path to the model checkpoint.
        results_path (Path): Path to save the inference results.
        min_seq_length (int): minimum sequence length to be padded. This should be at least equal to the length of largest sequence in the dataset
        include_hiddens (bool, optional): Whether to include hidden states in the output. Defaults to False.
        include_embeddings (bool, optional): Whether to include embeddings in the output. Defaults to False.
        include_logits (bool, Optional): Whether to include token logits in the output. Defaults to False.
        micro_batch_size (int, optional): Micro batch size for inference. Defaults to 64.
        precision (PrecisionTypes, optional): Precision type for inference. Defaults to "bf16-mixed".
        tensor_model_parallel_size (int, optional): Tensor model parallel size for distributed inference. Defaults to 1.
        pipeline_model_parallel_size (int, optional): Pipeline model parallel size for distributed inference. Defaults to 1.
        devices (int, optional): Number of devices to use for inference. Defaults to 1.
        num_nodes (int, optional): Number of nodes to use for distributed inference. Defaults to 1.
        config_class (Type[BioBertConfig]): The config class for configuring the model using checkpoint provided
    """
    if os.path.isdir(results_path):
        results_path = results_path / "esm2_inference_results.pt"
    else:
        _, extension = os.path.splitext(results_path)
        results_path = results_path if extension == ".pt" else results_path / ".pt"

    # Setup the strategy and trainer
    global_batch_size = infer_global_batch_size(
        micro_batch_size=micro_batch_size,
        num_nodes=num_nodes,
        devices=devices,
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
    )

    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
        ddp="megatron",
        find_unused_parameters=True,
    )

    trainer = nl.Trainer(
        accelerator="gpu",
        devices=devices,
        strategy=strategy,
        num_nodes=num_nodes,
        callbacks=[],  # TODO: @farhadr Add PredictionWriter for DDP
        plugins=nl.MegatronMixedPrecision(precision=precision),
    )

    dataset = InMemoryCSVDataset(data_path=data_path)
    datamodule = ESM2FineTuneDataModule(
        predict_dataset=dataset,
        micro_batch_size=micro_batch_size,
        global_batch_size=global_batch_size,
        min_seq_length=min_seq_length,
    )

    config = config_class(
        params_dtype=get_autocast_dtype(precision),
        pipeline_dtype=get_autocast_dtype(precision),
        autocast_dtype=get_autocast_dtype(precision),
        include_hiddens=include_hiddens,
        include_embeddings=include_embeddings,
        skip_logits=not include_logits,
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
        initial_ckpt_path=str(checkpoint_path),
        initial_ckpt_skip_keys_with_these_prefixes=[],  # load everything from the checkpoint.
    )

    tokenizer = get_tokenizer()
    module = biobert_lightning_module(config=config, tokenizer=tokenizer)

    predictions = trainer.predict(module, datamodule=datamodule, return_predictions=True)
    results_dict = batch_collator(predictions)
    non_none_keys = [key for key, val in results_dict.items() if val is not None]
    print(f"Writing output {str(non_none_keys)} into {results_path}")
    torch.save(results_dict, results_path)