Skip to content

Dataset

AMPLIFYMaskedResidueDataset

Bases: Dataset

Dataset class for AMPLIFY pretraining that implements sampling of UR100P sequences.

Source code in bionemo/amplify/dataset.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class AMPLIFYMaskedResidueDataset(Dataset):
    """Dataset class for AMPLIFY pretraining that implements sampling of UR100P sequences."""

    def __init__(
        self,
        hf_dataset: HFAmplifyDataset,
        seed: int = 42,
        max_seq_length: int = 512,
        mask_prob: float = 0.15,
        mask_token_prob: float = 0.8,
        mask_random_prob: float = 0.1,
        random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.AMINO_ACIDS_ONLY,
        tokenizer: BioNeMoAMPLIFYTokenizer | None = None,
    ) -> None:
        """Initializes the dataset.

        Args:
            hf_dataset: HuggingFace dataset containing AMPLIFY protein sequences. This should likely be created via a
                call like `datasets.load_dataset("chandar-lab/UR100P", split="train")`.
            total_samples: Total number of samples to draw from the dataset.
            seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
                that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
                generated.
            max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens.
            mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15.
            mask_token_prob: Proportion of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
            mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.
            random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.AMINO_ACIDS_ONLY.
            tokenizer: The input AMPLIFY tokenizer. Defaults to the standard AMPLIFY tokenizer.
        """
        self.protein_dataset = hf_dataset
        self.total_samples = len(self.protein_dataset)
        self.seed = seed
        self.max_seq_length = max_seq_length
        self.random_mask_strategy = random_mask_strategy
        if tokenizer is None:
            self.tokenizer = BioNeMoAMPLIFYTokenizer()
        else:
            self.tokenizer = tokenizer

        if self.tokenizer.mask_token_id is None:
            raise ValueError("Tokenizer does not have a mask token.")

        self.mask_config = masking.BertMaskConfig(
            tokenizer=self.tokenizer,
            random_tokens=range(self.tokenizer.vocab_size)
            if self.random_mask_strategy == RandomMaskStrategy.ALL_TOKENS
            else range(6, self.tokenizer.vocab_size),
            mask_prob=mask_prob,
            mask_token_prob=mask_token_prob,
            random_token_prob=mask_random_prob,
        )

    def __len__(self) -> int:
        """Returns the total number of sequences in the dataset."""
        return self.total_samples

    def __getitem__(self, index: EpochIndex) -> BertSample:
        """Deterministically masks and returns a protein sequence from the dataset.

        This function is largely copied from the ESM2 dataset.

        Args:
            index: The current epoch and the index of the cluster to sample.

        Returns:
            A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields.
        """
        # Initialize a random number generator with a seed that is a combination of the dataset seed, epoch, and index.
        rng = np.random.default_rng([self.seed, index.epoch, index.idx])
        if index.idx >= len(self):
            raise IndexError(f"Index {index.idx} out of range [0, {len(self)}).")

        sequence = self.protein_dataset[int(index.idx)]["sequence"]

        # We don't want special tokens before we pass the input to the masking function; we add these in the collate_fn.
        tokenized_sequence = self._tokenize(sequence)

        # If the sequence is too long, we crop it to the max sequence length by randomly selecting a starting position.
        cropped_sequence = _random_crop(tokenized_sequence, self.max_seq_length, rng)

        # Get a single integer seed for torch from our rng, since the index tuple is hard to pass directly to torch.
        torch_seed = random_utils.get_seed_from_rng(rng)
        masked_sequence, labels, loss_mask = masking.apply_bert_pretraining_mask(
            tokenized_sequence=cropped_sequence,  # type: ignore
            random_seed=torch_seed,
            mask_config=self.mask_config,
        )

        return {
            "text": masked_sequence,
            "types": torch.zeros_like(masked_sequence, dtype=torch.int64),
            "attention_mask": torch.ones_like(masked_sequence, dtype=torch.int64),
            "labels": labels,
            "loss_mask": loss_mask,
            "is_random": torch.zeros_like(masked_sequence, dtype=torch.int64),
        }

    def _tokenize(self, sequence: str) -> torch.Tensor:
        """Tokenize a protein sequence.

        Args:
            sequence: The protein sequence.

        Returns:
            The tokenized sequence.
        """
        tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
        return tensor.flatten()  # type: ignore

__getitem__(index)

Deterministically masks and returns a protein sequence from the dataset.

This function is largely copied from the ESM2 dataset.

Parameters:

Name Type Description Default
index EpochIndex

The current epoch and the index of the cluster to sample.

required

Returns:

Type Description
BertSample

A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields.

Source code in bionemo/amplify/dataset.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def __getitem__(self, index: EpochIndex) -> BertSample:
    """Deterministically masks and returns a protein sequence from the dataset.

    This function is largely copied from the ESM2 dataset.

    Args:
        index: The current epoch and the index of the cluster to sample.

    Returns:
        A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields.
    """
    # Initialize a random number generator with a seed that is a combination of the dataset seed, epoch, and index.
    rng = np.random.default_rng([self.seed, index.epoch, index.idx])
    if index.idx >= len(self):
        raise IndexError(f"Index {index.idx} out of range [0, {len(self)}).")

    sequence = self.protein_dataset[int(index.idx)]["sequence"]

    # We don't want special tokens before we pass the input to the masking function; we add these in the collate_fn.
    tokenized_sequence = self._tokenize(sequence)

    # If the sequence is too long, we crop it to the max sequence length by randomly selecting a starting position.
    cropped_sequence = _random_crop(tokenized_sequence, self.max_seq_length, rng)

    # Get a single integer seed for torch from our rng, since the index tuple is hard to pass directly to torch.
    torch_seed = random_utils.get_seed_from_rng(rng)
    masked_sequence, labels, loss_mask = masking.apply_bert_pretraining_mask(
        tokenized_sequence=cropped_sequence,  # type: ignore
        random_seed=torch_seed,
        mask_config=self.mask_config,
    )

    return {
        "text": masked_sequence,
        "types": torch.zeros_like(masked_sequence, dtype=torch.int64),
        "attention_mask": torch.ones_like(masked_sequence, dtype=torch.int64),
        "labels": labels,
        "loss_mask": loss_mask,
        "is_random": torch.zeros_like(masked_sequence, dtype=torch.int64),
    }

__init__(hf_dataset, seed=42, max_seq_length=512, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=RandomMaskStrategy.AMINO_ACIDS_ONLY, tokenizer=None)

Initializes the dataset.

Parameters:

Name Type Description Default
hf_dataset HFAmplifyDataset

HuggingFace dataset containing AMPLIFY protein sequences. This should likely be created via a call like datasets.load_dataset("chandar-lab/UR100P", split="train").

required
total_samples

Total number of samples to draw from the dataset.

required
seed int

Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure that getitem is deterministic, but can be random across different runs. If None, a random seed is generated.

42
max_seq_length int

Crop long sequences to a maximum of this length, including BOS and EOS tokens.

512
mask_prob float

The overall probability a token is included in the loss function. Defaults to 0.15.

0.15
mask_token_prob float

Proportion of masked tokens that get assigned the id. Defaults to 0.8.

0.8
mask_random_prob float

Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.

0.1
random_mask_strategy RandomMaskStrategy

Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.AMINO_ACIDS_ONLY.

AMINO_ACIDS_ONLY
tokenizer BioNeMoAMPLIFYTokenizer | None

The input AMPLIFY tokenizer. Defaults to the standard AMPLIFY tokenizer.

None
Source code in bionemo/amplify/dataset.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def __init__(
    self,
    hf_dataset: HFAmplifyDataset,
    seed: int = 42,
    max_seq_length: int = 512,
    mask_prob: float = 0.15,
    mask_token_prob: float = 0.8,
    mask_random_prob: float = 0.1,
    random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.AMINO_ACIDS_ONLY,
    tokenizer: BioNeMoAMPLIFYTokenizer | None = None,
) -> None:
    """Initializes the dataset.

    Args:
        hf_dataset: HuggingFace dataset containing AMPLIFY protein sequences. This should likely be created via a
            call like `datasets.load_dataset("chandar-lab/UR100P", split="train")`.
        total_samples: Total number of samples to draw from the dataset.
        seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
            that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
            generated.
        max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens.
        mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15.
        mask_token_prob: Proportion of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
        mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.
        random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.AMINO_ACIDS_ONLY.
        tokenizer: The input AMPLIFY tokenizer. Defaults to the standard AMPLIFY tokenizer.
    """
    self.protein_dataset = hf_dataset
    self.total_samples = len(self.protein_dataset)
    self.seed = seed
    self.max_seq_length = max_seq_length
    self.random_mask_strategy = random_mask_strategy
    if tokenizer is None:
        self.tokenizer = BioNeMoAMPLIFYTokenizer()
    else:
        self.tokenizer = tokenizer

    if self.tokenizer.mask_token_id is None:
        raise ValueError("Tokenizer does not have a mask token.")

    self.mask_config = masking.BertMaskConfig(
        tokenizer=self.tokenizer,
        random_tokens=range(self.tokenizer.vocab_size)
        if self.random_mask_strategy == RandomMaskStrategy.ALL_TOKENS
        else range(6, self.tokenizer.vocab_size),
        mask_prob=mask_prob,
        mask_token_prob=mask_token_prob,
        random_token_prob=mask_random_prob,
    )

__len__()

Returns the total number of sequences in the dataset.

Source code in bionemo/amplify/dataset.py
96
97
98
def __len__(self) -> int:
    """Returns the total number of sequences in the dataset."""
    return self.total_samples

HFAmplifyDataset

Bases: Protocol

Protocol for HuggingFace datasets containing AMPLIFY protein sequences.

Source code in bionemo/amplify/dataset.py
37
38
39
40
class HFAmplifyDataset(Protocol):
    """Protocol for HuggingFace datasets containing AMPLIFY protein sequences."""

    def __getitem__(self, index: int) -> HFDatasetRow: ...  # noqa: D105

HFDatasetRow

Bases: TypedDict

TypedDict for HuggingFace dataset rows.

Source code in bionemo/amplify/dataset.py
31
32
33
34
class HFDatasetRow(TypedDict):
    """TypedDict for HuggingFace dataset rows."""

    sequence: str