Writing Megatron-LM Compatible Datamodules
Megatron-LM relies on determinism in the training dataset classes to ensure
that input tensors are initialized correctly across model-parallel ranks (see NeMo2 Parallelism). As a
consequence, new dataset classes must be careful to preserve the required determinism. Common operations such as data
augmentation, masking, etc. can cause dataset[i]
to return random results for a given index, breaking this megatron
contract.
Multi-Epoch Training
One training regime where this limitation is most apparent is is multi-epoch training, where standard training recipes would apply different random masks or different data augmentation strategies each time the data is encountered. BioNeMo provides a number of utilities that make multi-epoch training easier while still obeying the determinism requirements of megatron.
The MultiEpochDatasetResampler class simplifies the
process of multi-epoch training, where the data should both be re-shuffled each epoch with different random effects
applied each time the data is seen. To be compatible with this resampler, the provided dataset class's __getitem__
method should accept a EpochIndex tuple that contains both an epoch
and index value. Random effects can then be performed by setting the torch random seed based on the epoch value:
class MyDataset:
def __getitem__(self, idx: EpochIndex):
rng = torch.Generator()
rng.manual_seed(idx.epoch)
...
Avoid torch.manual_seed
Megatron-LM handles torch seeding internally. Calling torch.cuda.manual_seed
inside the user-provided dataset
can cause issues with model parallelism. See megatron/core/tensor_parallel/random.py#L198-L199 for more
details.
For deterministic datasets that still want to train for multiple epochs with epoch-level shuffling, the IdentityMultiEpochDatasetWrapper class can simplify this process by wrapping a dataset that accepts integer indices and passing along the EpochIndex index values from the resampled dataset.
class MyDeterministicDataset:
def __getitem__(self, index: int):
...
dataset = IdentityMultiEpochDatasetWrapper(MyDeterministicDataset())
for sample in MultiEpochDatasetResampler(dataset, num_epochs=3, shuffle=True):
...
Very large datasets
For datasets where len(dataset)
is too large for a shuffled list of indices to comfortably fit in memory,
[PRNGResampleDataset][bionemo.core.data.resamples.PRNGResampleDataset] offers a simple solution for shuffling a
dataset with replacement in O(1) memory.
Training Resumption
To ensure identical behavior with and without job interruption, BioNeMo provides MegatronDataModule to save and load state dict for training resumption, and provides [WrappedDataLoader][nemo.lightning.data.WrappedDataLoader] to add a mode
attribute to [DataLoader][torch.utils.data.DataLoader].
class MyDataModule(MegatronDataModule):
def __init__(self, *args, **kwargs):
super().__init__()
...
def train_dataloader(self):
self.update_init_global_step() # required to set the correct `global_step` for resumption
return WrappedDataLoader(
..., # any other arguments for DataLoader
mode="train",
)
def val_dataloader(self):
self.update_init_global_step() # required to set the correct `global_step` for resumption
return WrappedDataLoader(
..., # any other arguments for DataLoader
mode="validation",
)
def test_dataloader(self):
self.update_init_global_step() # required to set the correct `global_step` for resumption
return WrappedDataLoader(
..., # any other arguments for DataLoader
mode="test",
)
MegatronDataModule
Users will see non-overlapping training curve if their datamodule is not inheritting from MegatronDataModule
, unless similar logics are handled by the users. In MegatronDataModule
, self.update_init_global_step()
must be called right before the dataloaders are returned to ensure that training resumes with the correct sample index instead of restarting from 0 everytime. We recommend users to inherit from MegatronDataModule
similar to the pattern above.
WrappedDataLoader
The WrappedDataLoader
class is a wrapper around the PyTorch DataLoader class that adds the mode
attribute to the dataloader. The dataloader will resume from the last sample index only when mode is 'train'. val_dataloader
and test_dataloader
are unaffected.
WARNING: 'train' is the default value of mode
in WrappedDataLoader
. If not set, users might find their validation/test dataloader changes behavior by resuming from a non-zero sample index.
Testing Datasets For Megatron Compatibility
BioNeMo also provides utility functions for test suites to validate that datasets conform to the megatron data model.
The [assert_dataset_compatible_with_megatron][bionemo.testing.data_utils.assert_dataset_compatible_with_megatron]
function calls the dataset with identical indices and ensures the outputs are identical, while also checking to see if
torch.manual_seed
was used.
Example datasets in BioNeMo
The ESMMaskedResidueDataset demonstrates one approach for leveraging EpochIndex indices to perform epoch-level randomization within the confines of megatron's data model.