Accelerating Hugging Face Mixtral MoE Fine-Tuning with Transformer Engine
Goal
This tutorial showcases how to accelerate fine-tuning a mixture-of-experts model, Mixtral-8x7B, with Transformer Engine (TE) in BF16 and MXFP8 precision.
Setup
Mixtral-8x7B has 8 experts and roughly 47B total parameters. In BF16 the model weights alone consume ~93 GB, and full AdamW fine-tuning needs ~370 GB. This tutorial is tested on 8x B300 GPUs with Expert Parallelism (EP) = 2 and Data Parallelism (DP) = 4, so the experts are divided across 2 GPUs and there are 4 replicas. The container used is pytorch-26.04-py3. A sequence length of 8192 and a
global batch size of 48 are used across the experiments.
Install the required Python packages using the following command in a terminal:
pip install -r requirements.txt
Table of Contents
[Baseline] Running HF Mixtral – Without Expert Parallelism (Precision:
BF16)[Improvement 1] Transformer Engine with Expert Parallelism (Precision:
BF16)[Improvement 2] Batched Expert Execution with
GroupedLinear(Precision:BF16)[Improvement 3] Precision Optimization and Fused MLP (Precision:
MXFP8)Conclusion
Appendix: Dependencies
[Baseline] Running HF Mixtral – Without Expert Parallelism (Precision: BF16)
Before applying any Transformer Engine optimizations, we establish a Hugging Face (HF) baseline. Mixtral replaces the standard Transformer feed-forward network (FFN) with a sparse Mixture of Experts (MoE): a learned router selects the top-2 experts out of 8 per token, as shown in Fig 1.
Fig 1: Dense Transformer block (left) vs Sparse MoE Transformer block (right).
The current HF implementation has two limitations.
Pipeline parallelism. Because the full model does not fit on one GPU, the baseline uses pipeline parallelism to split the model across GPUs. This is the simplest way to partition a model, but GPU utilization is limited by pipeline bubbles and sequential layer dependencies.
Excessive kernel launches. HF’s MixtralSparseMoeBlock iterates over all 8 experts in a Python loop. Each expert triggers individual kernel launches.
for expert_idx, expert_layer in enumerate(self.experts):
idx, top_x = torch.where(expert_mask[expert_idx])
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden)
For each layer, HF loops through the experts sequentially. Each expert is much smaller than the dense FFN, so each expert GEMM is small and cannot saturate the GPU’s tensor cores. Looping over many experts therefore launches many small GEMMs, leaving the FFN dominated by orchestration overhead and memory movement.
The script run_finetune_ep.py initializes Hugging Face and then runs fine-tuning. For the full implementation, refer to utils.py. Now, let’s execute the following command in the terminal.
python3 run_finetune_ep.py --improvement 0 --batch-size 48 --max-seq-length 8192 --warmup-steps 5 --train-steps 30
Here is the expected output:
30 fine-tuning steps complete!
Median time per step: 2472 ms
Let’s add this information in a table and keep comparing it with a few possible improvements in future sections:
Models |
Precision |
Step Time |
Speedup (over baseline) |
|---|---|---|---|
HF baseline |
BF16 |
2472 ms |
1 |
[Improvement 1] Transformer Engine with Expert Parallelism (Precision: BF16)
Now that we have a baseline, let’s bring in Transformer Engine. This section replaces the HF Transformer block with TE modules and introduces expert parallelism (EP).
Fig 2: HF MixtralDecoderLayer (left) wrapped by TE modules (right).
Fused Building Blocks
Attention block. Instead of using one module for layer norm and another module for attention, TE combines them (
RMSNormand attention) withte.MultiheadAttention, where the inputRMSNormis bundled with the fused QKV projection. The layer norm weights and QKV weights are stored in the same building block:self_attention.layernorm_qkv.weight. Here is how you can use TE’s attention block:self.self_attention = transformer_engine.pytorch.MultiheadAttention( hidden_size=config.hidden_size, fuse_qkv_params=True, qkv_weight_interleaved=True, normalization="RMSNorm", input_layernorm=True, ... )
MoE block. TE provides the building blocks for the MoE layer. First, the gate computes router probabilities, then
softmaxandtop-kselect the top two experts out of eight for each token. The selected experts are then passed to the dispatcher. The dispatcher determines which EP ranks host the selected MoE experts, and NCCL handles the all-to-all communication that moves tokens to those ranks. The dispatcher also usestransformer_engine.pytorch.moe_permute_and_pad_with_probsto handle padding requirements, such as padding to multiples of 32 required byMXFP8.Below is the overview pseudocode for the MoE block:
router_logits = self.gate(hidden_states) softmax_probs = torch.nn.functional.softmax(router_logits, dim=-1) routing_weights, selected_experts = torch.topk(softmax_probs, self.top_k, dim=-1) dispatch_output = self.dispatcher.dispatch(hidden_states, selected_experts, routing_weights)
Parallelism layout
In this tutorial, EP=2 is used to split the model between 2 GPUs, each hosting 4 experts. In this 8-GPU setup, the model is replicated 4 times, creating 4 data-parallel groups.
Here is how to set up EP:
config.expert_parallel_size = 2
ep_size = config.expert_parallel_size
dp_size = world_size // ep_size
ep_group = None
for dp_rank in range(dp_size):
ranks = list(range(dp_rank * ep_size, (dp_rank + 1) * ep_size))
group = dist.new_group(ranks=ranks)
if dist.get_rank() in ranks:
ep_group = group
model.model.set_ep_groups(ep_group=ep_group)
Mapping the HF checkpoint to TE
Some weights/parameters need to be reshaped and also remapped to corresponding weight names in TE modules. The replace_params helper in te_mixtral.py performs the mapping (also illustrated in Fig 2 above). The two non-trivial groups are:
Attention. HF stores Q, K, V as separate projections; TE fuses them into a single QKV weight that lives under the
layernorm_qkvsubmodule:
HF key |
TE key |
|---|---|
|
|
|
|
|
|
|
|
MoE experts. HF packs all experts’
SwiGLUprojections into two tensors per layer; TE keeps the same packing under different attribute names soreplace_paramsis essentially a copy:
HF key |
TE key |
|---|---|
|
|
|
|
|
|
All other weights (embeddings, norms, LM head) are direct copies. See replace_params in te_mixtral.py for the full mapping.
Let’s launch the same fine-tuning loop – this time across 8 GPUs via torchrun. See run_finetune_ep.py and utils.py for the full implementation.
Now, let’s execute the following command.
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 1 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30
Here is the expected output:
30 fine-tuning steps complete!
Median time per step: 747 ms
Compared to the baseline implementation, we see the following result:
Models |
Precision |
Step Time |
Speedup (over baseline) |
|---|---|---|---|
HF baseline |
BF16 |
2472 ms |
1 |
TE decoder, TE building blocks, and MoE layer |
BF16 |
747 ms |
3.31 |
Improvement 1 is 3.31x faster than the baseline, a 231% speedup.
[Improvement 2] Batched Expert Execution with GroupedLinear (Precision: BF16)
Improvement 1 kept the per-expert FFN as a Python loop. If each rank owns 4 local experts, the loop launches the per-expert GEMMs one by one. Another limitation is that the expert GEMMs are usually small: each one sees only the tokens routed to it, which is too small to feed the tensor cores efficiently. The following section shows how to execute the GEMMs in a batch.
Fig 3: Left: looping through experts one-by-one. Right: one grouped-GEMM over all experts.
GroupedLinear applies multiple linear transformations in one call. It gathers the experts’ weights and input tokens. Although each expert receives a different number of tokens, GroupedLinear supports this by accepting per-expert token counts (split_sizes). GroupedLinear submits the local experts through TE’s grouped GEMM path instead of launching one PyTorch Linear operation per expert. This reduces per-expert launch and scheduling overhead.
Here are the steps to use GroupedLinear. Each expert keeps its own weight tensor (weight0, weight1, …), and the call takes the per-expert token counts as an extra positional argument:
from transformer_engine.pytorch.ops import GroupedLinear
experts_gate_up = GroupedLinear(
num_groups=num_local_experts,
in_features=hidden_size,
out_features=2 * intermediate_size,
bias=False,
dtype=torch.bfloat16,
device="cuda",
)
gate_up_output = experts_gate_up(tokens, split_sizes)
Compared with the Python loop in Improvement 1, this becomes one gate-up projection per layer instead of 4 separate calls (4 is the number of experts on a GPU). The expert weights can be imported from HF. In te_mixtral.py, the grouped-op path keeps each local expert as a normal per-expert weight{i} parameter and loads the owning expert slice directly.
To see the effect of GroupedLinear, we keep everything else unchanged. Execute the following command in the terminal.
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 2 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30
Here is the expected output:
30 fine-tuning steps complete!
Median time per step: 635 ms
Adding the GroupedLinear result gives us:
Models |
Precision |
Step Time |
Speedup (over baseline) |
|---|---|---|---|
HF baseline |
BF16 |
2472 ms |
1 |
TE decoder, TE building blocks, and MoE layer |
BF16 |
747 ms |
3.31 |
TE with |
BF16 |
635 ms |
3.89 |
GroupedLinear reaches a 3.89x speedup over the baseline, or 289%.
[Improvement 3] Precision Optimization and Fused MLP (Precision: MXFP8)
With EP and grouped expert GEMMs in place, the next improvement lowers precision from BF16 to MXFP8. MXFP8 converts the weight and activation values to 8 bits instead of 16 bits. To preserve dynamic range, it keeps one E8M0 scale factor for every 32 values; applying that scale recovers a wider numerical range. On Blackwell GPUs, MXFP8 is native and hardware accelerated, so MXFP8 GEMMs can run through specialized Tensor Core instructions. Read more about MXFP8 and
block scaling in the Transformer Engine FP8 primer.
The model still keeps its master weights in BF16, so using MXFP8 adds Quantization and De-Quantization work around the GEMMs. Quantization converts BF16 weights and activations to MXFP8 before the low-precision GEMM; De-Quantization converts the result back to the higher-precision format. The naive path performs these as separate operations. This motivates the fused MLP path shown below.
Fig 4: The MXFP8 path fuses multiple operations into one kernel before the down projection.
To use MXFP8, we simply define a recipe and pass it to the model.
fp8_recipe = te_recipe.MXFP8BlockScaling()
model = TEMixtralMXFP8ForCausalLM(config, fp8_recipe=fp8_recipe, dispatcher=dispatcher)
Now, the model’s forward and backward passes run under MXFP8 precision which is enabled through TE’s autocast API:
with te.autocast(enabled=True, recipe=self._fp8_recipe):
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states)
To use fused MLP, we import TE’s Sequential API to chain together gate_up, ScaledSwiGLU, and down. It also folds the De-Quantization step into the fused path. ScaledSwiGLU is chosen to combine the routing probabilities (“scales”) with the expert FFN computations.
from transformer_engine.pytorch.ops import GroupedLinear, ScaledSwiGLU, Sequential
experts_ffn = Sequential(GroupedLinear(gate_up), ScaledSwiGLU(), GroupedLinear(down))
TE’s Sequential scans the ops and, if the pattern matches, replaces the GroupedLinear -> ScaledSwiGLU -> GroupedLinear pattern with a fused operation object: ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 for forward and a matching fused backward op. It reduces framework overhead, fuses the SwiGLU/probability-scaling work into the grouped MLP path, and avoids some intermediate materialization.
Note
NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 must be set before TE imports the fused op registration. In this tutorial, run_finetune_ep.py already does that automatically for improvement 3.
Execute the following in the terminal:
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 3 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30
Here is the expected result:
30 fine-tuning steps complete!
Median time per step: 542 ms
With MXFP8 fused MLP included, the final comparison is:
Models |
Precision |
Step Time |
Speedup (over baseline) |
|---|---|---|---|
HF baseline |
BF16 |
2472 ms |
1 |
TE EP Python loop |
BF16 |
747 ms |
3.31 |
TE with |
BF16 |
635 ms |
3.89 |
TE with MXFP8 fused MLP |
MXFP8 |
542 ms |
4.56 |
For Mixtral-8x7B, we get the largest speedup with MXFP8 fused MLP: 4.56x faster than the baseline, or 356%.
Conclusion
This tutorial walks through three progressive optimization improvements that speed up the fine-tuning of the Mixtral-8x7B model by replacing building blocks in a Hugging Face baseline with TE-native blocks like MXFP8 and fused grouped MLP. The tutorial uses global batch 48, seq 8192 on 8x B300 to demonstrate the speedups.
To run all improvements together, execute the following in a terminal. All four runs use the same global batch size of 48. The TE runs use DP=4, so the per-rank batch size is 12.
python3 run_finetune_ep.py --improvement 0 --batch-size 48 --max-seq-length 8192 --warmup-steps 5 --train-steps 30
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 1 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 2 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 3 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30
Note on Scaling
For large-scale training, check out Megatron’s performance summary.
Appendix: Dependencies
File |
Purpose |
|---|---|
|
BF16 TE Mixtral implementation |
|
MXFP8 implementation |
|
Token dispatch/combine for MXFP8 |
|
Converts Hugging Face weights to Transformer Engine format |
|
Training loop |
|
CLI launcher |
|
Python package versions |
|
Input sequence preparation |