{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Accelerating Hugging Face Mixtral MoE Fine-Tuning with Transformer Engine\n", "\n", "
\n", "\n", "Goal\n", "\n", "This tutorial showcases how to accelerate fine-tuning a mixture-of-experts model, [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1), with Transformer Engine (TE) in `BF16` and `MXFP8` precision.\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Setup**\n", "\n", "Mixtral-8x7B has 8 experts and roughly 47B total parameters. In `BF16` the model weights alone consume ~93 GB, and full `AdamW` fine-tuning needs ~370 GB. This tutorial is tested on 8x B300 GPUs with `Expert Parallelism (EP) = 2` and `Data Parallelism (DP) = 4`, so the experts are divided across 2 GPUs and there are 4 replicas. The container used is [pytorch-26.04-py3](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch?version=26.04-py3). A sequence length of 8192 and a global batch size of 48 are used across the experiments.\n", "\n", "Install the required Python packages using the following command in a terminal:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```bash\n", "pip install -r requirements.txt \n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Table of Contents\n", "\n", "1. [Baseline] Running HF Mixtral -- Without Expert Parallelism (Precision: `BF16`)\n", "2. [Improvement 1] Transformer Engine with Expert Parallelism (Precision: `BF16`)\n", "3. [Improvement 2] Batched Expert Execution with `GroupedLinear` (Precision: `BF16`)\n", "4. [Improvement 3] Precision Optimization and Fused MLP (Precision: `MXFP8`)\n", "5. Conclusion\n", "6. Appendix: Dependencies" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [Baseline] Running HF Mixtral -- Without Expert Parallelism (Precision: `BF16`)\n", "\n", "Before applying any Transformer Engine optimizations, we establish a Hugging Face (HF) baseline. Mixtral replaces the standard Transformer feed-forward network (FFN) with a sparse **Mixture of Experts** (MoE): a learned router selects the top-2 experts out of 8 per token, as shown in Fig 1. \n", "\n", "
\n", "\n", "
Fig 1: Dense Transformer block (left) vs Sparse MoE Transformer block (right).
\n", "
\n", "\n", "\n", "The current HF implementation has two limitations.\n", "\n", "1. **Pipeline parallelism**. Because the full model does not fit on one GPU, the baseline uses pipeline parallelism to split the model across GPUs. This is the simplest way to partition a model, but GPU utilization is limited by pipeline bubbles and sequential layer dependencies.\n", "\n", "\n", "2. **Excessive kernel launches.** [HF's MixtralSparseMoeBlock](https://github.com/huggingface/transformers/blob/3ef278124e47832f34406ca3ca85bc50ad8b79bb/src/transformers/models/mixtral/modeling_mixtral.py) iterates over all 8 experts in a Python loop. Each expert triggers individual kernel launches. \n", "\n", "```python\n", "for expert_idx, expert_layer in enumerate(self.experts):\n", " idx, top_x = torch.where(expert_mask[expert_idx])\n", " current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)\n", " current_hidden = expert_layer(current_state) * routing_weights[top_x, idx, None]\n", " final_hidden_states.index_add_(0, top_x, current_hidden)\n", "```\n", "\n", "For each layer, HF loops through the experts sequentially. Each expert is much smaller than the dense FFN, so each expert GEMM is small and cannot saturate the GPU's tensor cores. Looping over many experts therefore launches many small GEMMs, leaving the FFN dominated by orchestration overhead and memory movement.\n", "\n", "\n", "The script [run_finetune_ep.py](run_finetune_ep.py) initializes Hugging Face and then runs fine-tuning. For the full implementation, refer to [utils.py](utils.py). Now, let's execute the following command in the terminal." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```bash\n", "python3 run_finetune_ep.py --improvement 0 --batch-size 48 --max-seq-length 8192 --warmup-steps 5 --train-steps 30\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is the expected output:\n", "\n", "```\n", "30 fine-tuning steps complete!\n", "Median time per step: 2472 ms\n", "```\n", "\n", "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", "\n", "| Models | Precision | Step Time | Speedup (over baseline) |\n", "|---|---|---:|---:|\n", "| HF baseline | BF16 | 2472 ms | 1 |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [Improvement 1] Transformer Engine with Expert Parallelism (Precision: `BF16`)\n", "\n", "Now that we have a baseline, let's bring in Transformer Engine. This section replaces the HF Transformer block with TE modules and introduces expert parallelism (EP). \n", "\n", "
\n", "\n", "
Fig 2: HF MixtralDecoderLayer (left) wrapped by TE modules (right).
\n", "
\n", "\n", "**Fused Building Blocks**\n", "\n", "- **Attention block.** Instead of using one module for layer norm and another module for attention, TE combines them (`RMSNorm` and attention) with `te.MultiheadAttention`, where the input `RMSNorm` is bundled with the fused QKV projection. The layer norm weights and QKV weights are stored in the same building block: `self_attention.layernorm_qkv.weight`. Here is how you can use TE's attention block:\n", "\n", " ```python\n", " self.self_attention = transformer_engine.pytorch.MultiheadAttention(\n", " hidden_size=config.hidden_size,\n", " fuse_qkv_params=True,\n", " qkv_weight_interleaved=True,\n", " normalization=\"RMSNorm\",\n", " input_layernorm=True,\n", " ...\n", " )\n", " ```\n", "\n", "- **MoE block.** TE provides the building blocks for the MoE layer. First, the gate computes router probabilities, then `softmax` and `top-k` select the top two experts out of eight for each token. The selected experts are then passed to the dispatcher. The dispatcher determines which EP ranks host the selected MoE experts, and NCCL handles the all-to-all communication that moves tokens to those ranks. The dispatcher also uses `transformer_engine.pytorch.moe_permute_and_pad_with_probs` to handle padding requirements, such as padding to multiples of 32 required by `MXFP8`.\n", "\n", " Below is the overview pseudocode for the MoE block:\n", "\n", " ```python\n", " router_logits = self.gate(hidden_states) \n", "\n", " softmax_probs = torch.nn.functional.softmax(router_logits, dim=-1)\n", "\n", " routing_weights, selected_experts = torch.topk(softmax_probs, self.top_k, dim=-1)\n", "\n", " dispatch_output = self.dispatcher.dispatch(hidden_states, selected_experts, routing_weights)\n", " ```\n", "\n", "**Parallelism layout**\n", "\n", "In this tutorial, EP=2 is used to split the model between 2 GPUs, each hosting 4 experts. In this 8-GPU setup, the model is replicated 4 times, creating 4 data-parallel groups.\n", "\n", "Here is how to set up EP:\n", "\n", "```python\n", "config.expert_parallel_size = 2\n", "ep_size = config.expert_parallel_size\n", "dp_size = world_size // ep_size\n", "ep_group = None\n", "for dp_rank in range(dp_size):\n", " ranks = list(range(dp_rank * ep_size, (dp_rank + 1) * ep_size))\n", " group = dist.new_group(ranks=ranks)\n", " if dist.get_rank() in ranks:\n", " ep_group = group\n", "model.model.set_ep_groups(ep_group=ep_group)\n", "```\n", "\n", "**Mapping the HF checkpoint to TE**\n", "\n", "Some weights/parameters need to be reshaped and also remapped to corresponding weight names in TE modules. The `replace_params` helper in [te_mixtral.py](te_mixtral.py) performs the mapping (also illustrated in Fig 2 above). The two non-trivial groups are:\n", "\n", "- **Attention.** HF stores Q, K, V as separate projections; TE fuses them into a single QKV weight that lives under the `layernorm_qkv` submodule:\n", "\n", "| HF key | TE key |\n", "|---|---|\n", "| `self_attn.q_proj.weight` | `self_attention.layernorm_qkv.weight` (Q slice) |\n", "| `self_attn.k_proj.weight` | `self_attention.layernorm_qkv.weight` (K slice) |\n", "| `self_attn.v_proj.weight` | `self_attention.layernorm_qkv.weight` (V slice) |\n", "| `input_layernorm.weight` | `self_attention.layernorm_qkv.layer_norm_weight` |\n", "\n", "- **MoE experts.** HF packs all experts' `SwiGLU` projections into two tensors per layer; TE keeps the same packing under different attribute names so `replace_params` is essentially a copy:\n", "\n", "| HF key | TE key |\n", "|---|---|\n", "| `mlp.experts.gate_up_proj` `[num_experts, 2*ffn, h]` | `mlp.experts_gate_up_weight` |\n", "| `mlp.experts.down_proj` `[num_experts, h, ffn]` | `mlp.experts_down_weight` |\n", "| `mlp.gate.weight` | `mlp.gate.weight` |\n", "\n", "All other weights (embeddings, norms, LM head) are direct copies. See `replace_params` in `te_mixtral.py` for the full mapping.\n", "\n", "Let's launch the same fine-tuning loop -- this time across 8 GPUs via `torchrun`. See `run_finetune_ep.py` and `utils.py` for the full implementation.\n", "\n", "Now, let's execute the following command." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```bash\n", "torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 1 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is the expected output:\n", "\n", "```\n", "30 fine-tuning steps complete!\n", "Median time per step: 747 ms\n", "```\n", "\n", "Compared to the baseline implementation, we see the following result:\n", "\n", "| Models | Precision | Step Time | Speedup (over baseline) |\n", "|---|---|---:|---:|\n", "| HF baseline | BF16 | 2472 ms | 1 |\n", "| TE decoder, TE building blocks, and MoE layer | BF16 | 747 ms | 3.31 |\n", "\n", "Improvement 1 is 3.31x faster than the baseline, a **231%** speedup." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [Improvement 2] Batched Expert Execution with `GroupedLinear` (Precision: `BF16`)\n", "\n", "Improvement 1 kept the per-expert FFN as a Python loop. If each rank owns 4 local experts, the loop launches the per-expert GEMMs one by one. Another limitation is that the expert GEMMs are usually small: each one sees only the tokens routed to it, which is too small to feed the tensor cores efficiently. The following section shows how to execute the GEMMs in a batch. \n", "\n", "
\n", "\n", "
Fig 3: Left: looping through experts one-by-one. Right: one grouped-GEMM over all experts.
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "`GroupedLinear` applies multiple linear transformations in one call. It gathers the experts' weights and input tokens. Although each expert receives a different number of tokens, `GroupedLinear` supports this by accepting per-expert token counts (`split_sizes`). `GroupedLinear` submits the local experts through TE’s grouped GEMM path instead of launching one PyTorch Linear operation per expert. This reduces per-expert launch and scheduling overhead. \n", "\n", "Here are the steps to use `GroupedLinear`. Each expert keeps its own weight tensor (`weight0`, `weight1`, ...), and the call takes the per-expert token counts as an extra positional argument:\n", "\n", "```python\n", "from transformer_engine.pytorch.ops import GroupedLinear\n", "\n", "experts_gate_up = GroupedLinear(\n", " num_groups=num_local_experts,\n", " in_features=hidden_size,\n", " out_features=2 * intermediate_size,\n", " bias=False,\n", " dtype=torch.bfloat16,\n", " device=\"cuda\",\n", ")\n", "\n", "gate_up_output = experts_gate_up(tokens, split_sizes)\n", "```\n", "\n", "Compared with the Python loop in Improvement 1, this becomes one gate-up projection per layer instead of 4 separate calls (4 is the number of experts on a GPU). The expert weights can be imported from HF. In `te_mixtral.py`, the grouped-op path keeps each local expert as a normal per-expert `weight{i}` parameter and loads the owning expert slice directly.\n", "\n", "To see the effect of `GroupedLinear`, we keep everything else unchanged. Execute the following command in the terminal. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```bash\n", "torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 2 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is the expected output:\n", "\n", "```\n", "30 fine-tuning steps complete!\n", "Median time per step: 635 ms\n", "```\n", "\n", "Adding the `GroupedLinear` result gives us:\n", "\n", "| Models | Precision | Step Time | Speedup (over baseline) |\n", "|---|---|---:|---:|\n", "| HF baseline | BF16 | 2472 ms | 1 |\n", "| TE decoder, TE building blocks, and MoE layer | BF16 | 747 ms | 3.31 |\n", "| TE with `GroupedLinear` | BF16 | 635 ms | 3.89 |\n", "\n", "`GroupedLinear` reaches a 3.89x speedup over the baseline, or **289%**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [Improvement 3] Precision Optimization and Fused MLP (Precision: `MXFP8`)\n", "\n", "With EP and grouped expert GEMMs in place, the next improvement lowers precision from `BF16` to `MXFP8`. `MXFP8` converts the weight and activation values to 8 bits instead of 16 bits. To preserve dynamic range, it keeps one `E8M0` scale factor for every 32 values; applying that scale recovers a wider numerical range. On Blackwell GPUs, `MXFP8` is native and hardware accelerated, so `MXFP8` GEMMs can run through specialized Tensor Core instructions. Read more about `MXFP8` and block scaling in the [Transformer Engine FP8 primer](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#MXFP8-and-block-scaling).\n", "\n", "The model still keeps its master weights in `BF16`, so using `MXFP8` adds `Quantization` and `De-Quantization` work around the GEMMs. `Quantization` converts `BF16` weights and activations to `MXFP8` before the low-precision GEMM; `De-Quantization` converts the result back to the higher-precision format. The naive path performs these as separate operations. This motivates the fused MLP path shown below.\n", "\n", "
\n", "\n", "
Fig 4: The MXFP8 path fuses multiple operations into one kernel before the down projection.
\n", "
\n", "\n", "To use `MXFP8`, we simply define a recipe and pass it to the model. \n", "\n", "```python\n", "fp8_recipe = te_recipe.MXFP8BlockScaling()\n", "model = TEMixtralMXFP8ForCausalLM(config, fp8_recipe=fp8_recipe, dispatcher=dispatcher)\n", "```\n", "\n", "Now, the model's forward and backward passes run under `MXFP8` precision which is enabled through TE's `autocast` API:\n", "\n", "```python\n", "with te.autocast(enabled=True, recipe=self._fp8_recipe):\n", " for decoder_layer in self.layers:\n", " hidden_states = decoder_layer(hidden_states)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use fused MLP, we import TE's `Sequential` API to chain together `gate_up`, `ScaledSwiGLU`, and `down`. It also folds the `De-Quantization` step into the fused path. `ScaledSwiGLU` is chosen to combine the routing probabilities (\"scales\") with the expert FFN computations. \n", "\n", "```python\n", "from transformer_engine.pytorch.ops import GroupedLinear, ScaledSwiGLU, Sequential\n", "\n", "experts_ffn = Sequential(GroupedLinear(gate_up), ScaledSwiGLU(), GroupedLinear(down))\n", "```\n", "\n", "TE's `Sequential` scans the ops and, if the pattern matches, replaces the `GroupedLinear -> ScaledSwiGLU -> GroupedLinear` pattern with a fused operation object: `ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8` for forward and a matching fused backward op. It reduces framework overhead, fuses the SwiGLU/probability-scaling work into the grouped MLP path, and avoids some intermediate materialization.\n", "\n", "
\n", "\n", "Note\n", "\n", "`NVTE_CUTEDSL_FUSED_GROUPED_MLP=1` must be set before TE imports the fused op registration. In this tutorial, `run_finetune_ep.py` already does that automatically for improvement 3.\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Execute the following in the terminal:\n", "\n", "```bash\n", "torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 3 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30\n", "```\n", "\n", "Here is the expected result:\n", "```\n", "30 fine-tuning steps complete!\n", "Median time per step: 542 ms\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With MXFP8 fused MLP included, the final comparison is:\n", "\n", "| Models | Precision | Step Time | Speedup (over baseline) |\n", "|---|---|---:|---:|\n", "| HF baseline | BF16 | 2472 ms | 1 |\n", "| TE EP Python loop | BF16 | 747 ms | 3.31 |\n", "| TE with `GroupedLinear` | BF16 | 635 ms | 3.89 |\n", "| TE with MXFP8 fused MLP | MXFP8 | 542 ms | 4.56 |\n", "\n", "For Mixtral-8x7B, we get the largest speedup with MXFP8 fused MLP: 4.56x faster than the baseline, or **356%**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "This tutorial walks through three progressive optimization improvements that speed up the fine-tuning of the Mixtral-8x7B model by replacing building blocks in a Hugging Face baseline with TE-native blocks like MXFP8 and fused grouped MLP. The tutorial uses **global batch 48, seq 8192** on 8x B300 to demonstrate the speedups.\n", "\n", "To run all improvements together, execute the following in a terminal. All four runs use the same global batch size of 48. The TE runs use DP=4, so the per-rank batch size is 12.\n", "\n", "```bash\n", "python3 run_finetune_ep.py --improvement 0 --batch-size 48 --max-seq-length 8192 --warmup-steps 5 --train-steps 30\n", "\n", "torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 1 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30\n", "\n", "torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 2 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30\n", "\n", "torchrun --standalone --nproc_per_node=8 run_finetune_ep.py --improvement 3 --ep-size 2 --batch-size 12 --max-seq-length 8192 --warmup-steps 5 --train-steps 30\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "Note on Scaling\n", "\n", "For large-scale training, check out [Megatron's performance summary](https://docs.nvidia.com/nemo/megatron-bridge/latest/performance-summary.html).\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Appendix: Dependencies\n", "\n", "| File | Purpose |\n", "|---|---|\n", "| `te_mixtral.py` | BF16 TE Mixtral implementation |\n", "| `te_mixtral_mxfp8.py` | MXFP8 implementation |\n", "| `te_moe_dispatch.py` | Token dispatch/combine for MXFP8 |\n", "| `hf_to_te_weights.py` | Converts Hugging Face weights to Transformer Engine format |\n", "| `utils.py` | Training loop |\n", "| `run_finetune_ep.py` | CLI launcher |\n", "| `requirements.txt` | Python package versions |\n", "| `collator.py` | Input sequence preparation |" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.12.0" }, "nbsphinx": { "execute": "never" } }, "nbformat": 4, "nbformat_minor": 4 }