Skip to content

Infer esm2

get_parser()

Return the cli parser for this tool.

Source code in bionemo/esm2/scripts/infer_esm2.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def get_parser():
    """Return the cli parser for this tool."""
    parser = argparse.ArgumentParser(description="Infer ESM2.")
    parser.add_argument(
        "--checkpoint-path",
        type=Path,
        required=True,
        help="Path to the ESM2 pretrained checkpoint",
    )
    parser.add_argument(
        "--data-path",
        type=Path,
        required=True,
        help="Path to the CSV file containing sequences and label columns",
    )
    parser.add_argument("--results-path", type=Path, required=True, help="Path to the results directory.")

    parser.add_argument(
        "--precision",
        type=str,
        choices=get_args(PrecisionTypes),
        required=False,
        default="bf16-mixed",
        help="Precision type to use for training.",
    )
    parser.add_argument(
        "--num-gpus",
        type=int,
        required=False,
        default=1,
        help="Number of GPUs to use for training. Default is 1.",
    )
    parser.add_argument(
        "--num-nodes",
        type=int,
        required=False,
        default=1,
        help="Number of nodes to use for training. Default is 1.",
    )
    parser.add_argument(
        "--micro-batch-size",
        type=int,
        required=False,
        default=2,
        help="Micro-batch size. Global batch size is inferred from this.",
    )
    parser.add_argument(
        "--pipeline-model-parallel-size",
        type=int,
        required=False,
        default=1,
        help="Pipeline model parallel size. Default is 1.",
    )
    parser.add_argument(
        "--tensor-model-parallel-size",
        type=int,
        required=False,
        default=1,
        help="Tensor model parallel size. Default is 1.",
    )
    parser.add_argument(
        "--prediction-interval",
        type=str,
        required=False,
        choices=get_args(IntervalT),
        default="epoch",
        help="Intervals to write DDP predictions into disk",
    )
    parser.add_argument(
        "--include-hiddens",
        action="store_true",
        default=False,
        help="Include hiddens in output of inference",
    )
    parser.add_argument(
        "--include-input-ids",
        action="store_true",
        default=False,
        help="Include input_ids in output of inference",
    )
    parser.add_argument(
        "--include-embeddings",
        action="store_true",
        default=False,
        help="Include embeddings in output of inference",
    )
    parser.add_argument(
        "--include-logits", action="store_true", default=False, help="Include per-token logits in output."
    )
    config_class_options: Dict[str, Type[BioBertConfig]] = SUPPORTED_CONFIGS

    def config_class_type(desc: str) -> Type[BioBertConfig]:
        try:
            return config_class_options[desc]
        except KeyError:
            raise argparse.ArgumentTypeError(
                f"Do not recognize key {desc}, valid options are: {config_class_options.keys()}"
            )

    parser.add_argument(
        "--config-class",
        type=config_class_type,
        default="ESM2Config",
        help="Model configs link model classes with losses, and handle model initialization (including from a prior "
        "checkpoint). This is how you can fine-tune a model. First train with one config class that points to one model "
        "class and loss, then implement and provide an alternative config class that points to a variant of that model "
        "and alternative loss. In the future this script should also provide similar support for picking different data "
        f"modules for fine-tuning with different data types. Choices: {config_class_options.keys()}",
    )
    return parser

infer_esm2_entrypoint()

Entrypoint for running inference on a geneformer checkpoint and data.

Source code in bionemo/esm2/scripts/infer_esm2.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def infer_esm2_entrypoint():
    """Entrypoint for running inference on a geneformer checkpoint and data."""
    # 1. get arguments
    parser = get_parser()
    args = parser.parse_args()
    # 2. Call infer with args
    infer_model(
        data_path=args.data_path,
        checkpoint_path=args.checkpoint_path,
        results_path=args.results_path,
        include_hiddens=args.include_hiddens,
        include_embeddings=args.include_embeddings,
        include_logits=args.include_logits,
        include_input_ids=args.include_input_ids,
        micro_batch_size=args.micro_batch_size,
        precision=args.precision,
        tensor_model_parallel_size=args.tensor_model_parallel_size,
        pipeline_model_parallel_size=args.pipeline_model_parallel_size,
        devices=args.num_gpus,
        num_nodes=args.num_nodes,
        config_class=args.config_class,
    )

infer_model(data_path, checkpoint_path, results_path, min_seq_length=1024, include_hiddens=False, include_embeddings=False, include_logits=False, include_input_ids=False, micro_batch_size=64, precision='bf16-mixed', tensor_model_parallel_size=1, pipeline_model_parallel_size=1, devices=1, num_nodes=1, prediction_interval='epoch', config_class=ESM2Config)

Runs inference on a BioNeMo ESM2 model using PyTorch Lightning.

Parameters:

Name Type Description Default
data_path Path

Path to the input data.

required
checkpoint_path Path

Path to the model checkpoint.

required
results_path Path

Path to save the inference results.

required
min_seq_length int

minimum sequence length to be padded. This should be at least equal to the length of largest sequence in the dataset

1024
include_hiddens bool

Whether to include hidden states in the output. Defaults to False.

False
include_embeddings bool

Whether to include embeddings in the output. Defaults to False.

False
include_logits (bool, Optional)

Whether to include token logits in the output. Defaults to False.

False
include_input_ids (bool, Optional)

Whether to include input_ids in the output. Defaults to False.

False
micro_batch_size int

Micro batch size for inference. Defaults to 64.

64
precision PrecisionTypes

Precision type for inference. Defaults to "bf16-mixed".

'bf16-mixed'
tensor_model_parallel_size int

Tensor model parallel size for distributed inference. Defaults to 1.

1
pipeline_model_parallel_size int

Pipeline model parallel size for distributed inference. Defaults to 1.

1
devices int

Number of devices to use for inference. Defaults to 1.

1
num_nodes int

Number of nodes to use for distributed inference. Defaults to 1.

1
prediction_interval IntervalT

Intervals to write predict method output into disck for DDP inference. Defaults to epoch.

'epoch'
config_class Type[BioBertConfig]

The config class for configuring the model using checkpoint provided

ESM2Config
Source code in bionemo/esm2/scripts/infer_esm2.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def infer_model(
    data_path: Path,
    checkpoint_path: Path,
    results_path: Path,
    min_seq_length: int = 1024,
    include_hiddens: bool = False,
    include_embeddings: bool = False,
    include_logits: bool = False,
    include_input_ids: bool = False,
    micro_batch_size: int = 64,
    precision: PrecisionTypes = "bf16-mixed",
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    devices: int = 1,
    num_nodes: int = 1,
    prediction_interval: IntervalT = "epoch",
    config_class: Type[BioBertConfig] = ESM2Config,
) -> None:
    """Runs inference on a BioNeMo ESM2 model using PyTorch Lightning.

    Args:
        data_path (Path): Path to the input data.
        checkpoint_path (Path): Path to the model checkpoint.
        results_path (Path): Path to save the inference results.
        min_seq_length (int): minimum sequence length to be padded. This should be at least equal to the length of largest sequence in the dataset
        include_hiddens (bool, optional): Whether to include hidden states in the output. Defaults to False.
        include_embeddings (bool, optional): Whether to include embeddings in the output. Defaults to False.
        include_logits (bool, Optional): Whether to include token logits in the output. Defaults to False.
        include_input_ids (bool, Optional): Whether to include input_ids in the output. Defaults to False.
        micro_batch_size (int, optional): Micro batch size for inference. Defaults to 64.
        precision (PrecisionTypes, optional): Precision type for inference. Defaults to "bf16-mixed".
        tensor_model_parallel_size (int, optional): Tensor model parallel size for distributed inference. Defaults to 1.
        pipeline_model_parallel_size (int, optional): Pipeline model parallel size for distributed inference. Defaults to 1.
        devices (int, optional): Number of devices to use for inference. Defaults to 1.
        num_nodes (int, optional): Number of nodes to use for distributed inference. Defaults to 1.
        prediction_interval (IntervalT, optional): Intervals to write predict method output into disck for DDP inference. Defaults to epoch.
        config_class (Type[BioBertConfig]): The config class for configuring the model using checkpoint provided
    """
    # create the directory to save the inference results
    os.makedirs(results_path, exist_ok=True)

    # Setup the strategy and trainer
    global_batch_size = infer_global_batch_size(
        micro_batch_size=micro_batch_size,
        num_nodes=num_nodes,
        devices=devices,
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
    )

    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
        ddp="megatron",
        find_unused_parameters=True,
    )

    prediction_writer = PredictionWriter(output_dir=results_path, write_interval=prediction_interval)

    trainer = nl.Trainer(
        accelerator="gpu",
        devices=devices,
        strategy=strategy,
        num_nodes=num_nodes,
        callbacks=[prediction_writer],
        plugins=nl.MegatronMixedPrecision(precision=precision),
    )

    dataset = InMemoryProteinDataset.from_csv(data_path)
    datamodule = ESM2FineTuneDataModule(
        predict_dataset=dataset,
        micro_batch_size=micro_batch_size,
        global_batch_size=global_batch_size,
        min_seq_length=min_seq_length,
    )

    config = config_class(
        params_dtype=get_autocast_dtype(precision),
        pipeline_dtype=get_autocast_dtype(precision),
        autocast_dtype=get_autocast_dtype(precision),
        include_hiddens=include_hiddens,
        include_embeddings=include_embeddings,
        include_input_ids=include_input_ids,
        skip_logits=not include_logits,
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
        initial_ckpt_path=str(checkpoint_path),
        initial_ckpt_skip_keys_with_these_prefixes=[],  # load everything from the checkpoint.
    )

    tokenizer = get_tokenizer()
    module = biobert_lightning_module(config=config, tokenizer=tokenizer)

    # datamodule is responsible for transforming dataloaders by adding MegatronDataSampler. Alternatively, to
    # directly use dataloader in predict method, the data sampler should be included in MegatronStrategy
    trainer.predict(module, datamodule=datamodule)  # return_predictions=False failing due to a lightning bug