Introduction

Transformer Engine accelerates deep learning on NVIDIA GPUs in several ways, with low precision training being one of the most important. This chapter introduces mixed precision training and FP8 support.

Training in BF16/FP16

Deep learning traditionally uses 32-bit floating-point (FP32) numbers. NVIDIA GPUs support lower precision formats—FP16 since Pascal, BF16 since Ampere—which offer higher throughput and lower memory usage. Let’s compare these formats.

sign exponent mantissa FP32 0 0 1 1 1 1 1 0 1 1 0 0 1 0 1 0 0 1 0 1 0 1 1 1 1 0 1 0 1 0 0 0 = 0.3952 BF16 0 0 1 1 1 1 1 0 1 1 0 0 1 0 1 0 ≈ 0.3945 FP16 0 0 1 1 0 1 1 0 0 1 0 1 0 0 1 0 ≈ 0.3950

Figure 1: Comparison of FP32, BF16, and FP16 floating-point formats showing bit allocation for sign, exponent, and mantissa.

The key differences between these formats are:

  • FP32 (32 bits total): 1 sign bit + 8 exponent bits + 23 mantissa bits – standard single-precision format

  • BF16 (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32’s exponent range but has reduced precision

  • FP16 (16 bits total): 1 sign bit + 5 exponent bits + 10 mantissa bits – reduced range but higher precision than BF16

BF16’s advantage is that it shares the same exponent range as FP32, making it easier to convert between the two formats without overflow/underflow issues. FP16 offers better precision for smaller values but has a limited dynamic range, which results in the need to perform loss scaling to avoid overflow/underflow—see this paper on loss scaling for more details.

Mixed precision

Not all operations should be run in reduced precision to preserve accuracy. Modern deep learning frameworks use mixed precision training, where different operations use different precisions based on their numerical properties:

  • Matrix multiplications are compute-heavy and remain numerically stable at lower precision, making them ideal candidates for acceleration.

  • Operations like layer normalization and softmax can work with low precision inputs and outputs, but may use high precision internally or for their weights.

  • Operations like loss computation and exponentiation need high precision throughout.

Master weights

Another consideration in mixed precision training is how to store the model weights. Lower precision formats like FP16 and BF16 have limited representational granularity, which becomes problematic during gradient updates. When a small gradient is added to a not so small weight stored in low precision, the result may round back to the original value if the update falls below the format’s precision threshold. Moreover, some elements of the gradient itself can be too small to be represented in low precision, especially after the accumulation from multiple GPUs in the data parallel training setting.

The solution is to maintain master weights in FP32. During training, weights are cast to lower precision for forward and backward passes, but the gradient updates are applied to the full-precision master copy. This ensures that even small gradients accumulate correctly over time.

There are two common software approaches to storing master weights:

  • In the optimizer: The model holds low-precision weights, while the optimizer maintains FP32 copies alongside momentum and other state. During each step, the optimizer updates its FP32 copy and casts the result back to the model’s low-precision weights.

    This approach makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer.

    Since the casting happens only during the optimizer step, this approach is also faster when optimizer runs less frequently than the model, e.g. when performing gradient accumulation or pipeline parallel training.

  • In the model: The model stores weights directly in FP32, and they are cast to lower precision on-the-fly during forward and backward passes. This approach works seamlessly with any standard optimizer, requiring no special support.

Master Weights Storage Approaches Low Precision Weights (no master weights) Model Weights (BF16/FP16) Forward/Backward Optimizer State (FP32) Master Weights in Model Model Weights (FP32) cast to BF16/FP16 Forward/Backward Optimizer State (FP32) Master Weights in Optimizer cast to BF16/FP16 Model Weights (BF16/FP16) Forward/Backward Optimizer State (FP32) Master (FP32)

Figure 2: Three approaches to weight storage—low precision only (no master weights), master weights stored in the model, and master weights stored in the optimizer.

The PyTorch API of Transformer Engine provides several mechanisms to control precision:

  • Weight precision: Use the params_dtype argument in any TE layer constructor.

  • Computation precision: Use the torch.autocast context manager. When enabled, inputs are cast to the autocast dtype before computation.

  • Input dtype: When torch.autocast is not used, the input tensor’s dtype determines the computation precision. In this case, inputs and parameters must have matching dtypes.


import torch
import transformer_engine.pytorch as te
from contextlib import nullcontext


def run_forward_backward(params_dtype, autocast_precision, grad_scaler_enabled):
    if grad_scaler_enabled:
        grad_scaler = torch.amp.GradScaler("cuda")

    layer = te.TransformerLayer(
        hidden_size=1024,
        ffn_hidden_size=4096,
        num_attention_heads=16,
        params_dtype=params_dtype,
    )
    x = torch.randn(32, 128, 1024, dtype=params_dtype, device="cuda")

    autocast_ctx = (
        torch.autocast(device_type="cuda", dtype=autocast_precision)
        if autocast_precision is not None
        else nullcontext()
    )
    with autocast_ctx:
        output = layer(x)
        assert (
            output.dtype == autocast_precision if autocast_precision is not None else params_dtype
        )
        loss = output.sum()
    if grad_scaler_enabled:
        grad_scaler.scale(loss).backward()
    else:
        loss.backward()


run_forward_backward(torch.float32, torch.float32, False)  # high precision training
run_forward_backward(
    torch.float32, torch.bfloat16, False
)  # bfloat16 training with master weights in FP32
run_forward_backward(
    torch.float32, torch.float16, True
)  # fp16 training with master weights in FP32, needs loss scaling
run_forward_backward(
    torch.bfloat16, torch.bfloat16, False
)  # bfloat16 training with weights in BF16

Lower precisions

Transformer Engine’s primary feature is supporting even lower precision than BF16/FP16, such as FP8, MXFP8, NVFP4, etc. The logic of these precisions is more complicated than the logic of BF16/FP16 – they require scaling factors to properly represent the full range of values in the tensor. Sometimes it is one scaling factor per tensor, sometimes it is one scaling factor per block of values. A precision format combined with the logic for training is called a recipe.

In this section we present common logic for all the recipes. Each one of them is described in more detail in a separate section later. Let’s now see how we can train in lower precisions in supported frameworks.

The PyTorch API of Transformer Engine provides an autocast context manager to control precision. It’s similar to the torch.autocast context manager, but tailored for low precision training. The most important argument is the recipe argument, which accepts objects inheriting from Recipe.

Forward computations need to be performed inside the autocast context manager, while the .backward() call should be outside of it (it inherits the setting from the corresponding forward pass).

Here is a basic example:

Needs to be run on SM89+ (Ada or newer)

import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format

recipe = DelayedScaling()
layer = te.Linear(1024, 1024)
inp = torch.randn(32, 1024, dtype=torch.float32, device="cuda")

with te.autocast(enabled=True, recipe=recipe):
    output = layer(inp)

# .backward() is called outside of autocast
loss = output.sum()
loss.backward()

You can use multiple recipes in the same model in the following ways:

Sequential contexts – apply different recipes to different parts of your model:

Needs to be run on SM89+ (Ada or newer)

encoder_recipe = DelayedScaling(fp8_format=Format.E4M3)
decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID)

encoder = te.Linear(1024, 1024)
decoder = te.Linear(1024, 1024)

with te.autocast(enabled=True, recipe=encoder_recipe):
    hidden = encoder(inp)

with te.autocast(enabled=True, recipe=decoder_recipe):
    output = decoder(hidden)

Nested contexts – the inner context overrides the outer one for its scope:

Needs to be run on SM89+ (Ada or newer)

outer_recipe = DelayedScaling(fp8_format=Format.E4M3)
inner_recipe = DelayedScaling(fp8_format=Format.HYBRID)

layer1 = te.Linear(1024, 1024)
layer2 = te.Linear(1024, 1024)
layer3 = te.Linear(1024, 1024)

with te.autocast(enabled=True, recipe=outer_recipe):
    # layer1 uses outer_recipe
    x = layer1(inp)

    with te.autocast(enabled=True, recipe=inner_recipe):
        # layer2 uses inner_recipe (overrides outer)
        x = layer2(x)

    # layer3 uses outer_recipe again
    output = layer3(x)

Mixed precision with 8- or 4-bit precisions

From now on, we will refer to FP8/MXFP8/NVFP4 etc. as low precision and to FP32/BF16/FP16 as high precision. This terminology will be used throughout the rest of the documentation.

Not all operations run in low precision:

  • Linear operations: run in low precision.

  • Attention computations: run in high precision by default (some recipes allow low precision as an option).

  • Other operations (layer normalization, softmax, etc.): run in high precision.

Within high-precision operations, there are two categories:

  • Configurable precision: most operations run in parameter precision (FP32/BF16/FP16) or the precision specified by torch.autocast.

  • Fixed FP32 precision: some operations, or parts of operations—such as the division in layernorm—always run in FP32, regardless of other settings.

Transformer Layer – default precision of operation in low precision recipe Input Layer Norm QKV Linear QK^T Softmax Scores * V Output Linear Dropout + Add Layer Norm FFN Linear 1 GELU FFN Linear 2 Output Parameters Gradients Higher Precision (FP32/BF16/FP16) Lower Precision (FP8, MXFP8 etc.)

Figure 3: Default precision of operations in a TransformerLayer forward pass. Only linear operations are in lower precision. Dot product attention is shown as three separate operations (QK^T, Softmax, Scores * V), though in practice these may be fused into a single kernel.

Linear layer data flow

Let’s see how data flow of a linear layer works by default on a single H100 GPU with FP8 precision:

H100 (Hopper) architecture natively supports FP8 Matrix Multiplication only in TN layout (Transpose-NoTranspose), so GEMM with tensors A and B returns B * A^T.

Forward pass

  • Input is quantized to FP8 – both input and input^T quantized versions are created.

  • Weights are stored in high precision and quantized to low precision before the GEMM – both weight and weight^T quantized versions are created.

  • FP8 GEMM with layout TN is run with weight and input tensors,

  • Outputs – input * weight^T tensor – are returned in high precision.

Backward pass

  • Output gradients are quantized to FP8 – both output_grad and output_grad^T quantized versions are created.

  • FP8 GEMM with layout TN is performed with weight^T and output_grad tensors to compute input gradients.

  • FP8 GEMM with layout TN is performed with input^T and output_grad^T tensors to compute weight gradients.

  • Input gradients – output_grad * weight tensor – are returned in high precision.

  • Weight gradients – output_grad^T * input tensor – are returned in high precision.

FP8 Linear Layer – Forward and Backward Pass Forward Pass InputT Input Quantize Input N Weight Quantize Weight WeightT T FP8 GEMM (TN) Output Backward Pass WeightT Output grad. Quantize Output grad. Output grad.T FP8 GEMM (TN) Input grad. FP8 GEMM (TN) Weight grad. InputT N T N T Higher Precision (FP32/BF16/FP16) Lower Precision (FP8, MXFP8 etc.)

Figure 4: Forward pass of a Linear layer with low precision data flow.