dna_code = {
"ATA": "I",
"ATC": "I",
"ATT": "I",
"ATG": "M",
"ACA": "T",
"ACC": "T",
"ACG": "T",
"ACT": "T",
"AAC": "N",
"AAT": "N",
"AAA": "K",
"AAG": "K",
"AGC": "S",
"AGT": "S",
"AGA": "R",
"AGG": "R",
"CTA": "L",
"CTC": "L",
"CTG": "L",
"CTT": "L",
"CCA": "P",
"CCC": "P",
"CCG": "P",
"CCT": "P",
"CAC": "H",
"CAT": "H",
"CAA": "Q",
"CAG": "Q",
"CGA": "R",
"CGC": "R",
"CGG": "R",
"CGT": "R",
"GTA": "V",
"GTC": "V",
"GTG": "V",
"GTT": "V",
"GCA": "A",
"GCC": "A",
"GCG": "A",
"GCT": "A",
"GAC": "D",
"GAT": "D",
"GAA": "E",
"GAG": "E",
"GGA": "G",
"GGC": "G",
"GGG": "G",
"GGT": "G",
"TCA": "S",
"TCC": "S",
"TCG": "S",
"TCT": "S",
"TTC": "F",
"TTT": "F",
"TTA": "L",
"TTG": "L",
"TAC": "Y",
"TAT": "Y",
"TAA": "*",
"TAG": "*",
"TGC": "C",
"TGT": "C",
"TGA": "*",
"TGG": "W",
}
def translate(seq):
"""
Translate an RNA sequence into a protein sequence.
"""
protein = ""
# Process the RNA sequence three nucleotides (codon) at a time.
for i in range(0, len(seq) - 2, 3):
codon = seq[i : i + 3]
# Look up the codon in the genetic code dictionary.
amino_acid = dna_code.get(codon, "?")
protein += amino_acid
return protein
def reverse_complement_dna(seq):
"""
Return the reverse complement of a DNA sequence.
Parameters:
seq (str): A DNA sequence with uppercase letters only (e.g., "ATCG").
Returns:
str: The reverse complement DNA sequence.
Raises:
KeyError: If the sequence contains lowercase letters or invalid characters.
"""
complement = {"A": "T", "T": "A", "G": "C", "C": "G", "N": "N"}
return "".join(complement[base] for base in seq[::-1])
def process_gtf(gtf_path, fasta_path):
"""
Build coding sequences (CDS) for transcripts from a a gtf annotation file
and a reference FASTA file.
"""
gtf = pd.read_csv(gtf_path, sep="\t")
if "#name" in gtf.columns:
gtf = gtf.rename({"#name": "name"}, axis=1)
gtf = (
gtf.loc[gtf["cdsStart"] != gtf["cdsEnd"]].reset_index(drop=True).copy()
) # keep transcripts with a non‑empty coding region
gtf["exonStarts_arr"] = gtf["exonStarts"].map(lambda x: ast.literal_eval(x))
gtf["exonEnds_arr"] = gtf["exonEnds"].map(lambda x: ast.literal_eval(x))
fasta = {}
with pyfaidx.Fasta(fasta_path) as f:
# Get available chromosomes in the FASTA file
available_chroms = set(f.keys())
# Filter GTF to only include chromosomes present in FASTA
gtf_chroms = set(gtf["chrom"].unique())
missing_chroms = gtf_chroms - available_chroms
if missing_chroms:
print(
f"Warning: {len(missing_chroms)} chromosome(s) in GTF not found in FASTA, filtering them out: {sorted(missing_chroms)}"
)
gtf = gtf[gtf["chrom"].isin(available_chroms)].reset_index(drop=True).copy()
# Load chromosome sequences
for chrom in gtf["chrom"].unique():
fasta[chrom] = f[chrom][
:
].seq.upper() # load entire chrom sequence to uppercase and store in fasta[chrom] for fast slicing.
cds_starts = []
cds_ends = []
lengths = []
seqs = []
for i in tqdm(
range(gtf.shape[0]), desc="Processing transcripts", total=gtf.shape[0]
): # loop through each transcript in the gtf file
t = gtf.iloc[i]
chrom = t["chrom"]
cds_s = []
cds_e = []
cs, ce = t[["cdsStart", "cdsEnd"]] # These are 0-based coordinates
length = 0
curr_seq = []
for a, b in zip(t["exonStarts_arr"], t["exonEnds_arr"]): # combine all exons regions for this transcript
v1 = max(a, cs) # clip exons to the coding region (ignore UTRs)
v2 = min(b, ce)
if v1 < v2: # record the sequence string if it overlaps with the coding region
cds_s.append(v1)
cds_e.append(v2)
length += v2 - v1
curr_seq.append(fasta[chrom][v1:v2])
# save the cds starts, ends, and length for this transcript
cds_starts.append(tuple(cds_s))
cds_ends.append(tuple(cds_e))
lengths.append(length)
# Get the joined CDS sequence in the forward direction (as it is build from the reference FASTA)
curr_seq = "".join(curr_seq)
if t["strand"] == "-":
# reverse‑complement to get the seequence in the gene direction (5'→3')
curr_seq = reverse_complement_dna(curr_seq)
seqs.append(curr_seq)
# Build the the processed transcript table with coding sequence information for each transcript
gtf["cds_starts"] = cds_starts
gtf["cds_ends"] = cds_ends
gtf["cds_length"] = lengths
gtf["cds"] = seqs # sequence is strand-aware (always gene 5'->3')
gtf_s = gtf[
["name", "chrom", "strand", "cdsStart", "cdsEnd", "cds_starts", "cds_ends", "cds_length", "cds"]
].copy()
gtf_s["name"] = gtf_s["name"].str.split(".").str[0]
# Sort transcripts by chromosome, start, and end coordinates, they're in the forward direction and 0-based.
gtf_s = gtf_s.sort_values(by=["chrom", "cdsStart", "cdsEnd"]).reset_index(drop=True).copy()
return gtf_s, fasta
def process_a_chrom(chrom_variants, chrom_refseq, return_alt_cds=False):
"""
Annotate a single nucleotide variant with transcript CDS context.
- Locate its position within the coding sequence (0-based coordinate),
- extract the ref codon and build the alt codon,
- translate codons to amino acids,
- optionally, build the full alternate CDS with the mutation.
## Notes: The output columns are:
- pos: 1-based (forward direction, genomic coordinates)
- ref: genomic ref allele (as in the reference genome, forward orientation)
- alt: genomic alt allele (as in the reference genome, forward orientation)
- cdsStart: 0-based genomic start of CDS, half-open interval; genomic axis.
- cdsEnd: 0-based genomic end of CDS, half-open interval; genomic axis.
- var_rel_dist_in_cds: 0-based index within the CDS in transcript 5'->3' orientation (strand-aware).
- ref_seq: full CDS string in transcript 5'->3' orientation
- ref_codon: codon from ref_seq at the variant position; strand-aware.
- alt_codon: codon after single-base change; strand-aware.
- ref_aa: ref amino acid at the variant position
- alt_aa: alt amino acid at the variant position
- alt_seq (if set): full alt CDS after the single-base change, transcript 5'->3' orientation; strand-aware.
- codon_position: 0-based codon index within CDS.
"""
# normalize alleles;
chrom_variants["ref"] = chrom_variants["ref"].str.upper()
chrom_variants["alt"] = chrom_variants["alt"].str.upper()
var_ids = chrom_variants["variant_id"].values
var_pos = (
chrom_variants["pos"].values - 1
) # Convert to 0-based - mutations are always reported in 1-based coordinates
var_ref = chrom_variants["ref"].values # reference allele
var_alt = chrom_variants["alt"].values # alternate allele
chrom = chrom_refseq.iloc[0]["chrom"]
# CDS processed using `process_gtf`` function
cds_strands = chrom_refseq["strand"].values # strand of the coding sequence
cds_starts = chrom_refseq["cdsStart"].values
cds_ends = chrom_refseq["cdsEnd"].values
cds_lengths = chrom_refseq["cds_length"].values
rec_cds_starts = chrom_refseq[
"cds_starts"
].values # List of exon starts within the CDS region for all transcripts.
rec_cds_ends = chrom_refseq["cds_ends"].values # List of exon ends within the CDS region for all transcripts.
rec_cds = chrom_refseq["cds"].values # CDS sequence - strand-aware (always gene 5'->3')
rec_names = chrom_refseq["name"].values
# Find transcripts (CDS regions)that overlap the variant position
# sorted variant positions to find, per transcript,
# var_pos[s1[j]:ss2[j]] includes all variants with pos in [cdsStart[j], cdsEnd[j]) for transcript j.
s1 = np.searchsorted(var_pos, cds_starts, side="left")
s2 = np.searchsorted(var_pos, cds_ends, side="right")
results = []
# Loop through each transcript j (CDS region) and get information for all overlapping variants.
for j, (ss1, ss2) in enumerate(zip(s1, s2)):
curr_starts = rec_cds_starts[j] # CDS starts boundaries for transcript j.
curr_ends = rec_cds_ends[j] # CDS ends boundaries for transcript j.
# Sanity checks on CDS sequence for this transcript
assert cds_lengths[j] == len(rec_cds[j]), f"CDS length mismatch for {rec_names[j]}"
assert cds_lengths[j] % 3 == 0, f"CDS length not multiple of 3 for {rec_names[j]}"
if ss1 < ss2:
for i in range(ss1, ss2): # loop through all variants in the CDS region for transcript j.
pos = var_pos[i]
curr_ref = var_ref[i]
curr_alt = var_alt[i]
# Calculate offset in CDS sequence - translate a genomic coordinate into a CDS relativeindex.
offset = 0
bound = False
for a, b in zip(curr_starts, curr_ends):
if pos >= b: # if the variant is after the end of the exon
offset += b - a # Add length of complete exon
elif a <= pos < b:
offset += pos - a
bound = True # the variant is within this exon
break
if bound:
if cds_strands[j] == "-":
# Handle reverse strand
offset = cds_lengths[j] - 1 - offset # Convert to reverse strand position
ref_codon = rec_cds[j][offset // 3 * 3 : offset // 3 * 3 + 3]
# Check if the reference base in ref codon is the reverse complement of the variant reference allele from the reference FASTA.
assert rec_cds[j][offset] == reverse_complement_dna(curr_ref), f"-, {ref_codon} {var_ids[i]}"
# Build the alternate codon by replacing the reference base with the alternate allele.
alt_codon = (
ref_codon[: offset % 3] + reverse_complement_dna(curr_alt) + ref_codon[offset % 3 + 1 :]
)
results.append(
[
chrom,
pos,
f"{chrom}_{pos + 1}_{curr_ref}_{curr_alt}",
curr_ref,
curr_alt,
rec_names[j],
cds_starts[j],
cds_ends[j],
cds_strands[j],
offset,
rec_cds[j],
ref_codon,
alt_codon,
translate(ref_codon),
translate(alt_codon),
]
)
if return_alt_cds:
alt_cds = rec_cds[j][:offset] + reverse_complement_dna(curr_alt) + rec_cds[j][offset + 1 :]
results[-1].append(alt_cds)
else:
# Handle forward strand
ref_codon = rec_cds[j][offset // 3 * 3 : offset // 3 * 3 + 3]
assert rec_cds[j][offset] == curr_ref, f"+, {ref_codon} {var_ids[i]}"
alt_codon = ref_codon[: offset % 3] + curr_alt + ref_codon[offset % 3 + 1 :]
results.append(
[
chrom,
pos,
f"{chrom}_{pos + 1}_{curr_ref}_{curr_alt}",
curr_ref,
curr_alt,
rec_names[j],
cds_starts[j],
cds_ends[j],
cds_strands[j],
offset,
rec_cds[j],
ref_codon,
alt_codon,
translate(ref_codon),
translate(alt_codon),
]
)
if return_alt_cds:
alt_cds = rec_cds[j][:offset] + curr_alt + rec_cds[j][offset + 1 :]
results[-1].append(alt_cds)
columns = [
"chrom",
"pos",
"variant_id",
"ref",
"alt",
"tx_name",
"cdsStart",
"cdsEnd",
"tx_strand",
"var_rel_dist_in_cds",
"ref_seq",
"ref_codon",
"alt_codon",
"ref_aa",
"alt_aa",
]
if return_alt_cds:
columns.append("alt_seq")
if results:
results = pd.DataFrame(results)
results.columns = columns
results["pos"] += 1 # Convert back to 1-based
results["codon_position"] = results["var_rel_dist_in_cds"] // 3
else:
# Create empty DataFrame with correct columns
results = pd.DataFrame(columns=columns)
results["codon_position"] = pd.Series(dtype="int64")
return results
def plot_transcript_distribution(variants):
# Count transcripts per variant, then count how many variants fall in each bin
counts_pl = pl.from_pandas(variants).group_by("variant_id").count().rename({"count": "n_transcripts"})
hist_pl = counts_pl.group_by("n_transcripts").count().rename({"count": "n_variants"}).sort("n_transcripts")
df = hist_pl.to_pandas().astype({"n_transcripts": "int64", "n_variants": "int64"})
plt.figure(figsize=(20, 4))
ax = sns.barplot(data=df, x="n_transcripts", y="n_variants", color="#4c5a88")
ax.set_xlabel("Number of transcripts associated with a single variant")
ax.set_ylabel("Number of unique variants")
ax.set_title("Distribution of transcripts per variant")
ax.set_yscale("log")
ax.yaxis.set_major_locator(LogLocator(base=10)) # 10^k ticks # log-scale Y
for p in ax.patches:
h = p.get_height()
if h > 0:
ax.annotate(
f"{int(h):,}",
(p.get_x() + p.get_width() / 2, h),
ha="center",
va="bottom",
fontsize=9,
xytext=(0, 3),
textcoords="offset points",
)
sns.despine()
plt.tight_layout()
plt.show()
def check_mutation_positions(df: pd.DataFrame, adjusted_context_length: int) -> pd.DataFrame:
"""
Check if the mutation positions are within the bounds of the coding sequence.
adjusted_context_length = max_length - 2 (for [CLS] and [SEP])
"""
def _check(row):
ref_seq = row["ref_seq"]
codon_pos = int(row["codon_position"])
total_codons = (len(ref_seq) // 3) if isinstance(ref_seq, str) else 0
out = {
"id": row.get("id", None),
"variant_id": row.get("variant_id", None),
"total_codons": total_codons,
"codon_position": codon_pos,
"in_bounds": codon_pos < total_codons,
"needs_centering": total_codons > adjusted_context_length,
"out_of_bounds": codon_pos >= adjusted_context_length,
}
return pd.Series(out)
return df.apply(_check, axis=1)
def get_reverse_complement(seq):
"""Get reverse complement of a sequence"""
complement = {"A": "T", "T": "A", "G": "C", "C": "G", "N": "N"}
return "".join(complement[base] for base in seq[::-1].upper())
def extract_cds_sequence(row, fasta):
"""Extract CDS sequence for a transcript based on exon coordinates and CDS boundaries."""
chrom = row["chrom"]
strand = row["strand"]
cds_start = row["cdsStart"]
cds_end = row["cdsEnd"]
# Parse exon coordinates
exon_starts = [int(x) for x in row["exonStarts"].rstrip(",").split(",")]
exon_ends = [int(x) for x in row["exonEnds"].rstrip(",").split(",")]
# Extract CDS sequence from exons
cds_sequence = ""
for start, end in zip(exon_starts, exon_ends):
# Find overlap between exon and CDS
overlap_start = max(start, cds_start)
overlap_end = min(end, cds_end)
if overlap_start < overlap_end:
# Extract sequence from this exon segment
seq = str(fasta[chrom][overlap_start:overlap_end]).upper()
cds_sequence += seq
# Reverse complement if on negative strand
if strand == "-":
cds_sequence = get_reverse_complement(cds_sequence)
return cds_sequence
def process_dset(dset, refseq, remove_non_pli=False):
"""
Add additional features to the dataset including:
- Amino acid translations
- Codon frequency ratios
- Gene names and pLI scores
- PhyloP conservation scores
- CDS offset fractions
"""
# Add amino acid translations to the dataset
dset = dset.with_columns(
[
pl.col("ref_codon")
.map_elements(lambda x: str(Seq(x).translate()), return_dtype=pl.String)
.alias("ref_aa"),
pl.col("alt_codon")
.map_elements(lambda x: str(Seq(x).translate()), return_dtype=pl.String)
.alias("alt_aa"),
]
)
assert dset.filter(pl.col("ref_aa") != pl.col("alt_aa")).height == 0
dset = dset.filter(pl.col("ref_aa") != "*")
codon_freqs = json.load(open(f"{DATA_DIR}/codon_counts_nopathogen.json"))["Primates"]
dset = dset.with_columns(
pl.col("ref_codon").map_elements(lambda x: codon_freqs[x], return_dtype=pl.Float64).alias("ref_codon_freq")
)
dset = dset.with_columns(
pl.col("alt_codon").map_elements(lambda x: codon_freqs[x], return_dtype=pl.Float64).alias("alt_codon_freq")
)
dset = dset.with_columns((pl.col("ref_codon_freq") / pl.col("alt_codon_freq")).log().alias("codon_freq_ratio"))
tx_to_name = {row["name"]: row["name2"] for row in refseq.rows(named=True)}
if "gene_name" not in dset.columns:
dset = dset.with_columns(
pl.col("tx").map_elements(lambda x: tx_to_name[x], return_dtype=pl.String).alias("gene_name")
)
pli = pl.read_csv(f"{DATA_DIR}/ucsc_pliByGene_hg38.tsv", separator="\t")
gene_to_pli = {row["geneName"]: row["_pli"] for row in pli.rows(named=True)}
dset = dset.with_columns(
pl.col("gene_name").map_elements(lambda x: gene_to_pli.get(x, -1000), return_dtype=pl.Float64).alias("pli")
)
if remove_non_pli:
dset = dset.filter(pl.col("pli") != -1000)
dset = dset.with_columns((pl.col("pli") * 10).cast(pl.Int32).alias("pli_bin"))
bw = pyBigWig.open(f"{DATA_DIR}/hg38.phyloP447way.bw")
phylop = []
for row in tqdm(dset.rows(named=True)):
phylop.append(bw.values(row["chrom"], row["pos"] - 1, row["pos"])[0])
dset = dset.with_columns(pl.Series(values=phylop, name="phylop").fill_nan(-1000))
dset = dset.with_columns(pl.col("phylop").round().cast(pl.Int32).alias("phylop_bin"))
if "cds_offset_frac" not in dset.columns:
dset = dset.with_columns(pl.col("ref_seq").str.len_chars().alias("cds_length"))
dset = dset.with_columns((pl.col("var_rel_dist_in_cds") / pl.col("cds_length")).alias("cds_offset_frac"))
dset = dset.with_columns((pl.col("cds_offset_frac") * 10).cast(pl.Int32).alias("cds_offset_frac_bin"))
return dset
dna_code = {
"ATA": "I",
"ATC": "I",
"ATT": "I",
"ATG": "M",
"ACA": "T",
"ACC": "T",
"ACG": "T",
"ACT": "T",
"AAC": "N",
"AAT": "N",
"AAA": "K",
"AAG": "K",
"AGC": "S",
"AGT": "S",
"AGA": "R",
"AGG": "R",
"CTA": "L",
"CTC": "L",
"CTG": "L",
"CTT": "L",
"CCA": "P",
"CCC": "P",
"CCG": "P",
"CCT": "P",
"CAC": "H",
"CAT": "H",
"CAA": "Q",
"CAG": "Q",
"CGA": "R",
"CGC": "R",
"CGG": "R",
"CGT": "R",
"GTA": "V",
"GTC": "V",
"GTG": "V",
"GTT": "V",
"GCA": "A",
"GCC": "A",
"GCG": "A",
"GCT": "A",
"GAC": "D",
"GAT": "D",
"GAA": "E",
"GAG": "E",
"GGA": "G",
"GGC": "G",
"GGG": "G",
"GGT": "G",
"TCA": "S",
"TCC": "S",
"TCG": "S",
"TCT": "S",
"TTC": "F",
"TTT": "F",
"TTA": "L",
"TTG": "L",
"TAC": "Y",
"TAT": "Y",
"TAA": "*",
"TAG": "*",
"TGC": "C",
"TGT": "C",
"TGA": "*",
"TGG": "W",
}
def translate(seq):
"""
Translate an RNA sequence into a protein sequence.
"""
protein = ""
# Process the RNA sequence three nucleotides (codon) at a time.
for i in range(0, len(seq) - 2, 3):
codon = seq[i : i + 3]
# Look up the codon in the genetic code dictionary.
amino_acid = dna_code.get(codon, "?")
protein += amino_acid
return protein
def reverse_complement_dna(seq):
"""
Return the reverse complement of a DNA sequence.
Parameters:
seq (str): A DNA sequence with uppercase letters only (e.g., "ATCG").
Returns:
str: The reverse complement DNA sequence.
Raises:
KeyError: If the sequence contains lowercase letters or invalid characters.
"""
complement = {"A": "T", "T": "A", "G": "C", "C": "G", "N": "N"}
return "".join(complement[base] for base in seq[::-1])
def process_gtf(gtf_path, fasta_path):
"""
Build coding sequences (CDS) for transcripts from a a gtf annotation file
and a reference FASTA file.
"""
gtf = pd.read_csv(gtf_path, sep="\t")
if "#name" in gtf.columns:
gtf = gtf.rename({"#name": "name"}, axis=1)
gtf = (
gtf.loc[gtf["cdsStart"] != gtf["cdsEnd"]].reset_index(drop=True).copy()
) # keep transcripts with a non‑empty coding region
gtf["exonStarts_arr"] = gtf["exonStarts"].map(lambda x: ast.literal_eval(x))
gtf["exonEnds_arr"] = gtf["exonEnds"].map(lambda x: ast.literal_eval(x))
fasta = {}
with pyfaidx.Fasta(fasta_path) as f:
# Get available chromosomes in the FASTA file
available_chroms = set(f.keys())
# Filter GTF to only include chromosomes present in FASTA
gtf_chroms = set(gtf["chrom"].unique())
missing_chroms = gtf_chroms - available_chroms
if missing_chroms:
print(
f"Warning: {len(missing_chroms)} chromosome(s) in GTF not found in FASTA, filtering them out: {sorted(missing_chroms)}"
)
gtf = gtf[gtf["chrom"].isin(available_chroms)].reset_index(drop=True).copy()
# Load chromosome sequences
for chrom in gtf["chrom"].unique():
fasta[chrom] = f[chrom][
:
].seq.upper() # load entire chrom sequence to uppercase and store in fasta[chrom] for fast slicing.
cds_starts = []
cds_ends = []
lengths = []
seqs = []
for i in tqdm(
range(gtf.shape[0]), desc="Processing transcripts", total=gtf.shape[0]
): # loop through each transcript in the gtf file
t = gtf.iloc[i]
chrom = t["chrom"]
cds_s = []
cds_e = []
cs, ce = t[["cdsStart", "cdsEnd"]] # These are 0-based coordinates
length = 0
curr_seq = []
for a, b in zip(t["exonStarts_arr"], t["exonEnds_arr"]): # combine all exons regions for this transcript
v1 = max(a, cs) # clip exons to the coding region (ignore UTRs)
v2 = min(b, ce)
if v1 < v2: # record the sequence string if it overlaps with the coding region
cds_s.append(v1)
cds_e.append(v2)
length += v2 - v1
curr_seq.append(fasta[chrom][v1:v2])
# save the cds starts, ends, and length for this transcript
cds_starts.append(tuple(cds_s))
cds_ends.append(tuple(cds_e))
lengths.append(length)
# Get the joined CDS sequence in the forward direction (as it is build from the reference FASTA)
curr_seq = "".join(curr_seq)
if t["strand"] == "-":
# reverse‑complement to get the seequence in the gene direction (5'→3')
curr_seq = reverse_complement_dna(curr_seq)
seqs.append(curr_seq)
# Build the the processed transcript table with coding sequence information for each transcript
gtf["cds_starts"] = cds_starts
gtf["cds_ends"] = cds_ends
gtf["cds_length"] = lengths
gtf["cds"] = seqs # sequence is strand-aware (always gene 5'->3')
gtf_s = gtf[
["name", "chrom", "strand", "cdsStart", "cdsEnd", "cds_starts", "cds_ends", "cds_length", "cds"]
].copy()
gtf_s["name"] = gtf_s["name"].str.split(".").str[0]
# Sort transcripts by chromosome, start, and end coordinates, they're in the forward direction and 0-based.
gtf_s = gtf_s.sort_values(by=["chrom", "cdsStart", "cdsEnd"]).reset_index(drop=True).copy()
return gtf_s, fasta
def process_a_chrom(chrom_variants, chrom_refseq, return_alt_cds=False):
"""
Annotate a single nucleotide variant with transcript CDS context.
- Locate its position within the coding sequence (0-based coordinate),
- extract the ref codon and build the alt codon,
- translate codons to amino acids,
- optionally, build the full alternate CDS with the mutation.
## Notes: The output columns are:
- pos: 1-based (forward direction, genomic coordinates)
- ref: genomic ref allele (as in the reference genome, forward orientation)
- alt: genomic alt allele (as in the reference genome, forward orientation)
- cdsStart: 0-based genomic start of CDS, half-open interval; genomic axis.
- cdsEnd: 0-based genomic end of CDS, half-open interval; genomic axis.
- var_rel_dist_in_cds: 0-based index within the CDS in transcript 5'->3' orientation (strand-aware).
- ref_seq: full CDS string in transcript 5'->3' orientation
- ref_codon: codon from ref_seq at the variant position; strand-aware.
- alt_codon: codon after single-base change; strand-aware.
- ref_aa: ref amino acid at the variant position
- alt_aa: alt amino acid at the variant position
- alt_seq (if set): full alt CDS after the single-base change, transcript 5'->3' orientation; strand-aware.
- codon_position: 0-based codon index within CDS.
"""
# normalize alleles;
chrom_variants["ref"] = chrom_variants["ref"].str.upper()
chrom_variants["alt"] = chrom_variants["alt"].str.upper()
var_ids = chrom_variants["variant_id"].values
var_pos = (
chrom_variants["pos"].values - 1
) # Convert to 0-based - mutations are always reported in 1-based coordinates
var_ref = chrom_variants["ref"].values # reference allele
var_alt = chrom_variants["alt"].values # alternate allele
chrom = chrom_refseq.iloc[0]["chrom"]
# CDS processed using `process_gtf`` function
cds_strands = chrom_refseq["strand"].values # strand of the coding sequence
cds_starts = chrom_refseq["cdsStart"].values
cds_ends = chrom_refseq["cdsEnd"].values
cds_lengths = chrom_refseq["cds_length"].values
rec_cds_starts = chrom_refseq[
"cds_starts"
].values # List of exon starts within the CDS region for all transcripts.
rec_cds_ends = chrom_refseq["cds_ends"].values # List of exon ends within the CDS region for all transcripts.
rec_cds = chrom_refseq["cds"].values # CDS sequence - strand-aware (always gene 5'->3')
rec_names = chrom_refseq["name"].values
# Find transcripts (CDS regions)that overlap the variant position
# sorted variant positions to find, per transcript,
# var_pos[s1[j]:ss2[j]] includes all variants with pos in [cdsStart[j], cdsEnd[j]) for transcript j.
s1 = np.searchsorted(var_pos, cds_starts, side="left")
s2 = np.searchsorted(var_pos, cds_ends, side="right")
results = []
# Loop through each transcript j (CDS region) and get information for all overlapping variants.
for j, (ss1, ss2) in enumerate(zip(s1, s2)):
curr_starts = rec_cds_starts[j] # CDS starts boundaries for transcript j.
curr_ends = rec_cds_ends[j] # CDS ends boundaries for transcript j.
# Sanity checks on CDS sequence for this transcript
assert cds_lengths[j] == len(rec_cds[j]), f"CDS length mismatch for {rec_names[j]}"
assert cds_lengths[j] % 3 == 0, f"CDS length not multiple of 3 for {rec_names[j]}"
if ss1 < ss2:
for i in range(ss1, ss2): # loop through all variants in the CDS region for transcript j.
pos = var_pos[i]
curr_ref = var_ref[i]
curr_alt = var_alt[i]
# Calculate offset in CDS sequence - translate a genomic coordinate into a CDS relativeindex.
offset = 0
bound = False
for a, b in zip(curr_starts, curr_ends):
if pos >= b: # if the variant is after the end of the exon
offset += b - a # Add length of complete exon
elif a <= pos < b:
offset += pos - a
bound = True # the variant is within this exon
break
if bound:
if cds_strands[j] == "-":
# Handle reverse strand
offset = cds_lengths[j] - 1 - offset # Convert to reverse strand position
ref_codon = rec_cds[j][offset // 3 * 3 : offset // 3 * 3 + 3]
# Check if the reference base in ref codon is the reverse complement of the variant reference allele from the reference FASTA.
assert rec_cds[j][offset] == reverse_complement_dna(curr_ref), f"-, {ref_codon} {var_ids[i]}"
# Build the alternate codon by replacing the reference base with the alternate allele.
alt_codon = (
ref_codon[: offset % 3] + reverse_complement_dna(curr_alt) + ref_codon[offset % 3 + 1 :]
)
results.append(
[
chrom,
pos,
f"{chrom}_{pos + 1}_{curr_ref}_{curr_alt}",
curr_ref,
curr_alt,
rec_names[j],
cds_starts[j],
cds_ends[j],
cds_strands[j],
offset,
rec_cds[j],
ref_codon,
alt_codon,
translate(ref_codon),
translate(alt_codon),
]
)
if return_alt_cds:
alt_cds = rec_cds[j][:offset] + reverse_complement_dna(curr_alt) + rec_cds[j][offset + 1 :]
results[-1].append(alt_cds)
else:
# Handle forward strand
ref_codon = rec_cds[j][offset // 3 * 3 : offset // 3 * 3 + 3]
assert rec_cds[j][offset] == curr_ref, f"+, {ref_codon} {var_ids[i]}"
alt_codon = ref_codon[: offset % 3] + curr_alt + ref_codon[offset % 3 + 1 :]
results.append(
[
chrom,
pos,
f"{chrom}_{pos + 1}_{curr_ref}_{curr_alt}",
curr_ref,
curr_alt,
rec_names[j],
cds_starts[j],
cds_ends[j],
cds_strands[j],
offset,
rec_cds[j],
ref_codon,
alt_codon,
translate(ref_codon),
translate(alt_codon),
]
)
if return_alt_cds:
alt_cds = rec_cds[j][:offset] + curr_alt + rec_cds[j][offset + 1 :]
results[-1].append(alt_cds)
columns = [
"chrom",
"pos",
"variant_id",
"ref",
"alt",
"tx_name",
"cdsStart",
"cdsEnd",
"tx_strand",
"var_rel_dist_in_cds",
"ref_seq",
"ref_codon",
"alt_codon",
"ref_aa",
"alt_aa",
]
if return_alt_cds:
columns.append("alt_seq")
if results:
results = pd.DataFrame(results)
results.columns = columns
results["pos"] += 1 # Convert back to 1-based
results["codon_position"] = results["var_rel_dist_in_cds"] // 3
else:
# Create empty DataFrame with correct columns
results = pd.DataFrame(columns=columns)
results["codon_position"] = pd.Series(dtype="int64")
return results
def plot_transcript_distribution(variants):
# Count transcripts per variant, then count how many variants fall in each bin
counts_pl = pl.from_pandas(variants).group_by("variant_id").count().rename({"count": "n_transcripts"})
hist_pl = counts_pl.group_by("n_transcripts").count().rename({"count": "n_variants"}).sort("n_transcripts")
df = hist_pl.to_pandas().astype({"n_transcripts": "int64", "n_variants": "int64"})
plt.figure(figsize=(20, 4))
ax = sns.barplot(data=df, x="n_transcripts", y="n_variants", color="#4c5a88")
ax.set_xlabel("Number of transcripts associated with a single variant")
ax.set_ylabel("Number of unique variants")
ax.set_title("Distribution of transcripts per variant")
ax.set_yscale("log")
ax.yaxis.set_major_locator(LogLocator(base=10)) # 10^k ticks # log-scale Y
for p in ax.patches:
h = p.get_height()
if h > 0:
ax.annotate(
f"{int(h):,}",
(p.get_x() + p.get_width() / 2, h),
ha="center",
va="bottom",
fontsize=9,
xytext=(0, 3),
textcoords="offset points",
)
sns.despine()
plt.tight_layout()
plt.show()
def check_mutation_positions(df: pd.DataFrame, adjusted_context_length: int) -> pd.DataFrame:
"""
Check if the mutation positions are within the bounds of the coding sequence.
adjusted_context_length = max_length - 2 (for [CLS] and [SEP])
"""
def _check(row):
ref_seq = row["ref_seq"]
codon_pos = int(row["codon_position"])
total_codons = (len(ref_seq) // 3) if isinstance(ref_seq, str) else 0
out = {
"id": row.get("id", None),
"variant_id": row.get("variant_id", None),
"total_codons": total_codons,
"codon_position": codon_pos,
"in_bounds": codon_pos < total_codons,
"needs_centering": total_codons > adjusted_context_length,
"out_of_bounds": codon_pos >= adjusted_context_length,
}
return pd.Series(out)
return df.apply(_check, axis=1)
def get_reverse_complement(seq):
"""Get reverse complement of a sequence"""
complement = {"A": "T", "T": "A", "G": "C", "C": "G", "N": "N"}
return "".join(complement[base] for base in seq[::-1].upper())
def extract_cds_sequence(row, fasta):
"""Extract CDS sequence for a transcript based on exon coordinates and CDS boundaries."""
chrom = row["chrom"]
strand = row["strand"]
cds_start = row["cdsStart"]
cds_end = row["cdsEnd"]
# Parse exon coordinates
exon_starts = [int(x) for x in row["exonStarts"].rstrip(",").split(",")]
exon_ends = [int(x) for x in row["exonEnds"].rstrip(",").split(",")]
# Extract CDS sequence from exons
cds_sequence = ""
for start, end in zip(exon_starts, exon_ends):
# Find overlap between exon and CDS
overlap_start = max(start, cds_start)
overlap_end = min(end, cds_end)
if overlap_start < overlap_end:
# Extract sequence from this exon segment
seq = str(fasta[chrom][overlap_start:overlap_end]).upper()
cds_sequence += seq
# Reverse complement if on negative strand
if strand == "-":
cds_sequence = get_reverse_complement(cds_sequence)
return cds_sequence
def process_dset(dset, refseq, remove_non_pli=False):
"""
Add additional features to the dataset including:
- Amino acid translations
- Codon frequency ratios
- Gene names and pLI scores
- PhyloP conservation scores
- CDS offset fractions
"""
# Add amino acid translations to the dataset
dset = dset.with_columns(
[
pl.col("ref_codon")
.map_elements(lambda x: str(Seq(x).translate()), return_dtype=pl.String)
.alias("ref_aa"),
pl.col("alt_codon")
.map_elements(lambda x: str(Seq(x).translate()), return_dtype=pl.String)
.alias("alt_aa"),
]
)
assert dset.filter(pl.col("ref_aa") != pl.col("alt_aa")).height == 0
dset = dset.filter(pl.col("ref_aa") != "*")
codon_freqs = json.load(open(f"{DATA_DIR}/codon_counts_nopathogen.json"))["Primates"]
dset = dset.with_columns(
pl.col("ref_codon").map_elements(lambda x: codon_freqs[x], return_dtype=pl.Float64).alias("ref_codon_freq")
)
dset = dset.with_columns(
pl.col("alt_codon").map_elements(lambda x: codon_freqs[x], return_dtype=pl.Float64).alias("alt_codon_freq")
)
dset = dset.with_columns((pl.col("ref_codon_freq") / pl.col("alt_codon_freq")).log().alias("codon_freq_ratio"))
tx_to_name = {row["name"]: row["name2"] for row in refseq.rows(named=True)}
if "gene_name" not in dset.columns:
dset = dset.with_columns(
pl.col("tx").map_elements(lambda x: tx_to_name[x], return_dtype=pl.String).alias("gene_name")
)
pli = pl.read_csv(f"{DATA_DIR}/ucsc_pliByGene_hg38.tsv", separator="\t")
gene_to_pli = {row["geneName"]: row["_pli"] for row in pli.rows(named=True)}
dset = dset.with_columns(
pl.col("gene_name").map_elements(lambda x: gene_to_pli.get(x, -1000), return_dtype=pl.Float64).alias("pli")
)
if remove_non_pli:
dset = dset.filter(pl.col("pli") != -1000)
dset = dset.with_columns((pl.col("pli") * 10).cast(pl.Int32).alias("pli_bin"))
bw = pyBigWig.open(f"{DATA_DIR}/hg38.phyloP447way.bw")
phylop = []
for row in tqdm(dset.rows(named=True)):
phylop.append(bw.values(row["chrom"], row["pos"] - 1, row["pos"])[0])
dset = dset.with_columns(pl.Series(values=phylop, name="phylop").fill_nan(-1000))
dset = dset.with_columns(pl.col("phylop").round().cast(pl.Int32).alias("phylop_bin"))
if "cds_offset_frac" not in dset.columns:
dset = dset.with_columns(pl.col("ref_seq").str.len_chars().alias("cds_length"))
dset = dset.with_columns((pl.col("var_rel_dist_in_cds") / pl.col("cds_length")).alias("cds_offset_frac"))
dset = dset.with_columns((pl.col("cds_offset_frac") * 10).cast(pl.Int32).alias("cds_offset_frac_bin"))
return dset