Accelerating Hugging Face Mixtral MoE Fine-Tuning with Transformer Engine

Goal

This tutorial showcases how to accelerate fine-tuning a mixture-of-experts model, Mixtral-8x7B, with Transformer Engine (TE) in BF16 and MXFP8 precision.

Setup

Mixtral-8x7B has 8 experts and roughly 47B total parameters. In BF16 the model weights alone consume ~93 GB, and full AdamW fine-tuning needs ~370 GB. This tutorial is tested on 8x B300 GPUs with Expert Parallelism (EP) = 2 and Data Parallelism (DP) = 4, so the experts are divided across 2 GPUs and there are 4 replicas. The container used is pytorch-26.04-py3. A sequence length of 8192 and a global batch size of 48 are used across the experiments.

Install the required Python packages using the following command in a terminal:

pip install -r requirements.txt

Table of Contents

  1. [Baseline] Running HF Mixtral – Without Expert Parallelism (Precision: BF16)

  2. [Improvement 1] Transformer Engine with Expert Parallelism (Precision: BF16)

  3. [Improvement 2] Batched Expert Execution with GroupedLinear (Precision: BF16)

  4. [Improvement 3] Precision Optimization and Fused MLP (Precision: MXFP8)

  5. Conclusion

  6. Appendix: Dependencies

[Baseline] Running HF Mixtral – Without Expert Parallelism (Precision: BF16)

Before applying any Transformer Engine optimizations, we establish a Hugging Face (HF) baseline. Mixtral replaces the standard Transformer feed-forward network (FFN) with a sparse Mixture of Experts (MoE): a learned router selects the top-2 experts out of 8 per token, as shown in Fig 1.

0904eabadbb546a8b76107ba66655686

Fig 1: Dense Transformer block (left) vs Sparse MoE Transformer block (right).

The current HF implementation has two limitations.

  1. Pipeline parallelism. Because the full model does not fit on one GPU, the baseline uses pipeline parallelism to split the model across GPUs. This is the simplest way to partition a model, but GPU utilization is limited by pipeline bubbles and sequential layer dependencies.

  2. Excessive kernel launches. HF’s MixtralSparseMoeBlock iterates over all 8 experts in a Python loop. Each expert triggers individual kernel launches.

for expert_idx, expert_layer in enumerate(self.experts):
    idx, top_x = torch.where(expert_mask[expert_idx])
    current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
    current_hidden = expert_layer(current_state) * routing_weights[top_x, idx, None]
    final_hidden_states.index_add_(0, top_x, current_hidden)

For each layer, HF loops through the experts sequentially. Each expert is much smaller than the dense FFN, so each expert GEMM is small and cannot saturate the GPU’s tensor cores. Looping over many experts therefore launches many small GEMMs, leaving the FFN dominated by orchestration overhead and memory movement.

The script run_finetune_ep.py initializes Hugging Face and then runs fine-tuning. For the full implementation, refer to utils.py. Now, let’s execute the following command in the terminal.

python3 run_finetune_ep.py --improvement 0 --batch-size 48 --max-seq-length 8192 --warmup-steps 5 --train-steps 30

Here is the expected output:

30 fine-tuning steps complete!
Median time per step: 2472 ms

Let’s add this information in a table and keep comparing it with a few possible improvements in future sections:

Models

Precision

Step Time

Speedup (over baseline)

HF baseline

BF16

2472 ms

1

[Improvement 1] Transformer Engine with Expert Parallelism (Precision: BF16)

Now that we have a baseline, let’s bring in Transformer Engine. This section replaces the HF Transformer block with TE modules and introduces expert parallelism (EP).

9ccb5890fcfd4c469b7a30b3afc209cf

Fig 2: HF MixtralDecoderLayer (left) wrapped by TE modules (right).

Fused Building Blocks

  • Attention block. Instead of using one module for layer norm and another module for attention, TE combines them (RMSNorm and attention) with te.MultiheadAttention, where the input RMSNorm is bundled with the fused QKV projection. The layer norm weights and QKV weights are stored in the same building block: self_attention.layernorm_qkv.weight. Here is how you can use TE’s attention block:

    self.self_attention = transformer_engine.pytorch.MultiheadAttention(
                hidden_size=config.hidden_size,
                fuse_qkv_params=True,
                qkv_weight_interleaved=True,
                normalization="RMSNorm",
                input_layernorm=True,
                ...
            )
    
  • MoE block. TE provides the building blocks for the MoE layer. First, the gate computes router probabilities, then softmax and top-k select the top two experts out of eight for each token. The selected experts are then passed to the dispatcher. The dispatcher determines which EP ranks host the selected MoE experts, and NCCL handles the all-to-all communication that moves tokens to those ranks. The dispatcher also uses transformer_engine.pytorch.moe_permute_and_pad_with_probs to handle padding requirements, such as padding to multiples of 32 required by MXFP8.

    Below is the overview pseudocode for the MoE block:

    router_logits = self.gate(hidden_states)
    
    softmax_probs = torch.nn.functional.softmax(router_logits, dim=-1)
    
    routing_weights, selected_experts = torch.topk(softmax_probs, self.top_k, dim=-1)
    
    dispatch_output = self.dispatcher.dispatch(hidden_states, selected_experts, routing_weights)
    

Parallelism layout

In this tutorial, EP=2 is used to split the model between 2 GPUs, each hosting 4 experts. In this 8-GPU setup, the model is replicated 4 times, creating 4 data-parallel groups.

Here is how to set up EP:

config.expert_parallel_size = 2
ep_size = config.expert_parallel_size
dp_size = world_size // ep_size
ep_group = None
for dp_rank in range(dp_size):
    ranks = list(range(dp_rank * ep_size, (dp_rank + 1) * ep_size))
    group = dist.new_group(ranks=ranks)
    if dist.get_rank() in ranks:
        ep_group = group
model.model.set_ep_groups(ep_group=ep_group)

Mapping the HF checkpoint to TE

Some weights/parameters need to be reshaped and also remapped to corresponding weight names in TE modules. The replace_params helper in te_mixtral.py performs the mapping (also illustrated in Fig 2 above). The two non-trivial groups are:

  • Attention. HF stores Q, K, V as separate projections; TE fuses them into a single QKV weight that lives under the layernorm_qkv submodule:

HF key

TE key

self_attn.q_proj.weight

self_attention.layernorm_qkv.weight (Q slice)

self_attn.k_proj.weight

self_attention.layernorm_qkv.weight (K slice)

self_attn.v_proj.weight

self_attention.layernorm_qkv.weight (V slice)

input_layernorm.weight

self_attention.layernorm_qkv.layer_norm_weight

  • MoE experts. HF packs all experts’ SwiGLU projections into two tensors per layer; TE keeps the same packing under different attribute names so replace_params is essentially a copy:

HF key

TE key

mlp.experts.gate_up_proj [num_experts, 2*ffn, h]

mlp.experts_gate_up_weight

mlp.experts.down_proj [num_experts, h, ffn]

mlp.experts_down_weight

mlp.gate.weight

mlp.gate.weight

All other weights (embeddings, norms, LM head) are direct copies. See replace_params in te_mixtral.py for the full mapping.

Let’s launch the same fine-tuning loop – this time across 8 GPUs via torchrun. See run_finetune_ep.py and utils.py for the full implementation.

Now, let’s execute the following command.

torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 1 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30

Here is the expected output:

30 fine-tuning steps complete!
Median time per step: 747 ms

Compared to the baseline implementation, we see the following result:

Models

Precision

Step Time

Speedup (over baseline)

HF baseline

BF16

2472 ms

1

TE decoder, TE building blocks, and MoE layer

BF16

747 ms

3.31

Improvement 1 is 3.31x faster than the baseline, a 231% speedup.

[Improvement 2] Batched Expert Execution with GroupedLinear (Precision: BF16)

Improvement 1 kept the per-expert FFN as a Python loop. If each rank owns 4 local experts, the loop launches the per-expert GEMMs one by one. Another limitation is that the expert GEMMs are usually small: each one sees only the tokens routed to it, which is too small to feed the tensor cores efficiently. The following section shows how to execute the GEMMs in a batch.

ad88f37438264424916cf4844faffb5a

Fig 3: Left: looping through experts one-by-one. Right: one grouped-GEMM over all experts.

GroupedLinear applies multiple linear transformations in one call. It gathers the experts’ weights and input tokens. Although each expert receives a different number of tokens, GroupedLinear supports this by accepting per-expert token counts (split_sizes). GroupedLinear submits the local experts through TE’s grouped GEMM path instead of launching one PyTorch Linear operation per expert. This reduces per-expert launch and scheduling overhead.

Here are the steps to use GroupedLinear. Each expert keeps its own weight tensor (weight0, weight1, …), and the call takes the per-expert token counts as an extra positional argument:

from transformer_engine.pytorch.ops import GroupedLinear

experts_gate_up = GroupedLinear(
    num_groups=num_local_experts,
    in_features=hidden_size,
    out_features=2 * intermediate_size,
    bias=False,
    dtype=torch.bfloat16,
    device="cuda",
)

gate_up_output = experts_gate_up(tokens, split_sizes)

Compared with the Python loop in Improvement 1, this becomes one gate-up projection per layer instead of 4 separate calls (4 is the number of experts on a GPU). The expert weights can be imported from HF. In te_mixtral.py, the grouped-op path keeps each local expert as a normal per-expert weight{i} parameter and loads the owning expert slice directly.

To see the effect of GroupedLinear, we keep everything else unchanged. Execute the following command in the terminal.

torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 2 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30

Here is the expected output:

30 fine-tuning steps complete!
Median time per step: 635 ms

Adding the GroupedLinear result gives us:

Models

Precision

Step Time

Speedup (over baseline)

HF baseline

BF16

2472 ms

1

TE decoder, TE building blocks, and MoE layer

BF16

747 ms

3.31

TE with GroupedLinear

BF16

635 ms

3.89

GroupedLinear reaches a 3.89x speedup over the baseline, or 289%.

[Improvement 3] Precision Optimization and Fused MLP (Precision: MXFP8)

With EP and grouped expert GEMMs in place, the next improvement lowers precision from BF16 to MXFP8. MXFP8 converts the weight and activation values to 8 bits instead of 16 bits. To preserve dynamic range, it keeps one E8M0 scale factor for every 32 values; applying that scale recovers a wider numerical range. On Blackwell GPUs, MXFP8 is native and hardware accelerated, so MXFP8 GEMMs can run through specialized Tensor Core instructions. Read more about MXFP8 and block scaling in the Transformer Engine FP8 primer.

The model still keeps its master weights in BF16, so using MXFP8 adds Quantization and De-Quantization work around the GEMMs. Quantization converts BF16 weights and activations to MXFP8 before the low-precision GEMM; De-Quantization converts the result back to the higher-precision format. The naive path performs these as separate operations. This motivates the fused MLP path shown below.

9dfe03d7b1524752bff35137c9cb6ae8

Fig 4: The MXFP8 path fuses multiple operations into one kernel before the down projection.

To use MXFP8, we simply define a recipe and pass it to the model.

fp8_recipe = te_recipe.MXFP8BlockScaling()
model = TEMixtralMXFP8ForCausalLM(config, fp8_recipe=fp8_recipe, dispatcher=dispatcher)

Now, the model’s forward and backward passes run under MXFP8 precision which is enabled through TE’s autocast API:

with te.autocast(enabled=True, recipe=self._fp8_recipe):
    for decoder_layer in self.layers:
        hidden_states = decoder_layer(hidden_states)

To use fused MLP, we import TE’s Sequential API to chain together gate_up, ScaledSwiGLU, and down. It also folds the De-Quantization step into the fused path. ScaledSwiGLU is chosen to combine the routing probabilities (“scales”) with the expert FFN computations.

from transformer_engine.pytorch.ops import GroupedLinear, ScaledSwiGLU, Sequential

experts_ffn = Sequential(GroupedLinear(gate_up), ScaledSwiGLU(), GroupedLinear(down))

TE’s Sequential scans the ops and, if the pattern matches, replaces the GroupedLinear -> ScaledSwiGLU -> GroupedLinear pattern with a fused operation object: ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 for forward and a matching fused backward op. It reduces framework overhead, fuses the SwiGLU/probability-scaling work into the grouped MLP path, and avoids some intermediate materialization.

Note

NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 must be set before TE imports the fused op registration. In this tutorial, run_finetune_ep.py already does that automatically for improvement 3.

Execute the following in the terminal:

torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 3 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30

Here is the expected result:

30 fine-tuning steps complete!
Median time per step: 542 ms

With MXFP8 fused MLP included, the final comparison is:

Models

Precision

Step Time

Speedup (over baseline)

HF baseline

BF16

2472 ms

1

TE EP Python loop

BF16

747 ms

3.31

TE with GroupedLinear

BF16

635 ms

3.89

TE with MXFP8 fused MLP

MXFP8

542 ms

4.56

For Mixtral-8x7B, we get the largest speedup with MXFP8 fused MLP: 4.56x faster than the baseline, or 356%.

Conclusion

This tutorial walks through three progressive optimization improvements that speed up the fine-tuning of the Mixtral-8x7B model by replacing building blocks in a Hugging Face baseline with TE-native blocks like MXFP8 and fused grouped MLP. The tutorial uses global batch 48, seq 8192 on 8x B300 to demonstrate the speedups.

To run all improvements together, execute the following in a terminal. All four runs use the same global batch size of 48. The TE runs use DP=4, so the per-rank batch size is 12.

python3 run_finetune_ep.py --improvement 0 --batch-size 48 --max-seq-length 8192 --warmup-steps 5 --train-steps 30

torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 1 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30

torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 2 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30

torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 3 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30

Note on Scaling

For large-scale training, check out Megatron’s performance summary.

Appendix: Dependencies

File

Purpose

te_mixtral.py

BF16 TE Mixtral implementation

te_mixtral_mxfp8.py

MXFP8 implementation

te_moe_dispatch.py

Token dispatch/combine for MXFP8

hf_to_te_weights.py

Converts Hugging Face weights to Transformer Engine format

utils.py

Training loop

run_finetune_ep.py

CLI launcher

requirements.txt

Python package versions

collator.py

Input sequence preparation