Skip to content

Datamodule

MegatronDataModule

Bases: LightningDataModule

A mixin that adds a state_dict and load_state_dict method for datamodule training resumption in NeMo.

Source code in bionemo/llm/data/datamodule.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class MegatronDataModule(pl.LightningDataModule):
    """A mixin that adds a `state_dict` and `load_state_dict` method for datamodule training resumption in NeMo."""

    def __init__(self, *args, **kwargs):
        """Set init_global_step to 0 for datamodule resumption."""
        super().__init__(*args, **kwargs)
        self.init_global_step = 0

    def update_init_global_step(self):
        """Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
        self.init_global_step = self.trainer.global_step  # Update the init_global_step whenever we re-init training
        self.data_sampler.init_global_step = (
            self.init_global_step
        )  # Update the init_global_step whenever we re-init training

    def state_dict(self) -> Dict[str, Any]:
        """Called when saving a checkpoint, implement to generate and save datamodule state.

        Returns:
            A dictionary containing datamodule state.

        """
        consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
        return {"consumed_samples": consumed_samples}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.

        Args:
            state_dict: the datamodule state returned by ``state_dict``.

        """
        try:
            from megatron.core.num_microbatches_calculator import update_num_microbatches

        except (ImportError, ModuleNotFoundError):
            logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
            from apex.transformer.pipeline_parallel.utils import update_num_microbatches

        consumed_samples = state_dict["consumed_samples"]
        self.data_sampler.init_consumed_samples = consumed_samples
        self.data_sampler.prev_consumed_samples = consumed_samples

        update_num_microbatches(
            consumed_samples=consumed_samples,
            consistency_check=False,
        )
        self.data_sampler.if_first_step = 1

__init__(*args, **kwargs)

Set init_global_step to 0 for datamodule resumption.

Source code in bionemo/llm/data/datamodule.py
32
33
34
35
def __init__(self, *args, **kwargs):
    """Set init_global_step to 0 for datamodule resumption."""
    super().__init__(*args, **kwargs)
    self.init_global_step = 0

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.

Parameters:

Name Type Description Default
state_dict Dict[str, Any]

the datamodule state returned by state_dict.

required
Source code in bionemo/llm/data/datamodule.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.

    Args:
        state_dict: the datamodule state returned by ``state_dict``.

    """
    try:
        from megatron.core.num_microbatches_calculator import update_num_microbatches

    except (ImportError, ModuleNotFoundError):
        logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
        from apex.transformer.pipeline_parallel.utils import update_num_microbatches

    consumed_samples = state_dict["consumed_samples"]
    self.data_sampler.init_consumed_samples = consumed_samples
    self.data_sampler.prev_consumed_samples = consumed_samples

    update_num_microbatches(
        consumed_samples=consumed_samples,
        consistency_check=False,
    )
    self.data_sampler.if_first_step = 1

state_dict()

Called when saving a checkpoint, implement to generate and save datamodule state.

Returns:

Type Description
Dict[str, Any]

A dictionary containing datamodule state.

Source code in bionemo/llm/data/datamodule.py
44
45
46
47
48
49
50
51
52
def state_dict(self) -> Dict[str, Any]:
    """Called when saving a checkpoint, implement to generate and save datamodule state.

    Returns:
        A dictionary containing datamodule state.

    """
    consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
    return {"consumed_samples": consumed_samples}

update_init_global_step()

Please always call this when you get a new dataloader... if you forget, your resumption will not work.

Source code in bionemo/llm/data/datamodule.py
37
38
39
40
41
42
def update_init_global_step(self):
    """Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
    self.init_global_step = self.trainer.global_step  # Update the init_global_step whenever we re-init training
    self.data_sampler.init_global_step = (
        self.init_global_step
    )  # Update the init_global_step whenever we re-init training

MockDataModule

Bases: MegatronDataModule

A simple data module that just wraps input datasets with dataloaders.

Source code in bionemo/llm/data/datamodule.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class MockDataModule(MegatronDataModule):
    """A simple data module that just wraps input datasets with dataloaders."""

    def __init__(
        self,
        train_dataset: Dataset | None = None,
        valid_dataset: Dataset | None = None,
        test_dataset: Dataset | None = None,
        predict_dataset: Dataset | None = None,
        pad_token_id: int = 0,
        min_seq_length: int | None = None,
        max_seq_length: int = 512,
        micro_batch_size: int = 16,
        global_batch_size: int = 16,
        num_workers: int = 4,
    ) -> None:
        """Initialize the MockDataModule."""
        super().__init__()
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.test_dataset = test_dataset
        self.predict_dataset = predict_dataset
        self.pad_token_id = pad_token_id
        self.min_seq_length = min_seq_length
        self.max_seq_length = max_seq_length
        self.batch_size = micro_batch_size
        self.num_workers = num_workers
        self.data_sampler = MegatronDataSampler(
            seq_len=max_seq_length,
            micro_batch_size=micro_batch_size,
            global_batch_size=global_batch_size,
            dataloader_type="single",
            output_log=False,
        )

    def setup(self, stage: str | None = None) -> None:  # noqa: D102
        pass

    def _make_dataloader(
        self, dataset: Dataset, mode: Literal["train", "validation", "test", "predict"]
    ) -> WrappedDataLoader:
        if mode not in ["predict", "test"]:
            self.update_init_global_step()

        return WrappedDataLoader(
            mode=mode,
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=functools.partial(
                collate.bert_padding_collate_fn,
                padding_value=self.pad_token_id,
                min_length=self.min_seq_length,
                max_length=self.max_seq_length,
            ),
        )

    def train_dataloader(self) -> DataLoader:  # noqa: D102
        if self.train_dataset is None:
            raise ValueError("No train_dataset was provided")
        return self._make_dataloader(
            self.train_dataset,
            mode="train",
        )

    def val_dataloader(self) -> DataLoader:  # noqa: D102
        if self.valid_dataset is None:
            raise ValueError("No valid_dataset was provided")
        return self._make_dataloader(
            self.valid_dataset,
            mode="validation",
        )

    def test_dataloader(self) -> DataLoader:  # noqa: D102
        if self.test_dataset is None:
            raise ValueError("No test_dataset was provided")
        return self._make_dataloader(
            self.test_dataset,
            mode="test",
        )

    def predict_dataloader(self) -> DataLoader:  # noqa: D102
        if self.predict_dataset is None:
            raise ValueError("No predict_dataset was provided")
        return self._make_dataloader(
            self.predict_dataset,
            mode="predict",
        )

__init__(train_dataset=None, valid_dataset=None, test_dataset=None, predict_dataset=None, pad_token_id=0, min_seq_length=None, max_seq_length=512, micro_batch_size=16, global_batch_size=16, num_workers=4)

Initialize the MockDataModule.

Source code in bionemo/llm/data/datamodule.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def __init__(
    self,
    train_dataset: Dataset | None = None,
    valid_dataset: Dataset | None = None,
    test_dataset: Dataset | None = None,
    predict_dataset: Dataset | None = None,
    pad_token_id: int = 0,
    min_seq_length: int | None = None,
    max_seq_length: int = 512,
    micro_batch_size: int = 16,
    global_batch_size: int = 16,
    num_workers: int = 4,
) -> None:
    """Initialize the MockDataModule."""
    super().__init__()
    self.train_dataset = train_dataset
    self.valid_dataset = valid_dataset
    self.test_dataset = test_dataset
    self.predict_dataset = predict_dataset
    self.pad_token_id = pad_token_id
    self.min_seq_length = min_seq_length
    self.max_seq_length = max_seq_length
    self.batch_size = micro_batch_size
    self.num_workers = num_workers
    self.data_sampler = MegatronDataSampler(
        seq_len=max_seq_length,
        micro_batch_size=micro_batch_size,
        global_batch_size=global_batch_size,
        dataloader_type="single",
        output_log=False,
    )