Skip to content

Writing Megatron-LM Compatible Datamodules

Megatron-LM relies on determinism in the training dataset classes to ensure that input tensors are initialized correctly across model-parallel ranks (see NeMo2 Parallelism). As a consequence, new dataset classes must be careful to preserve the required determinism. Common operations such as data augmentation, masking, etc. can cause dataset[i] to return random results for a given index, breaking this megatron contract.

Multi-Epoch Training

One training regime where this limitation is most apparent is is multi-epoch training, where standard training recipes would apply different random masks or different data augmentation strategies each time the data is encountered. BioNeMo provides a number of utilities that make multi-epoch training easier while still obeying the determinism requirements of megatron.

The MultiEpochDatasetResampler class simplifies the process of multi-epoch training, where the data should both be re-shuffled each epoch with different random effects applied each time the data is seen. To be compatible with this resampler, the provided dataset class's __getitem__ method should accept a EpochIndex tuple that contains both an epoch and index value. Random effects can then be performed by setting the torch random seed based on the epoch value:

class MyDataset:
    def __getitem__(self, idx: EpochIndex):
        rng = torch.Generator()
        rng.manual_seed(idx.epoch)
        ...

Avoid torch.manual_seed

Megatron-LM handles torch seeding internally. Calling torch.cuda.manual_seed inside the user-provided dataset can cause issues with model parallelism. See megatron/core/tensor_parallel/random.py#L198-L199 for more details.

For deterministic datasets that still want to train for multiple epochs with epoch-level shuffling, the IdentityMultiEpochDatasetWrapper class can simplify this process by wrapping a dataset that accepts integer indices and passing along the EpochIndex index values from the resampled dataset.

class MyDeterministicDataset:
    def __getitem__(self, index: int):
        ...

dataset = IdentityMultiEpochDatasetWrapper(MyDeterministicDataset())
for sample in MultiEpochDatasetResampler(dataset, num_epochs=3, shuffle=True):
    ...

Very large datasets

For datasets where len(dataset) is too large for a shuffled list of indices to comfortably fit in memory, [PRNGResampleDataset][bionemo.core.data.resamples.PRNGResampleDataset] offers a simple solution for shuffling a dataset with replacement in O(1) memory.

Training Resumption

To ensure identical behavior with and without job interruption, BioNeMo provides MegatronDataModule to save and load state dict for training resumption, and provides [WrappedDataLoader][nemo.lightning.data.WrappedDataLoader] to add a mode attribute to [DataLoader][torch.utils.data.DataLoader].

class MyDataModule(MegatronDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        ...

    def train_dataloader(self):
        self.update_init_global_step()  # required to set the correct `global_step` for resumption
        return WrappedDataLoader(
            ...,  # any other arguments for DataLoader
            mode="train",
        )

    def val_dataloader(self):
        self.update_init_global_step()  # required to set the correct `global_step` for resumption
        return WrappedDataLoader(
            ...,  # any other arguments for DataLoader
            mode="validation",
        )

    def test_dataloader(self):
        self.update_init_global_step()  # required to set the correct `global_step` for resumption
        return WrappedDataLoader(
            ...,  # any other arguments for DataLoader
            mode="test",
        )

MegatronDataModule

Users will see non-overlapping training curve if their datamodule is not inheritting from MegatronDataModule, unless similar logics are handled by the users. In MegatronDataModule, self.update_init_global_step() must be called right before the dataloaders are returned to ensure that training resumes with the correct sample index instead of restarting from 0 everytime. We recommend users to inherit from MegatronDataModule similar to the pattern above.

WrappedDataLoader

The WrappedDataLoader class is a wrapper around the PyTorch DataLoader class that adds the mode attribute to the dataloader. The dataloader will resume from the last sample index only when mode is 'train'. val_dataloader and test_dataloader are unaffected.

WARNING: 'train' is the default value of mode in WrappedDataLoader. If not set, users might find their validation/test dataloader changes behavior by resuming from a non-zero sample index.

Testing Datasets For Megatron Compatibility

BioNeMo also provides utility functions for test suites to validate that datasets conform to the megatron data model. The [assert_dataset_compatible_with_megatron][bionemo.testing.data_utils.assert_dataset_compatible_with_megatron] function calls the dataset with identical indices and ensures the outputs are identical, while also checking to see if torch.manual_seed was used.

Example datasets in BioNeMo

The ESMMaskedResidueDataset demonstrates one approach for leveraging EpochIndex indices to perform epoch-level randomization within the confines of megatron's data model.