Skip to content

Fasta dataset

SimpleFastaDataset

Bases: Dataset

A simple dataset for Evo2 prediction.

Currently, this will not work for pre-training or fine-tuning, as that would require: 1) including "labels" in the input and 2) offsetting/rolling either the labels or input_ids to handle the off-by-one token prediction alignment.

Source code in bionemo/evo2/data/fasta_dataset.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class SimpleFastaDataset(torch.utils.data.Dataset):
    """A simple dataset for Evo2 prediction.

    Currently, this will not work for pre-training or fine-tuning, as that would require:
    1) including "labels" in the input and 2) offsetting/rolling either the labels or
    input_ids to handle the off-by-one token prediction alignment.
    """

    def __init__(self, fasta_path: Path, tokenizer, prepend_bos: bool = True):
        """Initialize the dataset."""
        super().__init__()
        self.fasta = NvFaidx(fasta_path)
        self.seqids = sorted(self.fasta.keys())
        self.tokenizer = tokenizer
        self.prepend_bos = prepend_bos  # needed for getting predictions for the requested set of tokens.

    def write_idx_map(self, output_dir: Path):
        """Write the index map to the output directory."""
        with open(output_dir / "seq_idx_map.json", "w") as f:
            json.dump({seqid: idx for idx, seqid in enumerate(self.seqids)}, f)

    def __len__(self):
        """Get the length of the dataset."""
        return len(self.seqids)

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        """Get an item from the dataset."""
        sequence = self.fasta[self.seqids[idx]].sequence().upper()
        tokenized_seq = self.tokenizer.text_to_ids(sequence)
        if self.prepend_bos:  # in pretraining we use EOS to start new sequences.
            tokens: list[int] = [self.tokenizer.eod] + tokenized_seq
        else:
            tokens: list[int] = tokenized_seq
        loss_mask = torch.ones_like(torch.tensor(tokens, dtype=torch.long), dtype=torch.long)
        if self.prepend_bos:
            loss_mask[0] = (
                0  # mask the eos token which we use for causal offsetting. Later in predict we take the output
            )
            #  for the first [:-1] tokens which align with the sequence starting after the EOS.
        return {
            "tokens": torch.tensor(tokens, dtype=torch.long),
            "position_ids": torch.arange(len(tokens), dtype=torch.long),
            "seq_idx": torch.tensor(idx, dtype=torch.long),
            "loss_mask": loss_mask,
        }

__getitem__(idx)

Get an item from the dataset.

Source code in bionemo/evo2/data/fasta_dataset.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
    """Get an item from the dataset."""
    sequence = self.fasta[self.seqids[idx]].sequence().upper()
    tokenized_seq = self.tokenizer.text_to_ids(sequence)
    if self.prepend_bos:  # in pretraining we use EOS to start new sequences.
        tokens: list[int] = [self.tokenizer.eod] + tokenized_seq
    else:
        tokens: list[int] = tokenized_seq
    loss_mask = torch.ones_like(torch.tensor(tokens, dtype=torch.long), dtype=torch.long)
    if self.prepend_bos:
        loss_mask[0] = (
            0  # mask the eos token which we use for causal offsetting. Later in predict we take the output
        )
        #  for the first [:-1] tokens which align with the sequence starting after the EOS.
    return {
        "tokens": torch.tensor(tokens, dtype=torch.long),
        "position_ids": torch.arange(len(tokens), dtype=torch.long),
        "seq_idx": torch.tensor(idx, dtype=torch.long),
        "loss_mask": loss_mask,
    }

__init__(fasta_path, tokenizer, prepend_bos=True)

Initialize the dataset.

Source code in bionemo/evo2/data/fasta_dataset.py
36
37
38
39
40
41
42
def __init__(self, fasta_path: Path, tokenizer, prepend_bos: bool = True):
    """Initialize the dataset."""
    super().__init__()
    self.fasta = NvFaidx(fasta_path)
    self.seqids = sorted(self.fasta.keys())
    self.tokenizer = tokenizer
    self.prepend_bos = prepend_bos  # needed for getting predictions for the requested set of tokens.

__len__()

Get the length of the dataset.

Source code in bionemo/evo2/data/fasta_dataset.py
49
50
51
def __len__(self):
    """Get the length of the dataset."""
    return len(self.seqids)

write_idx_map(output_dir)

Write the index map to the output directory.

Source code in bionemo/evo2/data/fasta_dataset.py
44
45
46
47
def write_idx_map(self, output_dir: Path):
    """Write the index map to the output directory."""
    with open(output_dir / "seq_idx_map.json", "w") as f:
        json.dump({seqid: idx for idx, seqid in enumerate(self.seqids)}, f)