Bases: LightningDataModule
A mixin that adds a state_dict
and load_state_dict
method for datamodule training resumption in NeMo.
Source code in bionemo/llm/data/datamodule.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 | class MegatronDataModule(pl.LightningDataModule):
"""A mixin that adds a `state_dict` and `load_state_dict` method for datamodule training resumption in NeMo."""
def __init__(self, *args, **kwargs):
"""Set init_global_step to 0 for datamodule resumption."""
super().__init__(*args, **kwargs)
self.init_global_step = 0
def update_init_global_step(self):
"""Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
self.init_global_step = self.trainer.global_step # Update the init_global_step whenever we re-init training
self.data_sampler.init_global_step = (
self.init_global_step
) # Update the init_global_step whenever we re-init training
def state_dict(self) -> Dict[str, Any]:
"""Called when saving a checkpoint, implement to generate and save datamodule state.
Returns:
A dictionary containing datamodule state.
"""
consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
return {"consumed_samples": consumed_samples}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.
Args:
state_dict: the datamodule state returned by ``state_dict``.
"""
try:
from megatron.core.num_microbatches_calculator import update_num_microbatches
except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
consumed_samples = state_dict["consumed_samples"]
self.data_sampler.init_consumed_samples = consumed_samples
self.data_sampler.prev_consumed_samples = consumed_samples
update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
self.data_sampler.if_first_step = 1
|
__init__(*args, **kwargs)
Set init_global_step to 0 for datamodule resumption.
Source code in bionemo/llm/data/datamodule.py
| def __init__(self, *args, **kwargs):
"""Set init_global_step to 0 for datamodule resumption."""
super().__init__(*args, **kwargs)
self.init_global_step = 0
|
load_state_dict(state_dict)
Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.
Parameters:
Name |
Type |
Description |
Default |
state_dict
|
Dict[str, Any]
|
the datamodule state returned by state_dict .
|
required
|
Source code in bionemo/llm/data/datamodule.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.
Args:
state_dict: the datamodule state returned by ``state_dict``.
"""
try:
from megatron.core.num_microbatches_calculator import update_num_microbatches
except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
consumed_samples = state_dict["consumed_samples"]
self.data_sampler.init_consumed_samples = consumed_samples
self.data_sampler.prev_consumed_samples = consumed_samples
update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
self.data_sampler.if_first_step = 1
|
state_dict()
Called when saving a checkpoint, implement to generate and save datamodule state.
Returns:
Type |
Description |
Dict[str, Any]
|
A dictionary containing datamodule state.
|
Source code in bionemo/llm/data/datamodule.py
38
39
40
41
42
43
44
45
46 | def state_dict(self) -> Dict[str, Any]:
"""Called when saving a checkpoint, implement to generate and save datamodule state.
Returns:
A dictionary containing datamodule state.
"""
consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
return {"consumed_samples": consumed_samples}
|
update_init_global_step()
Please always call this when you get a new dataloader... if you forget, your resumption will not work.
Source code in bionemo/llm/data/datamodule.py
| def update_init_global_step(self):
"""Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
self.init_global_step = self.trainer.global_step # Update the init_global_step whenever we re-init training
self.data_sampler.init_global_step = (
self.init_global_step
) # Update the init_global_step whenever we re-init training
|