Skip to content

Datamodule

ESM2FineTuneDataModule

Bases: MegatronDataModule

A PyTorch Lightning DataModule for fine-tuning ESM2 models.

This DataModule is designed to handle the data preparation and loading for fine-tuning ESM2 models. It provides a flexible way to create and manage datasets, data loaders, and sampling strategies.

Source code in bionemo/esm2/model/finetune/datamodule.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
class ESM2FineTuneDataModule(MegatronDataModule):
    """A PyTorch Lightning DataModule for fine-tuning ESM2 models.

    This DataModule is designed to handle the data preparation and loading for fine-tuning ESM2 models.
    It provides a flexible way to create and manage datasets, data loaders, and sampling strategies.
    """

    def __init__(
        self,
        train_dataset: DATASET_TYPES = None,
        valid_dataset: DATASET_TYPES = None,
        predict_dataset: DATASET_TYPES = None,
        seed: int = 42,
        min_seq_length: int | None = None,
        max_seq_length: int = 1024,
        micro_batch_size: int = 4,
        global_batch_size: int = 8,
        num_workers: int = 2,
        persistent_workers: bool = True,
        pin_memory: bool = True,
        rampup_batch_size: list[int] | None = None,
        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
    ) -> None:
        """Initialize the ESM2FineTuneDataModule.

        Args:
            train_dataset: The training dataset.
            valid_dataset: The validation dataset.
            predict_dataset: The prediction dataset. Should not be set together with train/valid datasets
            seed: The random seed to use for shuffling the datasets. Defaults to 42.
            min_seq_length: The minimum sequence length for the datasets. Defaults to None.
            max_seq_length: The maximum sequence length for the datasets. Defaults to 1024.
            micro_batch_size: The micro-batch size for the data loader. Defaults to 4.
            global_batch_size: The global batch size for the data loader. Defaults to 8.
            num_workers: The number of worker processes for the data loader. Defaults to 10.
            persistent_workers: Whether to persist the worker processes. Defaults to True.
            pin_memory: Whether to pin the data in memory. Defaults to True.
            rampup_batch_size: The batch size ramp-up schedule. Defaults to None.
            tokenizer: The tokenizer to use for tokenization. Defaults to the BioNeMoESMTokenizer.

        Returns:
            None
        """
        super().__init__()
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.predict_dataset = predict_dataset
        if predict_dataset is not None:
            assert train_dataset is None, "Datamodule expects either trin/valid dataset or predict dataset"
        self._seed = seed
        self._min_seq_length = min_seq_length
        self._max_seq_length = max_seq_length
        self._tokenizer = tokenizer

        self._micro_batch_size = micro_batch_size
        self._num_workers = num_workers
        self._persistent_workers = persistent_workers
        self._pin_memory = pin_memory

        self.data_sampler = MegatronDataSampler(
            seq_len=max_seq_length,
            micro_batch_size=micro_batch_size,
            global_batch_size=global_batch_size,
            dataloader_type="single",  # `MegatronPretrainingRandomSampler` from "cyclic" is failing.
            rampup_batch_size=rampup_batch_size,
            output_log=predict_dataset is None,  # logging does not work with predict step
        )

    def setup(self, stage: str) -> None:
        """Setup the ESMDataModule.

        Args:
            stage: Unused.

        Raises:
            RuntimeError: If the trainer is not attached, or if the trainer's max_steps is not set.
        """
        del stage  # Unused.

        if not hasattr(self, "trainer") or self.trainer is None:
            raise RuntimeError("Setup should be completed when trainer and config are attached.")

        if self.trainer.max_epochs is not None and self.trainer.max_epochs > 1:
            logging.warning(
                "Trainer is set to run for multiple epochs. This is not recommended due to the same shuffle being used "
                "in each. Instead set max_epochs to 1 and increase the number of max_steps."
            )

        # Create training dataset
        if self.train_dataset is not None:
            max_train_steps = self.trainer.max_steps
            if max_train_steps <= 0:
                raise RuntimeError("Please specify trainer.max_steps")

            num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size)
            self._train_ds = self._create_epoch_based_dataset(self.train_dataset, num_train_samples)

        # Create validation dataset
        if self.valid_dataset is not None and self.trainer.limit_val_batches != 0:
            num_val_samples = infer_num_samples(
                limit_batches=self.trainer.limit_val_batches,
                num_samples_in_dataset=len(self.valid_dataset),
                global_batch_size=self.data_sampler.global_batch_size,
                stage="val",
            )
            self._valid_ds = self._create_epoch_based_dataset(self.valid_dataset, num_val_samples)

        assert (
            hasattr(self, "trainer") and self.trainer is not None
        ), "Setup should be completed when trainer and config are attached."

    def _create_epoch_based_dataset(
        self,
        dataset: InMemoryPerTokenValueDataset | InMemorySingleValueDataset,
        total_samples: int,
    ):
        return MultiEpochDatasetResampler(
            IdentityMultiEpochDatasetWrapper(dataset),
            num_samples=total_samples,
            shuffle=self.predict_dataset is None,
            seed=self._seed,
        )

    def _create_dataloader(self, dataset, mode: Mode, **kwargs) -> WrappedDataLoader:
        """Create dataloader for train, validation, and test stages.

        Args:
            dataset: The dataset to create the dataloader for.
            mode: Stage of training, which is used to determined if consumed_samples in MegatronPretrainingSampler should be initialized to 0 (validation/test), or be set to the previous value from state_dict in case of checkpoint resumption (train).
            **kwargs: Additional arguments to pass to the dataloader.
        """
        if mode not in ["predict", "test"]:
            self.update_init_global_step()
        assert self._tokenizer.pad_token_id is not None, "Tokenizer must have a pad token id."

        return WrappedDataLoader(
            mode=mode,
            dataset=dataset,
            num_workers=self._num_workers,
            pin_memory=self._pin_memory,
            persistent_workers=self._persistent_workers,
            collate_fn=functools.partial(
                collate.bert_padding_collate_fn,
                padding_value=self._tokenizer.pad_token_id,
                min_length=self._min_seq_length,
                max_length=self._max_seq_length,
            ),
            **kwargs,
        )

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        """Returns the dataloader for training data."""
        assert self._train_ds is not None, "train_dataset is not provided to ESM2FineTuneDataModule"
        return self._create_dataloader(self._train_ds, mode="train")

    def val_dataloader(self) -> EVAL_DATALOADERS:
        """Returns the dataloader for validation data."""
        assert self._valid_ds is not None, "valid_dataset is not provided to ESM2FineTuneDataModule"
        return self._create_dataloader(self._valid_ds, mode="validation")

    def predict_dataloader(self) -> EVAL_DATALOADERS:
        """Returns the dataloader for prediction data."""
        assert self.predict_dataset is not None, "predict_dataset is not provided to ESM2FineTuneDataModule"
        return self._create_dataloader(self.predict_dataset, mode="predict")

    def test_dataloader(self) -> EVAL_DATALOADERS:
        """Raises a not implemented error."""
        raise NotImplementedError("No test dataset provided for ESM2")

__init__(train_dataset=None, valid_dataset=None, predict_dataset=None, seed=42, min_seq_length=None, max_seq_length=1024, micro_batch_size=4, global_batch_size=8, num_workers=2, persistent_workers=True, pin_memory=True, rampup_batch_size=None, tokenizer=tokenizer.get_tokenizer())

Initialize the ESM2FineTuneDataModule.

Parameters:

Name Type Description Default
train_dataset DATASET_TYPES

The training dataset.

None
valid_dataset DATASET_TYPES

The validation dataset.

None
predict_dataset DATASET_TYPES

The prediction dataset. Should not be set together with train/valid datasets

None
seed int

The random seed to use for shuffling the datasets. Defaults to 42.

42
min_seq_length int | None

The minimum sequence length for the datasets. Defaults to None.

None
max_seq_length int

The maximum sequence length for the datasets. Defaults to 1024.

1024
micro_batch_size int

The micro-batch size for the data loader. Defaults to 4.

4
global_batch_size int

The global batch size for the data loader. Defaults to 8.

8
num_workers int

The number of worker processes for the data loader. Defaults to 10.

2
persistent_workers bool

Whether to persist the worker processes. Defaults to True.

True
pin_memory bool

Whether to pin the data in memory. Defaults to True.

True
rampup_batch_size list[int] | None

The batch size ramp-up schedule. Defaults to None.

None
tokenizer BioNeMoESMTokenizer

The tokenizer to use for tokenization. Defaults to the BioNeMoESMTokenizer.

get_tokenizer()

Returns:

Type Description
None

None

Source code in bionemo/esm2/model/finetune/datamodule.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __init__(
    self,
    train_dataset: DATASET_TYPES = None,
    valid_dataset: DATASET_TYPES = None,
    predict_dataset: DATASET_TYPES = None,
    seed: int = 42,
    min_seq_length: int | None = None,
    max_seq_length: int = 1024,
    micro_batch_size: int = 4,
    global_batch_size: int = 8,
    num_workers: int = 2,
    persistent_workers: bool = True,
    pin_memory: bool = True,
    rampup_batch_size: list[int] | None = None,
    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
) -> None:
    """Initialize the ESM2FineTuneDataModule.

    Args:
        train_dataset: The training dataset.
        valid_dataset: The validation dataset.
        predict_dataset: The prediction dataset. Should not be set together with train/valid datasets
        seed: The random seed to use for shuffling the datasets. Defaults to 42.
        min_seq_length: The minimum sequence length for the datasets. Defaults to None.
        max_seq_length: The maximum sequence length for the datasets. Defaults to 1024.
        micro_batch_size: The micro-batch size for the data loader. Defaults to 4.
        global_batch_size: The global batch size for the data loader. Defaults to 8.
        num_workers: The number of worker processes for the data loader. Defaults to 10.
        persistent_workers: Whether to persist the worker processes. Defaults to True.
        pin_memory: Whether to pin the data in memory. Defaults to True.
        rampup_batch_size: The batch size ramp-up schedule. Defaults to None.
        tokenizer: The tokenizer to use for tokenization. Defaults to the BioNeMoESMTokenizer.

    Returns:
        None
    """
    super().__init__()
    self.train_dataset = train_dataset
    self.valid_dataset = valid_dataset
    self.predict_dataset = predict_dataset
    if predict_dataset is not None:
        assert train_dataset is None, "Datamodule expects either trin/valid dataset or predict dataset"
    self._seed = seed
    self._min_seq_length = min_seq_length
    self._max_seq_length = max_seq_length
    self._tokenizer = tokenizer

    self._micro_batch_size = micro_batch_size
    self._num_workers = num_workers
    self._persistent_workers = persistent_workers
    self._pin_memory = pin_memory

    self.data_sampler = MegatronDataSampler(
        seq_len=max_seq_length,
        micro_batch_size=micro_batch_size,
        global_batch_size=global_batch_size,
        dataloader_type="single",  # `MegatronPretrainingRandomSampler` from "cyclic" is failing.
        rampup_batch_size=rampup_batch_size,
        output_log=predict_dataset is None,  # logging does not work with predict step
    )

predict_dataloader()

Returns the dataloader for prediction data.

Source code in bionemo/esm2/model/finetune/datamodule.py
201
202
203
204
def predict_dataloader(self) -> EVAL_DATALOADERS:
    """Returns the dataloader for prediction data."""
    assert self.predict_dataset is not None, "predict_dataset is not provided to ESM2FineTuneDataModule"
    return self._create_dataloader(self.predict_dataset, mode="predict")

setup(stage)

Setup the ESMDataModule.

Parameters:

Name Type Description Default
stage str

Unused.

required

Raises:

Type Description
RuntimeError

If the trainer is not attached, or if the trainer's max_steps is not set.

Source code in bionemo/esm2/model/finetune/datamodule.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def setup(self, stage: str) -> None:
    """Setup the ESMDataModule.

    Args:
        stage: Unused.

    Raises:
        RuntimeError: If the trainer is not attached, or if the trainer's max_steps is not set.
    """
    del stage  # Unused.

    if not hasattr(self, "trainer") or self.trainer is None:
        raise RuntimeError("Setup should be completed when trainer and config are attached.")

    if self.trainer.max_epochs is not None and self.trainer.max_epochs > 1:
        logging.warning(
            "Trainer is set to run for multiple epochs. This is not recommended due to the same shuffle being used "
            "in each. Instead set max_epochs to 1 and increase the number of max_steps."
        )

    # Create training dataset
    if self.train_dataset is not None:
        max_train_steps = self.trainer.max_steps
        if max_train_steps <= 0:
            raise RuntimeError("Please specify trainer.max_steps")

        num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size)
        self._train_ds = self._create_epoch_based_dataset(self.train_dataset, num_train_samples)

    # Create validation dataset
    if self.valid_dataset is not None and self.trainer.limit_val_batches != 0:
        num_val_samples = infer_num_samples(
            limit_batches=self.trainer.limit_val_batches,
            num_samples_in_dataset=len(self.valid_dataset),
            global_batch_size=self.data_sampler.global_batch_size,
            stage="val",
        )
        self._valid_ds = self._create_epoch_based_dataset(self.valid_dataset, num_val_samples)

    assert (
        hasattr(self, "trainer") and self.trainer is not None
    ), "Setup should be completed when trainer and config are attached."

test_dataloader()

Raises a not implemented error.

Source code in bionemo/esm2/model/finetune/datamodule.py
206
207
208
def test_dataloader(self) -> EVAL_DATALOADERS:
    """Raises a not implemented error."""
    raise NotImplementedError("No test dataset provided for ESM2")

train_dataloader()

Returns the dataloader for training data.

Source code in bionemo/esm2/model/finetune/datamodule.py
191
192
193
194
def train_dataloader(self) -> TRAIN_DATALOADERS:
    """Returns the dataloader for training data."""
    assert self._train_ds is not None, "train_dataset is not provided to ESM2FineTuneDataModule"
    return self._create_dataloader(self._train_ds, mode="train")

val_dataloader()

Returns the dataloader for validation data.

Source code in bionemo/esm2/model/finetune/datamodule.py
196
197
198
199
def val_dataloader(self) -> EVAL_DATALOADERS:
    """Returns the dataloader for validation data."""
    assert self._valid_ds is not None, "valid_dataset is not provided to ESM2FineTuneDataModule"
    return self._create_dataloader(self._valid_ds, mode="validation")