Skip to content

Datamodule

MegatronDataModule

Bases: LightningDataModule

A mixin that adds a state_dict and load_state_dict method for datamodule training resumption in NeMo.

Source code in bionemo/llm/data/datamodule.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class MegatronDataModule(pl.LightningDataModule):
    """A mixin that adds a `state_dict` and `load_state_dict` method for datamodule training resumption in NeMo."""

    def __init__(self, *args, **kwargs):
        """Set init_global_step to 0 for datamodule resumption."""
        super().__init__(*args, **kwargs)
        self.init_global_step = 0

    def update_init_global_step(self):
        """Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
        self.init_global_step = self.trainer.global_step  # Update the init_global_step whenever we re-init training
        self.data_sampler.init_global_step = (
            self.init_global_step
        )  # Update the init_global_step whenever we re-init training

    def state_dict(self) -> Dict[str, Any]:
        """Called when saving a checkpoint, implement to generate and save datamodule state.

        Returns:
            A dictionary containing datamodule state.

        """
        consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
        return {"consumed_samples": consumed_samples}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.

        Args:
            state_dict: the datamodule state returned by ``state_dict``.

        """
        try:
            from megatron.core.num_microbatches_calculator import update_num_microbatches

        except (ImportError, ModuleNotFoundError):
            logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
            from apex.transformer.pipeline_parallel.utils import update_num_microbatches

        consumed_samples = state_dict["consumed_samples"]
        self.data_sampler.init_consumed_samples = consumed_samples
        self.data_sampler.prev_consumed_samples = consumed_samples

        update_num_microbatches(
            consumed_samples=consumed_samples,
            consistency_check=False,
        )
        self.data_sampler.if_first_step = 1

__init__(*args, **kwargs)

Set init_global_step to 0 for datamodule resumption.

Source code in bionemo/llm/data/datamodule.py
26
27
28
29
def __init__(self, *args, **kwargs):
    """Set init_global_step to 0 for datamodule resumption."""
    super().__init__(*args, **kwargs)
    self.init_global_step = 0

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.

Parameters:

Name Type Description Default
state_dict Dict[str, Any]

the datamodule state returned by state_dict.

required
Source code in bionemo/llm/data/datamodule.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.

    Args:
        state_dict: the datamodule state returned by ``state_dict``.

    """
    try:
        from megatron.core.num_microbatches_calculator import update_num_microbatches

    except (ImportError, ModuleNotFoundError):
        logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
        from apex.transformer.pipeline_parallel.utils import update_num_microbatches

    consumed_samples = state_dict["consumed_samples"]
    self.data_sampler.init_consumed_samples = consumed_samples
    self.data_sampler.prev_consumed_samples = consumed_samples

    update_num_microbatches(
        consumed_samples=consumed_samples,
        consistency_check=False,
    )
    self.data_sampler.if_first_step = 1

state_dict()

Called when saving a checkpoint, implement to generate and save datamodule state.

Returns:

Type Description
Dict[str, Any]

A dictionary containing datamodule state.

Source code in bionemo/llm/data/datamodule.py
38
39
40
41
42
43
44
45
46
def state_dict(self) -> Dict[str, Any]:
    """Called when saving a checkpoint, implement to generate and save datamodule state.

    Returns:
        A dictionary containing datamodule state.

    """
    consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
    return {"consumed_samples": consumed_samples}

update_init_global_step()

Please always call this when you get a new dataloader... if you forget, your resumption will not work.

Source code in bionemo/llm/data/datamodule.py
31
32
33
34
35
36
def update_init_global_step(self):
    """Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
    self.init_global_step = self.trainer.global_step  # Update the init_global_step whenever we re-init training
    self.data_sampler.init_global_step = (
        self.init_global_step
    )  # Update the init_global_step whenever we re-init training