Skip to content

Dataset

ESMMaskedResidueDataset

Bases: Dataset

Dataset class for ESM pretraining that implements cluster sampling of UniRef50 and UniRef90 sequences.

Megatron-LM expects the input datasets to be indexable, and for the output of the dataset for a given index to be deterministic. In cluster sampling, this can be tricky, since we need to perform weighted sampling over UniRef50 clusters.

Here, the getitem(i) returns a randomly sampled UniRef90 sequence from the i % len(dataset) UniRef50 cluster, with i controlling the random seed used for selecting the UniRef90 sequence and performing the masking.

Multi-epoch training

Currently, this class owns the logic for upsampling proteins for multi-epoch training by directly passing a total_samples that's larger than the number of clusters provided. This is done because megatron training assumes that dataset[i] will always return the exact same tensors in distributed training. Because the we want to vary mask patterns and cluster sampling each time a given cluster is sampled, we create our own pseudo-epochs inside the dataset itself. Eventually we'd like to move away from this paradigm and allow multi-epoch training to vary the dataset's random state through a callback, and allow megatron samplers to handle the epoch-to-epoch shuffling of sample order.

Source code in bionemo/esm2/data/dataset.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class ESMMaskedResidueDataset(Dataset):
    """Dataset class for ESM pretraining that implements cluster sampling of UniRef50 and UniRef90 sequences.

    Megatron-LM expects the input datasets to be indexable, and for the output of the dataset for a given index to be
    deterministic. In cluster sampling, this can be tricky, since we need to perform weighted sampling over UniRef50
    clusters.

    Here, the getitem(i) returns a randomly sampled UniRef90 sequence from the i % len(dataset) UniRef50 cluster, with i
    controlling the random seed used for selecting the UniRef90 sequence and performing the masking.

    !!! note "Multi-epoch training"

        Currently, this class owns the logic for upsampling proteins for multi-epoch training by directly passing a
        total_samples that's larger than the number of clusters provided. This is done because megatron training assumes
        that `dataset[i]` will always return the exact same tensors in distributed training. Because the we want to vary
        mask patterns and cluster sampling each time a given cluster is sampled, we create our own pseudo-epochs inside
        the dataset itself. Eventually we'd like to move away from this paradigm and allow multi-epoch training to vary
        the dataset's random state through a callback, and allow megatron samplers to handle the epoch-to-epoch
        shuffling of sample order.

    """

    def __init__(
        self,
        protein_dataset: Dataset,
        clusters: Sequence[Sequence[str]],
        seed: int = np.random.SeedSequence().entropy,  # type: ignore
        max_seq_length: int = 1024,
        mask_prob: float = 0.15,
        mask_token_prob: float = 0.8,
        mask_random_prob: float = 0.1,
        random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS,
        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
    ) -> None:
        """Initializes the dataset.

        Args:
            protein_dataset: Dataset containing protein sequences, indexed by UniRef90 ids.
            clusters: UniRef90 ids for all training sequences, bucketed by UniRef50 cluster. Alternatively for
                validation, this can also just a list of UniRef50 ids, with each entry being a length-1 list with a
                single UniRef50 id.
            total_samples: Total number of samples to draw from the dataset.
            seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
                that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
                generated.
            max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens.
            mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15.
            mask_token_prob: Proportion of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
            mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.
            random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.
            tokenizer: The input ESM tokenizer. Defaults to the standard ESM tokenizer.
        """
        self.protein_dataset = protein_dataset
        self.clusters = clusters
        self.seed = seed
        self.max_seq_length = max_seq_length
        self.random_mask_strategy = random_mask_strategy

        if tokenizer.mask_token_id is None:
            raise ValueError("Tokenizer does not have a mask token.")

        self.mask_config = masking.BertMaskConfig(
            tokenizer=tokenizer,
            random_tokens=range(len(tokenizer.all_tokens))
            if self.random_mask_strategy == RandomMaskStrategy.ALL_TOKENS
            else range(4, 24),
            mask_prob=mask_prob,
            mask_token_prob=mask_token_prob,
            random_token_prob=mask_random_prob,
        )

        self.tokenizer = tokenizer

    def __len__(self) -> int:
        """Returns the number of clusters, which constitutes a single epoch."""
        return len(self.clusters)

    def __getitem__(self, index: EpochIndex) -> BertSample:
        """Deterministically masks and returns a protein sequence from the dataset.

        This method samples from the i % len(dataset) cluster from the input clusters list. Random draws of the same
        cluster can be achieved by calling this method with i + len(dataset), i.e., wrapping around the dataset length.

        Args:
            index: The current epoch and the index of the cluster to sample.

        Returns:
            A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields.
        """
        # Initialize a random number generator with a seed that is a combination of the dataset seed, epoch, and index.
        rng = np.random.default_rng([self.seed, index.epoch, index.idx])
        if not len(self.clusters[index.idx]):
            raise ValueError(f"Cluster {index.idx} is empty.")

        sequence_id = rng.choice(self.clusters[index.idx])
        sequence = self.protein_dataset[sequence_id]

        # We don't want special tokens before we pass the input to the masking function; we add these in the collate_fn.
        tokenized_sequence = self._tokenize(sequence)
        cropped_sequence = _random_crop(tokenized_sequence, self.max_seq_length, rng)

        # Get a single integer seed for torch from our rng, since the index tuple is hard to pass directly to torch.
        torch_seed = random_utils.get_seed_from_rng(rng)
        masked_sequence, labels, loss_mask = masking.apply_bert_pretraining_mask(
            tokenized_sequence=cropped_sequence,  # type: ignore
            random_seed=torch_seed,
            mask_config=self.mask_config,
        )

        return {
            "text": masked_sequence,
            "types": torch.zeros_like(masked_sequence, dtype=torch.int64),
            "attention_mask": torch.ones_like(masked_sequence, dtype=torch.int64),
            "labels": labels,
            "loss_mask": loss_mask,
            "is_random": torch.zeros_like(masked_sequence, dtype=torch.int64),
        }

    def _tokenize(self, sequence: str) -> torch.Tensor:
        """Tokenize a protein sequence.

        Args:
            sequence: The protein sequence.

        Returns:
            The tokenized sequence.
        """
        tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
        return tensor.flatten()  # type: ignore

__getitem__(index)

Deterministically masks and returns a protein sequence from the dataset.

This method samples from the i % len(dataset) cluster from the input clusters list. Random draws of the same cluster can be achieved by calling this method with i + len(dataset), i.e., wrapping around the dataset length.

Parameters:

Name Type Description Default
index EpochIndex

The current epoch and the index of the cluster to sample.

required

Returns:

Type Description
BertSample

A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields.

Source code in bionemo/esm2/data/dataset.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def __getitem__(self, index: EpochIndex) -> BertSample:
    """Deterministically masks and returns a protein sequence from the dataset.

    This method samples from the i % len(dataset) cluster from the input clusters list. Random draws of the same
    cluster can be achieved by calling this method with i + len(dataset), i.e., wrapping around the dataset length.

    Args:
        index: The current epoch and the index of the cluster to sample.

    Returns:
        A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields.
    """
    # Initialize a random number generator with a seed that is a combination of the dataset seed, epoch, and index.
    rng = np.random.default_rng([self.seed, index.epoch, index.idx])
    if not len(self.clusters[index.idx]):
        raise ValueError(f"Cluster {index.idx} is empty.")

    sequence_id = rng.choice(self.clusters[index.idx])
    sequence = self.protein_dataset[sequence_id]

    # We don't want special tokens before we pass the input to the masking function; we add these in the collate_fn.
    tokenized_sequence = self._tokenize(sequence)
    cropped_sequence = _random_crop(tokenized_sequence, self.max_seq_length, rng)

    # Get a single integer seed for torch from our rng, since the index tuple is hard to pass directly to torch.
    torch_seed = random_utils.get_seed_from_rng(rng)
    masked_sequence, labels, loss_mask = masking.apply_bert_pretraining_mask(
        tokenized_sequence=cropped_sequence,  # type: ignore
        random_seed=torch_seed,
        mask_config=self.mask_config,
    )

    return {
        "text": masked_sequence,
        "types": torch.zeros_like(masked_sequence, dtype=torch.int64),
        "attention_mask": torch.ones_like(masked_sequence, dtype=torch.int64),
        "labels": labels,
        "loss_mask": loss_mask,
        "is_random": torch.zeros_like(masked_sequence, dtype=torch.int64),
    }

__init__(protein_dataset, clusters, seed=np.random.SeedSequence().entropy, max_seq_length=1024, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer())

Initializes the dataset.

Parameters:

Name Type Description Default
protein_dataset Dataset

Dataset containing protein sequences, indexed by UniRef90 ids.

required
clusters Sequence[Sequence[str]]

UniRef90 ids for all training sequences, bucketed by UniRef50 cluster. Alternatively for validation, this can also just a list of UniRef50 ids, with each entry being a length-1 list with a single UniRef50 id.

required
total_samples

Total number of samples to draw from the dataset.

required
seed int

Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure that getitem is deterministic, but can be random across different runs. If None, a random seed is generated.

entropy
max_seq_length int

Crop long sequences to a maximum of this length, including BOS and EOS tokens.

1024
mask_prob float

The overall probability a token is included in the loss function. Defaults to 0.15.

0.15
mask_token_prob float

Proportion of masked tokens that get assigned the id. Defaults to 0.8.

0.8
mask_random_prob float

Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.

0.1
random_mask_strategy RandomMaskStrategy

Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.

ALL_TOKENS
tokenizer BioNeMoESMTokenizer

The input ESM tokenizer. Defaults to the standard ESM tokenizer.

get_tokenizer()
Source code in bionemo/esm2/data/dataset.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def __init__(
    self,
    protein_dataset: Dataset,
    clusters: Sequence[Sequence[str]],
    seed: int = np.random.SeedSequence().entropy,  # type: ignore
    max_seq_length: int = 1024,
    mask_prob: float = 0.15,
    mask_token_prob: float = 0.8,
    mask_random_prob: float = 0.1,
    random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS,
    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
) -> None:
    """Initializes the dataset.

    Args:
        protein_dataset: Dataset containing protein sequences, indexed by UniRef90 ids.
        clusters: UniRef90 ids for all training sequences, bucketed by UniRef50 cluster. Alternatively for
            validation, this can also just a list of UniRef50 ids, with each entry being a length-1 list with a
            single UniRef50 id.
        total_samples: Total number of samples to draw from the dataset.
        seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
            that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
            generated.
        max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens.
        mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15.
        mask_token_prob: Proportion of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
        mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.
        random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.
        tokenizer: The input ESM tokenizer. Defaults to the standard ESM tokenizer.
    """
    self.protein_dataset = protein_dataset
    self.clusters = clusters
    self.seed = seed
    self.max_seq_length = max_seq_length
    self.random_mask_strategy = random_mask_strategy

    if tokenizer.mask_token_id is None:
        raise ValueError("Tokenizer does not have a mask token.")

    self.mask_config = masking.BertMaskConfig(
        tokenizer=tokenizer,
        random_tokens=range(len(tokenizer.all_tokens))
        if self.random_mask_strategy == RandomMaskStrategy.ALL_TOKENS
        else range(4, 24),
        mask_prob=mask_prob,
        mask_token_prob=mask_token_prob,
        random_token_prob=mask_random_prob,
    )

    self.tokenizer = tokenizer

__len__()

Returns the number of clusters, which constitutes a single epoch.

Source code in bionemo/esm2/data/dataset.py
165
166
167
def __len__(self) -> int:
    """Returns the number of clusters, which constitutes a single epoch."""
    return len(self.clusters)

ProteinSQLiteDataset

Bases: Dataset

Dataset for protein sequences stored in a SQLite database.

Source code in bionemo/esm2/data/dataset.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class ProteinSQLiteDataset(Dataset):
    """Dataset for protein sequences stored in a SQLite database."""

    def __init__(self, db_path: str | os.PathLike):
        """Initializes the dataset.

        Args:
            db_path: Path to the SQLite database.
        """
        self.conn = sqlite3.connect(str(db_path))
        self.cursor = self.conn.cursor()
        self._len = None

    def __len__(self) -> int:
        """Returns the number of proteins in the dataset.

        Returns:
            Number of proteins in the dataset.
        """
        if self._len is None:
            self.cursor.execute("SELECT COUNT(*) FROM protein")
            self._len = int(self.cursor.fetchone()[0])
        return self._len

    def __getitem__(self, idx: str) -> str:
        """Returns the sequence of a protein at a given index.

        TODO: This method may want to support batched indexing for improved performance.

        Args:
            idx: An identifier for the protein sequence. For training data, these are UniRef90 IDs, while for validation
                data, they are UniRef50 IDs.

        Returns:
            The protein sequence as a string.
        """
        if not isinstance(idx, str):
            raise TypeError(f"Expected string, got {type(idx)}: {idx}.")

        self.cursor.execute("SELECT sequence FROM protein WHERE id = ?", (idx,))
        return self.cursor.fetchone()[0]

__getitem__(idx)

Returns the sequence of a protein at a given index.

TODO: This method may want to support batched indexing for improved performance.

Parameters:

Name Type Description Default
idx str

An identifier for the protein sequence. For training data, these are UniRef90 IDs, while for validation data, they are UniRef50 IDs.

required

Returns:

Type Description
str

The protein sequence as a string.

Source code in bionemo/esm2/data/dataset.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __getitem__(self, idx: str) -> str:
    """Returns the sequence of a protein at a given index.

    TODO: This method may want to support batched indexing for improved performance.

    Args:
        idx: An identifier for the protein sequence. For training data, these are UniRef90 IDs, while for validation
            data, they are UniRef50 IDs.

    Returns:
        The protein sequence as a string.
    """
    if not isinstance(idx, str):
        raise TypeError(f"Expected string, got {type(idx)}: {idx}.")

    self.cursor.execute("SELECT sequence FROM protein WHERE id = ?", (idx,))
    return self.cursor.fetchone()[0]

__init__(db_path)

Initializes the dataset.

Parameters:

Name Type Description Default
db_path str | PathLike

Path to the SQLite database.

required
Source code in bionemo/esm2/data/dataset.py
52
53
54
55
56
57
58
59
60
def __init__(self, db_path: str | os.PathLike):
    """Initializes the dataset.

    Args:
        db_path: Path to the SQLite database.
    """
    self.conn = sqlite3.connect(str(db_path))
    self.cursor = self.conn.cursor()
    self._len = None

__len__()

Returns the number of proteins in the dataset.

Returns:

Type Description
int

Number of proteins in the dataset.

Source code in bionemo/esm2/data/dataset.py
62
63
64
65
66
67
68
69
70
71
def __len__(self) -> int:
    """Returns the number of proteins in the dataset.

    Returns:
        Number of proteins in the dataset.
    """
    if self._len is None:
        self.cursor.execute("SELECT COUNT(*) FROM protein")
        self._len = int(self.cursor.fetchone()[0])
    return self._len

RandomMaskStrategy

Bases: str, Enum

Enum for different random masking strategies.

In ESM2 pretraining, 15% of all tokens are masked and among which 10% are replaced with a random token. This class controls the set of random tokens to choose from.

Source code in bionemo/esm2/data/dataset.py
35
36
37
38
39
40
41
42
43
44
45
46
class RandomMaskStrategy(str, Enum):
    """Enum for different random masking strategies.

    In ESM2 pretraining, 15% of all tokens are masked and among which 10% are replaced with a random token. This class controls the set of random tokens to choose from.

    """

    AMINO_ACIDS_ONLY = "amino_acids_only"
    """Mask only with amino acid tokens."""

    ALL_TOKENS = "all_tokens"
    """Mask with all tokens in the tokenizer, including special tokens, padding and non-canonical amino acid tokens."""

ALL_TOKENS = 'all_tokens' class-attribute instance-attribute

Mask with all tokens in the tokenizer, including special tokens, padding and non-canonical amino acid tokens.

AMINO_ACIDS_ONLY = 'amino_acids_only' class-attribute instance-attribute

Mask only with amino acid tokens.

create_train_dataset(cluster_file, db_path, total_samples, seed, max_seq_length=1024, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer())

Creates a training dataset for ESM pretraining.

Parameters:

Name Type Description Default
cluster_file str | PathLike

Path to the cluster file. The file should contain a "ur90_id" column, where each row contains a list of UniRef90 ids for a single UniRef50 cluster.

required
db_path str | PathLike

Path to the SQLite database.

required
total_samples int

Total number of samples to draw from the dataset.

required
seed int

Random seed for reproducibility.

required
max_seq_length int

Crop long sequences to a maximum of this length, including BOS and EOS tokens.

1024
mask_prob float

The overall probability a token is included in the loss function. Defaults to 0.15.

0.15
mask_token_prob float

Proportion of masked tokens that get assigned the id. Defaults to 0.8.

0.8
mask_random_prob float

Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.

0.1
random_mask_strategy RandomMaskStrategy

Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.

ALL_TOKENS
tokenizer BioNeMoESMTokenizer

The input ESM tokenizer. Defaults to the standard ESM tokenizer.

get_tokenizer()

Returns:

Type Description

A dataset for ESM pretraining.

Raises:

Type Description
ValueError

If the cluster file does not exist, the database file does not exist, or the cluster file does not contain a "ur90_id" column.

Source code in bionemo/esm2/data/dataset.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def create_train_dataset(
    cluster_file: str | os.PathLike,
    db_path: str | os.PathLike,
    total_samples: int,
    seed: int,
    max_seq_length: int = 1024,
    mask_prob: float = 0.15,
    mask_token_prob: float = 0.8,
    mask_random_prob: float = 0.1,
    random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS,
    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
):
    """Creates a training dataset for ESM pretraining.

    Args:
        cluster_file: Path to the cluster file. The file should contain a "ur90_id" column, where each row contains a
            list of UniRef90 ids for a single UniRef50 cluster.
        db_path: Path to the SQLite database.
        total_samples: Total number of samples to draw from the dataset.
        seed: Random seed for reproducibility.
        max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens.
        mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15.
        mask_token_prob: Proportion of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
        mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.
        random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.
        tokenizer: The input ESM tokenizer. Defaults to the standard ESM tokenizer.

    Returns:
        A dataset for ESM pretraining.

    Raises:
        ValueError: If the cluster file does not exist, the database file does not exist, or the cluster file does not
            contain a "ur90_id" column.
    """
    if not Path(cluster_file).exists():
        raise ValueError(f"Cluster file {cluster_file} not found.")

    if not Path(db_path).exists():
        raise ValueError(f"Database file {db_path} not found.")

    cluster_df = pd.read_parquet(cluster_file)
    if "ur90_id" not in cluster_df.columns:
        raise ValueError(f"Training cluster file must contain a 'ur90_id' column. Found columns {cluster_df.columns}.")

    protein_dataset = ProteinSQLiteDataset(db_path)
    masked_cluster_dataset = ESMMaskedResidueDataset(
        protein_dataset=protein_dataset,
        clusters=cluster_df["ur90_id"],
        seed=seed,
        max_seq_length=max_seq_length,
        mask_prob=mask_prob,
        mask_token_prob=mask_token_prob,
        mask_random_prob=mask_random_prob,
        random_mask_strategy=random_mask_strategy,
        tokenizer=tokenizer,
    )

    return MultiEpochDatasetResampler(masked_cluster_dataset, num_samples=total_samples, shuffle=True, seed=seed)

create_valid_clusters(cluster_file)

Create a pandas series of UniRef50 cluster IDs from a cluster parquet file.

Parameters:

Name Type Description Default
cluster_file str | PathLike

Path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50

required

Returns:

Type Description
Series

A pandas series of UniRef50 cluster IDs.

Source code in bionemo/esm2/data/dataset.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def create_valid_clusters(cluster_file: str | os.PathLike) -> pd.Series:
    """Create a pandas series of UniRef50 cluster IDs from a cluster parquet file.

    Args:
        cluster_file: Path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50
        IDs, with one UniRef50 ID per row.

    Returns:
        A pandas series of UniRef50 cluster IDs.
    """
    if not Path(cluster_file).exists():
        raise ValueError(f"Cluster file {cluster_file} not found.")

    cluster_df = pd.read_parquet(cluster_file)
    if "ur50_id" not in cluster_df.columns:
        raise ValueError(
            f"Validation cluster file must contain a 'ur50_id' column. Found columns {cluster_df.columns}."
        )
    clusters = cluster_df["ur50_id"].apply(lambda x: [x])
    return clusters

create_valid_dataset(clusters, db_path, seed, total_samples=None, max_seq_length=1024, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer())

Creates a validation dataset for ESM pretraining.

Parameters:

Name Type Description Default
cluster_file

Clusters as pd.Series, or path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50 IDs, with one UniRef50 ID per row.

required
db_path str | PathLike

Path to the SQLite database.

required
total_samples int | None

Total number of samples to draw from the dataset.

None
seed int

Random seed for reproducibility.

required
max_seq_length int

Crop long sequences to a maximum of this length, including BOS and EOS tokens.

1024
mask_prob float

The overall probability a token is included in the loss function. Defaults to 0.15.

0.15
mask_token_prob float

Proportion of masked tokens that get assigned the id. Defaults to 0.8.

0.8
mask_random_prob float

Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.

0.1
random_masking_strategy

Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.

required

Raises:

Type Description
ValueError

If the cluster file does not exist, the database file does not exist, or the cluster file does not contain a "ur50_id" column.

Source code in bionemo/esm2/data/dataset.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def create_valid_dataset(  # noqa: D417
    clusters: pd.Series | str | os.PathLike,
    db_path: str | os.PathLike,
    seed: int,
    total_samples: int | None = None,
    max_seq_length: int = 1024,
    mask_prob: float = 0.15,
    mask_token_prob: float = 0.8,
    mask_random_prob: float = 0.1,
    random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS,
    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
):
    """Creates a validation dataset for ESM pretraining.

    Args:
        cluster_file: Clusters as pd.Series, or path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50
            IDs, with one UniRef50 ID per row.
        db_path: Path to the SQLite database.
        total_samples: Total number of samples to draw from the dataset.
        seed: Random seed for reproducibility.
        max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens.
        mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15.
        mask_token_prob: Proportion of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
        mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.
        random_masking_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.

    Raises:
        ValueError: If the cluster file does not exist, the database file does not exist, or the cluster file does not
            contain a "ur50_id" column.
    """
    if isinstance(clusters, (str, os.PathLike)):
        clusters = create_valid_clusters(clusters)

    elif not isinstance(clusters, pd.Series):
        raise ValueError(f"Clusters must be a pandas Series. Got {type(clusters)}.")

    if not Path(db_path).exists():
        raise ValueError(f"Database file {db_path} not found.")

    protein_dataset = ProteinSQLiteDataset(db_path)
    masked_dataset = ESMMaskedResidueDataset(
        protein_dataset=protein_dataset,
        clusters=clusters,
        seed=seed,
        max_seq_length=max_seq_length,
        mask_prob=mask_prob,
        mask_token_prob=mask_token_prob,
        mask_random_prob=mask_random_prob,
        random_mask_strategy=random_mask_strategy,
        tokenizer=tokenizer,
    )

    return MultiEpochDatasetResampler(masked_dataset, num_samples=total_samples, shuffle=True, seed=seed)