Skip to content

Writing Megatron-LM Compatible Datamodules

Megatron-LM relies on determinism in the training dataset classes to ensure that input tensors are initialized correctly across model-parallel ranks (see NeMo2 Parallelism). As a consequence, ensure that the new dataset classes preserve the required determinism. Common operations such as data augmentation and masking can cause dataset[i] to return random results for a given index, breaking this megatron contract.

Multi-Epoch Training

One training regime where this limitation is most apparent is multi-epoch training, where standard training recipes would apply different random masks or different data augmentation strategies each time the data is encountered. BioNeMo provides some utilities that make multi-epoch training easier, while obeying the determinism requirements of megatron.

The MultiEpochDatasetResampler class simplifies the process of multi-epoch training, where the data should both be re-shuffled each epoch with different random effects applied each time the data is seen. To be compatible with this resampler, the provided dataset class's __getitem__ method should accept a EpochIndex tuple that contains both an epoch and index value. Random effects can then be performed by setting the torch random seed based on the epoch value:

class MyDataset:
    def __getitem__(self, idx: EpochIndex):
        rng = torch.Generator()
        rng.manual_seed(idx.epoch)
        ...

Avoid torch.manual_seed

Megatron-LM handles torch seeding internally. Calling `torch.cuda.manual_seed` inside the user-provided dataset
can cause issues with model parallelism. See [megatron/core/tensor_parallel/random.py#L198-L199](
https://github.com/NVIDIA/Megatron-LM/blob/dddecd19/megatron/core/tensor_parallel/random.py#L198-L199) for more
details.

For deterministic datasets that still want to train for multiple epochs with epoch-level shuffling, the IdentityMultiEpochDatasetWrapper class can simplify this process by wrapping a dataset that accepts integer indices and passes along the EpochIndex index values from the resampled dataset.

class MyDeterministicDataset:
    def __getitem__(self, index: int): ...


dataset = IdentityMultiEpochDatasetWrapper(MyDeterministicDataset())
for sample in MultiEpochDatasetResampler(dataset, num_epochs=3, shuffle=True):
    ...

Training Resumption

To ensure identical behavior with and without job interruption, Megatron datamodules must manage sample-exact training resumption. When writing your own datamodule, preserve these constraints:

  • Persist enough dataloader state (e.g. the global step or sample index) so training resumes from the correct position rather than restarting from index 0.
  • Distinguish between train, validation, and test dataloaders explicitly. Only the training dataloader should resume from a saved sample index — validation and test dataloaders should always start from the beginning.
  • Update the global step immediately before returning each dataloader so the resume position is accurate.

See the evo2_megatron and eden_megatron recipes in bionemo-recipes for working examples of Megatron datamodule implementations with training resumption.

Testing Datasets for Megatron Compatibility

The key invariant for Megatron-compatible datasets is determinism: repeated calls with the same index must yield the same sample. When writing tests for your dataset, confirm that:

  • Repeated indexing with the same index returns identical results.
  • Epoch-aware randomization is driven only by the epoch component of the index (via a local torch.Generator, not the global seed).
  • torch.manual_seed is not called inside dataset __getitem__ paths, as Megatron-LM manages torch seeding internally for model parallelism.

Recipe-local tests in bionemo-recipes (e.g. in the evo2_megatron recipe) are the best reference for how to validate these assumptions.