Skip to content

TransformerEngine-accelerated ESM-2 training with native PyTorch training loop

This folder demonstrates how to train TE-accelerated ESM-2 with a native PyTorch training loop, including sequence packing and FP8 precision, using fully sharded data parallel (FSDP) for distributed training.

How to use this recipe

This folder contains an independent, minimal training example. It does not depend on any other code in the top-level bionemo-framework repository. You can download a zipped directory of this folder alone by clicking here.

How to deploy this recipe on cloud providers

🚧 Under development

Supported Models and Training Features

Model BF16 FP8[1] THD Input Format FP8 with THD Input Format MXFP8[2] Context Parallelism
ESM-2
AMPLIFY 🚧 🚧

✅: Supported
🚧: Under development
❌: Not supported

[1]: Requires compute capability 9.0 and above (Hopper+)
[2]: Requires compute capability 10.0 and 10.3 (Blackwell), 12.0 support pending

Installing Dependencies

The easiest way to get started with this recipe is to use the provided Dockerfile, which uses the latest NVIDIA PyTorch base image to provide optimized versions of PyTorch and TransformerEngine. To build the container, run:

docker build -t esm2_native_te .

To run the container, run:

docker run -it --gpus all --network host --ipc=host --rm -v ${PWD}:/workspace/bionemo esm2_native_te /bin/bash

Alternatively, the dependencies can be installed manually in an environment with CUDA support. Refer to Dockerfile.cuda for the process of installing dependencies in a fresh python environment (for example, CUDA 13.0):

uv venv --python 3.12 --seed /workspace/.venv
source /workspace/.venv/bin/activate
uv pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu130
uv pip install wheel packaging psutil
pip install --no-build-isolation "flash-attn>=2.1.1,<=2.8.1"
pip install --no-build-isolation transformer-engine[pytorch]==2.9.0
uv pip install -r /requirements.txt

To build and run the CUDA base container, run:

docker build -t esm2_native_te_cuda -f Dockerfile.cuda .
docker run -it --gpus all --network host --ipc=host --rm -v ${PWD}:/workspace/bionemo esm2_native_te_cuda /bin/bash -c "pytest -v ."

Performance Benchmarks

Performance Benchmarks

Note: "compiled" refers to torch.compile. "fa2" is FlashAttention2. Recently, we measured 2800 tokens/second/GPU training speed on H100 with HuggingFace Transformers's ESM-2 implementation of THD sequence packing, however we have not been able to make this configuration work on Blackwell and this work is still in progress.

Distributed Training

This recipe supports distributed training using DDP, FSDP2, and Megatron-FSDP, shown in three separate training entrypoints:

Commands to Launch Training

To run single-process training on one GPU, run:

python train_ddp.py  # or train_fsdp2.py / train_mfsdp.py

To run multi-process training locally on 2+ GPUs, run:

torchrun --nproc_per_node=2 train_fsdp2.py  # or train_mfsdp.py / train_ddp.py

Multi-Node training is supported with all three strategies, refer to slurm.sh for an example SLURM script.

FP8 Training

To run training with FP8, enable it by overriding the fp8_config.enabled=true configuration parameter. Additional FP8 configuration parameters, including switching to MXFP8BlockScaling, can be set using the hydra configuration.

python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true

FP8 Debugging

We also provide a mechanism to receive tensor data related to FP8 layers during training which may include activations, weights and gradients.

To enable this please select the following config options.

python train_fsdp2.py \
fp8_stats_config.enabled=True # whether to log stats or not
fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy # where to store the logs
fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml # specifies what stats you want to run. Currently this is saved in this yaml file.
fp8_config.enabled=True # set this to use FP8 otherwise stats logging won't work

Note: This feature is available for the train_ddp and the train_fsdp2 scripts. It is not yet available for train_mfsdp.

The config file structure fp8_debugging_stats.yaml is explained in the NVIDIA Transformer Engine config file documentation in more detail. Below we will cover some very basic elements of the file structure.

This comes as a performance cost that is dependent on the freq parameter mentioned above. freq=1 collects stats on every step which in our experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We recommend using freq>=10 to reduce this performance hit.

Sequence Packing (THD input format)

Sequence packing is handled using a padding-free collator (in collator.py) that provides input arguments, such as cu_seq_lens_q), needed for padding-free attention. To enable sequence packing, set use_sequence_packing=true in the hydra configuration.

python train_fsdp2.py --config-name L0_sanity use_sequence_packing=true

FP8 and Sequence Packing

To combine FP8 training with sequence packing, the number of unpadded input tokens must be a multiple of 16. The data collator will automatically pad packed sequences to the maximum number of tokens per batch.

python train_fsdp2.py --config-name L0_sanity \
  fp8_config.enabled=true \
  use_sequence_packing=true

Context Parallelism

We provide a training script train_ddp_cp and a sample config L0_sanity_cp that uses context parallelism.

In the config, the argument --cp_size allows the user to set the size of the context parallel distributed group. When paired with Distributed Data Parallelism (DDP), the number of context parallel groups will be determined by world_size//cp_size.

Thus, if a user has 8 processes and sets cp_size=2 they will have 2 CP groups and 4 DDP groups. During dataloading we make no assumptions about the data pipeline being deterministic or not. DDP groups will provide unique data while CP groups will contain replicates of that data.

For example, if we have 2 DDP groups and 2 CP groups. Each DDP group will have a unique dataloader DP0 for DDP group 0 and DP1 for DDP group 1. CP works by running something called ring attention, which expects tokens to live on each device in a particular layout. For this CP implementation we use something called Dual Chunk Swapping. If DP0 outputs sequence 1 2 3 4 5 6 7 8 and DP1 outputs 9 10 11 12 13 14 15 16 then when we run through the CPAwareDataloader defined in datasets, the dataloader will create CP shards from that DP group as follows:

      |   DP0   |    DP1        |
  CP0 | 1,2,7,8 | 9, 10, 15, 16 |
  CP1 | 3,4,5,6 | 11, 12, 13, 14|

You may notice these shards and wonder why they are the way they are. The reason is that CP groups are sharded using slices. The full input sequence (such as 1 2 3 4 5 6 7) is sliced into 2 * cp_size groups. Then CP0 takes the first and last slice, while CP1 takes the middle slices, of each sequence.

In this example, we only show one sequence but its important to note that slicing takes place on every sequence, so if a second sequence is also available, that will be sliced in the same manner. CP0 will take the first and last slice of every sequence, while CP1 will take the middle slices of each sequence.

Comparing Against the HF Transformers Reference Implementation

To launch training with the ESM-2 model as implemented in HF Transformers, pass a facebook/esm2 checkpoint as the model tag:

python train_fsdp2.py --config-name L0_sanity model_tag=facebook/esm2_t6_8M_UR50D

Downloading Pre-Training Data For Offline Training

An example pre-training dataset for ESM-2 is available in the nvidia/esm2_uniref_pretraining_data Hugging Face dataset. This dataset can be streamed from the Hugging Face Hub by using the following.

>>> from datasets import load_dataset
>>> dataset = load_dataset('nvidia/esm2_uniref_pretraining_data', split='train', streaming=True)
>>> print(next(iter(dataset)))
{'sequence': 'MSPRRTGGARPPGPCTPCGPRPRCPSRRSAAARPAPSAAPARRARPGRRPGCRPGTDCPGTARRPGGGP...',
 'ur50_id': 'UniRef50_A0A081XN86',
 'ur90_id': 'UniRef90_UPI002FBE17D9'}

For large-scale training, the dataset should be downloaded locally with the huggingface CLI, with appropriate values set for HF_HOME and HF_TOKEN environment variables. Use uv tool install huggingface_hub to install the CLI if not already installed.

export HF_TOKEN=<your_huggingface_token>
hf download nvidia/esm2_uniref_pretraining_data --repo-type dataset --local-dir /path/to/download/directory
# Test to ensure the dataset can be loaded correctly
python -c "import datasets; datasets.load_dataset('/path/to/download/directory', split='train', streaming=True)"

Pass the downloaded dataset directory to the training script as the dataset.path configuration parameter.

HF_DATASETS_OFFLINE=1 python train_fsdp2.py --config-name L0_sanity \
  dataset.load_dataset_kwargs.path=/path/to/download/directory

Saving and Loading Checkpoints

To enable checkpoint saving, ensure that checkpoint.ckpt_dir is set to a writable directory. Checkpointing frequency is controlled by the checkpoint.save_every_n_steps configuration parameter.

python train_fsdp2.py --config-name L0_sanity \
  checkpoint.ckpt_dir=/path/to/ckpt_dir \
  checkpoint.save_every_n_steps=100

To enable checkpoint loading, set checkpoint.resume_from_checkpoint=true to resume from the latest checkpoint.

python train_fsdp2.py --config-name L0_sanity \
  checkpoint.ckpt_dir=/path/to/ckpt_dir \
  checkpoint.resume_from_checkpoint=true

We also show how to export a final model at the end of training, which is suitable for uploading to the Hugging Face Hub or for local inference as a more durable format than torch distributed checkpoints. To enable this, set checkpoint.save_final_model=true in the hydra configuration. The resulting model will be saved to the final_model directory within the checkpoint directory.

Checkpointing is implemented for all three strategies, see checkpoint.py for more details.

Saving Dataloader State with StatefulDataLoader

These examples show how to save and resume your dataloader by passing the dataloader instance to our save_checkpoint_* and load_checkpoint_* functions using the StatefulDataLoader class from torchdata. See checkpoint.py for implementation details.

For references on StatefulDataLoader and it's integration with datasets see:

  • https://github.com/meta-pytorch/data/tree/main/torchdata/stateful_dataloader
  • https://huggingface.co/docs/datasets/en/stream#save-a-dataset-checkpoint-and-resume-iteration

Known limitations:

  • When loading the dataloader from a saved checkpoint, you must provide the same num_workers that you used to save the dataloader state, because state is saved at the worker-level.
  • Moreover, dataloader state is saved on a per-rank basis. So if you resume training and load the dataloader with a different amount of nodes / gpus that was used when you saved the dataloader the state will not resume perfectly.

Running Inference with the Trained Model

Models can be loaded from the final checkpoint directory using the AutoModel.from_pretrained method. For example:

from transformers import AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("path/to/final_model")
tokenizer = AutoTokenizer.from_pretrained("...")

gfp_P42212 = (
    "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTL"
    "VTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLV"
    "NRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLAD"
    "HYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
)

inputs = tokenizer(gfp_P42212, return_tensors="pt")
model.eval()
output = model(**inputs)

Performance

🚧 Under development

Reference

Developer Guide

Running Tests

To run tests locally, run recipes_local_test.py from the repository root with the recipe directory as an argument.

./ci/scripts/recipes_local_test.py bionemo-recipes/recipes/esm2_native_te/

Tests should be kept relatively fast, using the smallest model and number of training steps required to validate the feature. Hardware requirements beyond those used in CI (e.g., a single L4) should be annotated with pytest.mark.requires, e.g. requires_fp8 and requires_multi_gpu.

Development Container

To use the provided devcontainer, use "Dev Containers: Reopen in Container" from the VSCode menu, and choose the "BioNeMo Recipes Dev Container" option. To run the tests inside the container, run pytest -v . in the recipe directory.

Hydra Tips

Hydra is a powerful configuration management library for Python. This recipe uses Hydra to manage training configurations, allowing for easy modification of training hyper-parameters and model settings.

Configuration parameters can be overridden from the command line. For example, python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true.

For verbose logging, use the hydra command line override hydra.verbose=true, see https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ for more details.