ESM-2 Inference¶
This tutorial serves as a demo for ESM2 Inference using a CSV file with sequences
column. To pre-train the ESM2 model please refer to ESM-2 Pretraining tutorial.
%%capture --no-display --no-stderr cell_outputto suppress this output. Comment or delete this line in the cells below to restore full output.
Setup and Assumptions¶
In this tutorial, we will demonstrate how to download ESM2 checkpoint, create a CSV file with protein sequences, and infer a ESM-2 model.
All commands should be executed inside the BioNeMo docker container, which has all ESM-2 dependencies pre-installed. For more information on how to build or pull the BioNeMo2 container, refer to the Initialization Guide.
Import Required Libraries¶
%%capture --no-display --no-stderr cell_output
import os
import torch
import shutil
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
Work Directory¶
Set the work directory to store data and results:
cleanup : bool = True
cleanup : bool = True
work_dir="/workspace/bionemo2/esm2_inference_tutorial"
if cleanup and os.path.exists(work_dir):
shutil.rmtree(work_dir)
if not os.path.exists(work_dir):
os.makedirs(work_dir)
print(f"Directory '{work_dir}' created.")
else:
print(f"Directory '{work_dir}' already exists.")
Directory '/workspace/bionemo2/esm2_inference_tutorial' created.
Download Model Checkpoints¶
The following code will download the pre-trained model, esm2n/650m:2.0
, from the NGC registry:
from bionemo.core.data.load import load
checkpoint_path = load("esm2/650m:2.0")
print(checkpoint_path)
/home/bionemo/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar
Data¶
We use the InMemoryCSVDataset
class to load the protein sequence data from a .csv
file. This data file should at least have a sequences
column and can optionally have a labels
column used for fine-tuning applications. Here is an example of how to create your own inference input data using a list of sequences in python:
import pandas as pd
artificial_sequence_data = [
"TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
"DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
"LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
"LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
"ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
"SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
]
# Create a DataFrame
df = pd.DataFrame(artificial_sequence_data, columns=["sequences"])
# Save the DataFrame to a CSV file
data_path = os.path.join(work_dir, "sequences.csv")
df.to_csv(data_path, index=False)
Run Inference¶
Similar to PyTorch Lightning, ESM-2 Inference takes advantage of some key classes:
MegatronStrategy
- To launch and setup parallelism for NeMo and Megatron-LM.Trainer
- To configure training configurations and logging.ESMFineTuneDataModule
- To load sequence data for both fine-tuning and inference.ESM2Config
- To configure the ESM-2 model asBionemoLightningModule
.
Please refer to ESM-2 Pretraining and ESM-2 Fine-Tuning tutorials for detailed description of these classes.
To run inference on the data created in the previous step, we can use the infer_esm2
executable which calls bionemo-framework/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py
. We can get a full description of inference arguments by providing --help
in the following command:
! infer_esm2 --help
2024-12-16 20:19:23 - faiss.loader - INFO - Loading faiss with AVX512 support. 2024-12-16 20:19:23 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support. [NeMo W 2024-12-16 20:19:24 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work warn("Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work", RuntimeWarning) [NeMo W 2024-12-16 20:19:24 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead. cm = get_cmap("Set1") usage: infer_esm2 [-h] --checkpoint-path CHECKPOINT_PATH --data-path DATA_PATH --results-path RESULTS_PATH [--precision {fp16,bf16,fp32,bf16-mixed,fp32-mixed,16-mixed,fp16-mixed,16,32}] [--num-gpus NUM_GPUS] [--num-nodes NUM_NODES] [--micro-batch-size MICRO_BATCH_SIZE] [--pipeline-model-parallel-size PIPELINE_MODEL_PARALLEL_SIZE] [--tensor-model-parallel-size TENSOR_MODEL_PARALLEL_SIZE] [--prediction-interval {epoch,batch}] [--include-hiddens] [--include-input-ids] [--include-embeddings] [--include-logits] [--config-class CONFIG_CLASS] Infer ESM2. options: -h, --help show this help message and exit --checkpoint-path CHECKPOINT_PATH Path to the ESM2 pretrained checkpoint --data-path DATA_PATH Path to the CSV file containing sequences and label columns --results-path RESULTS_PATH Path to the results directory. --precision {fp16,bf16,fp32,bf16-mixed,fp32-mixed,16-mixed,fp16-mixed,16,32} Precision type to use for training. --num-gpus NUM_GPUS Number of GPUs to use for training. Default is 1. --num-nodes NUM_NODES Number of nodes to use for training. Default is 1. --micro-batch-size MICRO_BATCH_SIZE Micro-batch size. Global batch size is inferred from this. --pipeline-model-parallel-size PIPELINE_MODEL_PARALLEL_SIZE Pipeline model parallel size. Default is 1. --tensor-model-parallel-size TENSOR_MODEL_PARALLEL_SIZE Tensor model parallel size. Default is 1. --prediction-interval {epoch,batch} Intervals to write DDP predictions into disk --include-hiddens Include hiddens in output of inference --include-input-ids Include input_ids in output of inference --include-embeddings Include embeddings in output of inference --include-logits Include per-token logits in output. --config-class CONFIG_CLASS Model configs link model classes with losses, and handle model initialization (including from a prior checkpoint). This is how you can fine-tune a model. First train with one config class that points to one model class and loss, then implement and provide an alternative config class that points to a variant of that model and alternative loss. In the future this script should also provide similar support for picking different data modules for fine-tuning with different data types. Choices: dict_keys(['ESM2Config', 'ESM2FineTuneSeqConfig', 'ESM2FineTuneTokenConfig'])
The hidden states (which are usually the output of each layer in a neural network) can be obtained by using --include-hiddens
argument when calling the inference function of ESM-2 in BioNeMo Framework.
The hidden states can be converted into fixed-size vector embeddings. This is done by removing the hidden state vectors corresponding to padding tokens, then averaging across the rest. This process is often used when the goal is to create a single vector representation from the hidden states of a model, which can be used for various sequence-level downstream tasks such as classification (e.g. subcellular localization) or regression (e.g. melting temperature prediction). To obtain the embedding results we can use --include-embeddings
argument.
By passing the hidden state of an amino acid sequence through the BERT language model head, we can obtain output logits at each position and transform them into probabilities. This can happen by using --include-logits
argument. Logits here are the raw, unnormalized scores that represent the likelihood of each class and are not probabilities themselves; they can be any real number, including negative values.
Now lets call infer_esm2
executable with relevant arguments to compute and optionally return embeddings, hiddens and logits.
%%capture --no-display --no-stderr cell_output
! infer_esm2 --checkpoint-path {checkpoint_path} \
--data-path {data_path} \
--results-path {work_dir} \
--micro-batch-size 3 \
--num-gpus 1 \
--precision "bf16-mixed" \
--include-hiddens \
--include-embeddings \
--include-logits \
--include-input-ids
Inference Results¶
Inference predictions are stored into .pt
files for each device. Since we only used one device to run the inference (--num-gpus 1
) in the previous step, the results were written to {work_dir}/predictions__rank_0.pt
under the work directory of this notebook (defined above). The .pt
file containes a dictionary of {'result_key': torch.Tensor}
that be loaded with PyTorch:
import torch
results = torch.load(f"{work_dir}/predictions__rank_0.pt")
for key, val in results.items():
if val is not None:
print(f'{key}\t{val.shape}')
token_logits torch.Size([1024, 10, 128]) hidden_states torch.Size([10, 1024, 1280]) input_ids torch.Size([10, 1024]) embeddings torch.Size([10, 1280])
In this example data
a python dict with the following keys ['token_logits', 'hidden_states', 'input_ids', 'embeddings']
. Logits (token_logits
) tensor has a dimension of [sequence, batch, hidden]
to improve the training performance. We will transpose the first two dimension in the following to have batch-first shape like the rest of the output tensors.
logits = results['token_logits'].transpose(0, 1) # s, b, h -> b, s, h
print(logits.shape)
torch.Size([10, 1024, 128])
The last dimension of token_logits
is 128, with the first 33 positions corresponding to the amino acid vocabulary, followed by 95 paddings. We use the tokenizer.vocab_size
to filter out the paddings and only keep the 33 vocab positions.
from bionemo.esm2.data.tokenizer import get_tokenizer
tokenizer = get_tokenizer()
tokens = tokenizer.all_tokens
print(f"There are {tokenizer.vocab_size} unique tokens: {tokens}.")
aa_logits = logits[..., :tokenizer.vocab_size] # filter out the 95 paddings and only keep 33 vocab positions
print(f"Logits shape after removing the paddings in hidden dimension: {aa_logits.shape}")
There are 33 unique tokens: ['<cls>', '<pad>', '<eos>', '<unk>', 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-', '<null_1>', '<mask>']. Logits shape after removing the paddings in hidden dimension: torch.Size([10, 1024, 33])
Let's set aside the tokens corresponding to the 20 known amino acids.
aa_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
aa_indices = [i for i, token in enumerate(tokens) if token in aa_tokens]
extra_indices = [i for i, token in enumerate(tokens) if token not in aa_tokens]
The sequence dimension in this example (1024) is representing the max sequence length wich includes paddings, EOS, and BOS. To filter the relevant amino acid information we can use the input sequence IDs in the results to create a mask that can be used to extract the relevant information in aa_logits
input_ids = results['input_ids'] # b, s
# mask where non-amino acid tokens are True
mask = torch.isin(input_ids, torch.tensor(extra_indices))
DDP Inference Support¶
Although this tutorial is utilizing one devive to run the inference, distributed inference is supported for ESM2 in BioNeMo Framework. One can simply set the the --num-gpus n
to run distributed inference on n
devices. The output predictions will be written into predictions__rank_<0...n-1>.pt
under the --results-path
provided. Moreover, by optionally including input token IDs with --include-input-ids
we can snure 1:1 mapping between input sequences and output predictions.
The following snippet can be used to load and collate the predictions into a single dictionary.
import glob
from bionemo.llm.lightning import batch_collator
collated_preditions = batch_collator([torch.load(path) for path in glob.glob(f"{work_dir}/predictions__rank_*.pt")])
for key, val in collated_preditions.items():
if val is not None:
print(f'{key}\t{val.shape}')
# token_logits torch.Size([1024, 10, 128])
# hidden_states torch.Size([10, 1024, 1280])
# input_ids torch.Size([10, 1024])
# embeddings torch.Size([10, 1280])
For more in-depth example of inference and converting logits to probabilities please refer to ESM-2 Mutant Design Tutorial