Zero-shot prediction of BRCA1 variant effects with Evo 2¶
Deploy this tutorial on brev.dev:
Note - this notebook is a reproduction of The Arc Institute’s same-titled notebook here, using the BioNeMo 2 implementation of Evo2.
Evo2 is a foundation AI model trained on 9.3 trillion DNA base pairs, predicting variant effects without prior tast-specific training.
Without being explicitly trained on BRCA1 variants, we show Evo 2's ability to generalize across all life forms.
The human BRCA1 gene encodes for a protein that repairs damaged DNA (Moynahan et al., 1999). Certain variants of this gene have been associated with an increased risk of breast and ovarian cancers (Miki et al., 1994). Using Evo 2, we can predict whether a particular single nucleotide variant (SNV) of the BRCA1 gene is likely to be harmful to the protein's function, and thus potentially increase the risk of cancer for the patient with the genetic variant.
%%capture
!pip install biopython openpyxl
import os
# Runs a subset of the model layers to test that the notebook runs in CI, but the output will be incorrect.
FAST_CI_MODE:bool = os.environ.get("FAST_CI_MODE", False)
import glob
import gzip
import json
import math
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from Bio import SeqIO
import seaborn as sns
from sklearn.metrics import roc_auc_score, auc, roc_curve
We start by loading a dataset from Findlay et al. (2018), which contains experimentally measured function scores of 3,893 BRCA1 SNVs. These function scores reflect the extent by which the genetic variant has disrupted the protein's function, with lower scores indicating greater disruption. In this dataset, the SNVs are classified into three categories based on their function scores: LOF
(loss-of-function), INT
(intermediate), and FUNC
(functional). We start by reading in this dataset.
To keep the notebook streamlined, we've abstracted much of the preprocessing logic into accompanying scripts located in brca1_utils
. The full notebook can be viewed here.
%%capture
# Configuration parameters
DATA_DIR = 'brca1'
SAMPLE_CONFIG = {
'sample_frac': 0.05,
'balanced': True,
'disable': False,
'random_state': 42
}
# 1. Download the necessary data files if not present
excel_path, genome_path = download_data(DATA_DIR)
seq_chr17 = load_genome_sequence(genome_path)
# 2. Load and preprocess BRCA1 data
brca1_df = load_brca1_data(excel_path)
We then group the FUNC
and INT
classes of SNVs together into a single category (FUNC/INT
).
We build a function to parse the reference and variant sequences of a 8,192-bp window around the genomic position of each SNV, using the reference sequence of human chromosome 17 where BRCA1 is located.
To make things run faster, we'll just look at a balanced sample of our data. If you want to run on the full dataset, set disable_sample=True
OUTPUT_DIR = "brca1_fasta_files"
brca1_df = sample_data(
brca1_df,
sample_frac=SAMPLE_CONFIG['sample_frac'],
balanced=SAMPLE_CONFIG['balanced'],
disable=SAMPLE_CONFIG['disable'],
random_state=SAMPLE_CONFIG['random_state']
)
brca1_df.head(5)
chrom | pos | ref | alt | score | class | |
---|---|---|---|---|---|---|
0 | 17 | 41199726 | T | C | 0.159762 | FUNC/INT |
1 | 17 | 41209074 | T | A | -2.065569 | LOF |
2 | 17 | 41256913 | A | C | -0.847753 | FUNC/INT |
3 | 17 | 41219631 | T | A | -2.053739 | LOF |
4 | 17 | 41215965 | G | A | -1.671525 | LOF |
Next, we'll write these to local .fasta
files so we can use them for prediction below.
brca1_df = generate_fasta_files(
brca1_df,
seq_chr17,
output_dir=OUTPUT_DIR
)
Total unique reference sequences: 79 Total unique variant sequences: 84
Load Evo 2 Checkpoints¶
Then, we load Evo 2 1B model, loading the Evo 2 weights from hugging face.
Note - for better performance, load the 7b model by setting MODEL_SIZE="7b"
which also works well GPUs that do not support FP8.
%%capture
EXPERIMENTAL_1b_CHECKPOINT: bool = False
MODEL_SIZE = "1b" # also try 7b if you have a GPU with more than 32GB of memory
# Define checkpoint path
if EXPERIMENTAL_1b_CHECKPOINT and MODEL_SIZE == "1b":
from bionemo.core.data.load import load
# This is a new 1b checkpoint that has been fine-tuned on BF16 hardware. It should be able to handle FP8 as well.
# this line will download the checkpoint from NGC to your $HOME/.cache/bionemo directory and return the path.
# alternatively you can use `CHECKPOINT_PATH=$(download_bionemo_data evo2/1b-8k-bf16:1.0)` to do the same on the
# command line.
checkpoint_path = load("evo2/1b-8k-bf16:1.0")
else:
checkpoint_path = Path(f"nemo2_evo2_{MODEL_SIZE}_8k")
# Check if the directory does not exist or is empty
if not checkpoint_path.exists() or not any(checkpoint_path.iterdir()):
!evo2_convert_to_nemo2 --model-path hf://arcinstitute/savanna_evo2_{MODEL_SIZE}_base --model-size {MODEL_SIZE} --output-dir nemo2_evo2_{MODEL_SIZE}_8k
else:
print("Checkpoint directory is not empty. Skipping command.")
Score Sequences¶
Next, we score the likelihoods of the reference and variant sequences of each SNV.
# Define output directories for prediction results
output_dir = Path("brca1_fasta_files")
output_dir.mkdir(parents=True, exist_ok=True)
# Save reference and variant sequences to FASTA
ref_fasta_path = output_dir / "brca1_reference_sequences.fasta"
var_fasta_path = output_dir / "brca1_variant_sequences.fasta"
predict_ref_dir = output_dir / "reference_predictions"
predict_var_dir = output_dir / "variant_predictions"
predict_ref_dir.mkdir(parents=True, exist_ok=True)
predict_var_dir.mkdir(parents=True, exist_ok=True)
fp8_supported, gpu_info = check_fp8_support()
print(f"FP8 Support: {fp8_supported}")
print(gpu_info)
# Note: If FP8 is not supported, you may want to disable it in the model config
# The Evo2 config has 'use_fp8_input_projections: True' by default
if FAST_CI_MODE:
model_subset_option = "--num-layers 4 --hybrid-override-pattern SDH*"
else:
model_subset_option = ""
fp8_option = "--fp8" if fp8_supported else ""
# Update predict commands to run on the full dataset
predict_ref_command = (
f"predict_evo2 --fasta {ref_fasta_path} --ckpt-dir {checkpoint_path} "
f"--output-dir {predict_ref_dir} --model-size {MODEL_SIZE} --tensor-parallel-size 1 {model_subset_option} "
f"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs {fp8_option}"
)
predict_var_command = (
f"predict_evo2 --fasta {var_fasta_path} --ckpt-dir {checkpoint_path} "
f"--output-dir {predict_var_dir} --model-size {MODEL_SIZE} --tensor-parallel-size 1 {model_subset_option} "
f"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs {fp8_option}"
)
FP8 Support: False Device: NVIDIA RTX A6000, Compute Capability: 8.6
Score reference sequences:
%%capture
print(f"Running command: {predict_ref_command}")
!{predict_ref_command}
Score variant sequences:
%%capture
print(f"Running command: {predict_var_command}")
!{predict_var_command}
We calculate the change in likelihoods for each variant relative to the likelihood of their respective wild-type sequence.
First, we load the prediction files and sequence id maps:
# Find and load prediction files
ref_pred_files = glob.glob(os.path.join(predict_ref_dir, "predictions__rank_*.pt"))
var_pred_files = glob.glob(os.path.join(predict_var_dir, "predictions__rank_*.pt"))
# Load sequence ID maps (maps sequence ID -> prediction index)
with open(os.path.join(predict_ref_dir, "seq_idx_map.json"), "r") as f:
ref_seq_idx_map = json.load(f)
with open(os.path.join(predict_var_dir, "seq_idx_map.json"), "r") as f:
var_seq_idx_map = json.load(f)
# Load predictions
ref_preds = torch.load(ref_pred_files[0])
var_preds = torch.load(var_pred_files[0])
Then, calculate the delta score:
# next, calculate change in likelihoods
ref_log_probs = []
var_log_probs = []
for _, row in brca1_df.iterrows():
ref_name = row['ref_fasta_name']
var_name = row['var_fasta_name']
ref_log_probs.append(ref_preds['log_probs_seqs'][ref_seq_idx_map[ref_name]].item())
var_log_probs.append(var_preds['log_probs_seqs'][var_seq_idx_map[var_name]].item())
brca1_df['ref_log_probs'] = ref_log_probs
brca1_df['var_log_probs'] = var_log_probs
# ideally probability of a broken variant is lower than a good one. So a bad var - good ref is negative.
brca1_df['evo2_delta_score'] = brca1_df['var_log_probs'] - brca1_df['ref_log_probs']
brca1_df.head()
chrom | pos | ref | alt | score | class | ref_fasta_name | var_fasta_name | ref_log_probs | var_log_probs | evo2_delta_score | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 17 | 41199726 | T | C | 0.159762 | FUNC/INT | BRCA1_ref_pos_41199726_T_class_FUNC/INT | BRCA1_var_pos_41199726_TtoC_class_FUNC/INT | -0.952916 | -0.953258 | -0.000342 |
1 | 17 | 41209074 | T | A | -2.065569 | LOF | BRCA1_ref_pos_41209074_T_class_LOF | BRCA1_var_pos_41209074_TtoA_class_LOF | -0.750398 | -0.750437 | -0.000039 |
2 | 17 | 41256913 | A | C | -0.847753 | FUNC/INT | BRCA1_ref_pos_41256913_A_class_FUNC/INT | BRCA1_var_pos_41256913_AtoC_class_FUNC/INT | -0.798164 | -0.799011 | -0.000847 |
3 | 17 | 41219631 | T | A | -2.053739 | LOF | BRCA1_ref_pos_41219631_T_class_LOF | BRCA1_var_pos_41219631_TtoA_class_LOF | -1.032126 | -1.032697 | -0.000571 |
4 | 17 | 41215965 | G | A | -1.671525 | LOF | BRCA1_ref_pos_41215965_G_class_LOF | BRCA1_var_pos_41215965_GtoA_class_LOF | -0.860847 | -0.861287 | -0.000441 |
This delta likelihood should be predictive of how disruptive the SNV is to the protein's function: the lower the delta, the more likely that the SNV is disruptive. We can show this by comparing the distributions of delta likelihoods for the two classes of SNVs (functional/intermediate vs loss-of-function).
plot_strip_with_means(brca1_df, x_col="evo2_delta_score", class_col="class")
We can also calculate the area under the receiver operating characteristic curve (AUROC) of this zero-shot prediction method. Note that the results are nearly random unless you are on one of the following configurations:
--fp8
on an fp8 enabled GPU with either the 1b or 7b models. The 40b likely works as well.- the 7b model uniquely seems to work well without
--fp8
so if you are on an older device, the 7b model should produce robust results. Change theMODEL_SIZE
earlier in this tutorial and rerun for good results in that case.
# Calculate AUROC of zero-shot predictions
# class 1 is LOF which is the bad thing. That means we expect this to be more negative.
y_true = (brca1_df['class'] == 'LOF')
auroc = roc_auc_score(y_true, -brca1_df['evo2_delta_score'])
print(f'Zero-shot prediction AUROC: {auroc:.2}')
Zero-shot prediction AUROC: 0.77
plot_roc_curve(brca1_df)
Full Sample Performance¶
The above analysis may have been performed on a subset of the available data.
For comparison, the table below presents the AUROC scores for different model sizes trained on the full dataset (100% sample fraction).
Model Size | Dataset Sample Fraction | AUROC |
---|---|---|
Evo 2 1B | 100% | 0.76 |
Evo 2 7B | 100% | 0.87 |