Introduction
Transformer Engine accelerates deep learning on NVIDIA GPUs in several ways, with low precision training being one of the most important. This chapter introduces mixed precision training and FP8 support.
Training in BF16/FP16
Deep learning traditionally uses 32-bit floating-point (FP32) numbers. NVIDIA GPUs support lower precision formats—FP16 since Pascal, BF16 since Ampere—which offer higher throughput and lower memory usage. Let’s compare these formats.
Figure 1: Comparison of FP32, BF16, and FP16 floating-point formats showing bit allocation for sign, exponent, and mantissa.
The key differences between these formats are:
FP32 (32 bits total): 1 sign bit + 8 exponent bits + 23 mantissa bits – standard single-precision format
BF16 (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32’s exponent range but has reduced precision
FP16 (16 bits total): 1 sign bit + 5 exponent bits + 10 mantissa bits – reduced range but higher precision than BF16
BF16’s advantage is that it shares the same exponent range as FP32, making it easier to convert between the two formats without overflow/underflow issues. FP16 offers better precision for smaller values but has a limited dynamic range, which results in the need to perform loss scaling to avoid overflow/underflow—see this paper on loss scaling for more details.
Mixed precision
Not all operations should be run in reduced precision to preserve accuracy. Modern deep learning frameworks use mixed precision training, where different operations use different precisions based on their numerical properties:
Matrix multiplications are compute-heavy and remain numerically stable at lower precision, making them ideal candidates for acceleration.
Operations like layer normalization and softmax can work with low precision inputs and outputs, but may use high precision internally or for their weights.
Operations like loss computation and exponentiation need high precision throughout.
Master weights
Another consideration in mixed precision training is how to store the model weights. Lower precision formats like FP16 and BF16 have limited representational granularity, which becomes problematic during gradient updates. When a small gradient is added to a not so small weight stored in low precision, the result may round back to the original value if the update falls below the format’s precision threshold. Moreover, some elements of the gradient itself can be too small to be represented in low precision, especially after the accumulation from multiple GPUs in the data parallel training setting.
The solution is to maintain master weights in FP32. During training, weights are cast to lower precision for forward and backward passes, but the gradient updates are applied to the full-precision master copy. This ensures that even small gradients accumulate correctly over time.
There are two common software approaches to storing master weights:
In the optimizer: The model holds low-precision weights, while the optimizer maintains FP32 copies alongside momentum and other state. During each step, the optimizer updates its FP32 copy and casts the result back to the model’s low-precision weights.
This approach makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer.
Since the casting happens only during the optimizer step, this approach is also faster when optimizer runs less frequently than the model, e.g. when performing gradient accumulation or pipeline parallel training.
In the model: The model stores weights directly in FP32, and they are cast to lower precision on-the-fly during forward and backward passes. This approach works seamlessly with any standard optimizer, requiring no special support.
Figure 2: Three approaches to weight storage—low precision only (no master weights), master weights stored in the model, and master weights stored in the optimizer.
The PyTorch API of Transformer Engine provides several mechanisms to control precision:
Weight precision: Use the
params_dtypeargument in any TE layer constructor.Computation precision: Use the
torch.autocastcontext manager. When enabled, inputs are cast to the autocast dtype before computation.Input dtype: When
torch.autocastis not used, the input tensor’s dtype determines the computation precision. In this case, inputs and parameters must have matching dtypes.
import torch
import transformer_engine.pytorch as te
from contextlib import nullcontext
def run_forward_backward(params_dtype, autocast_precision, grad_scaler_enabled):
if grad_scaler_enabled:
grad_scaler = torch.amp.GradScaler("cuda")
layer = te.TransformerLayer(
hidden_size=1024,
ffn_hidden_size=4096,
num_attention_heads=16,
params_dtype=params_dtype,
)
x = torch.randn(32, 128, 1024, dtype=params_dtype, device="cuda")
autocast_ctx = (
torch.autocast(device_type="cuda", dtype=autocast_precision)
if autocast_precision is not None
else nullcontext()
)
with autocast_ctx:
output = layer(x)
assert (
output.dtype == autocast_precision if autocast_precision is not None else params_dtype
)
loss = output.sum()
if grad_scaler_enabled:
grad_scaler.scale(loss).backward()
else:
loss.backward()
run_forward_backward(torch.float32, torch.float32, False) # high precision training
run_forward_backward(
torch.float32, torch.bfloat16, False
) # bfloat16 training with master weights in FP32
run_forward_backward(
torch.float32, torch.float16, True
) # fp16 training with master weights in FP32, needs loss scaling
run_forward_backward(
torch.bfloat16, torch.bfloat16, False
) # bfloat16 training with weights in BF16
The JAX API of Transformer Engine provides two mechanisms to control precision:
Weight precision: Use the
dtypeargument in any TE layer constructor.Computation precision: Determined by the dtype of the input tensor.
For training with master weights in FP32 and computation in BF16, cast the input tensor to BF16 before passing it to the layer.
import jax
import jax.numpy as jnp
from transformer_engine.jax.flax import TransformerLayer
def run_forward_backward(params_dtype, compute_dtype):
# Create TransformerLayer
layer = TransformerLayer(
hidden_size=1024,
mlp_hidden_size=4096,
num_attention_heads=16,
dtype=params_dtype,
)
# Initialize parameters and optimizer
init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(init_key, (32, 128, 1024), dtype=compute_dtype)
var_collect = layer.init({"params": init_key, "dropout": dropout_key}, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x, rngs={"dropout": dropout_key})
assert output.dtype == compute_dtype
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
run_forward_backward(jnp.float32, jnp.float32) # high precision training
run_forward_backward(jnp.float32, jnp.bfloat16) # bfloat16 training with master weights in FP32
run_forward_backward(jnp.bfloat16, jnp.bfloat16) # bfloat16 training with weights in BF16
Lower precisions
Transformer Engine’s primary feature is supporting even lower precision than BF16/FP16, such as FP8, MXFP8, NVFP4, etc. The logic of these precisions is more complicated than the logic of BF16/FP16 – they require scaling factors to properly represent the full range of values in the tensor. Sometimes it is one scaling factor per tensor, sometimes it is one scaling factor per block of values. A precision format combined with the logic for training is called a recipe.
In this section we present common logic for all the recipes. Each one of them is described in more detail in a separate section later. Let’s now see how we can train in lower precisions in supported frameworks.
The PyTorch API of Transformer Engine provides an autocast context manager to control precision.
It’s similar to the torch.autocast context manager, but tailored for low precision training.
The most important argument is the recipe argument, which accepts objects inheriting from
Recipe.
Forward computations need to be performed inside the autocast context manager,
while the .backward() call should be outside of it (it inherits the setting from the
corresponding forward pass).
Here is a basic example:
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format
recipe = DelayedScaling()
layer = te.Linear(1024, 1024)
inp = torch.randn(32, 1024, dtype=torch.float32, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
# .backward() is called outside of autocast
loss = output.sum()
loss.backward()
You can use multiple recipes in the same model in the following ways:
Sequential contexts – apply different recipes to different parts of your model:
encoder_recipe = DelayedScaling(fp8_format=Format.E4M3)
decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID)
encoder = te.Linear(1024, 1024)
decoder = te.Linear(1024, 1024)
with te.autocast(enabled=True, recipe=encoder_recipe):
hidden = encoder(inp)
with te.autocast(enabled=True, recipe=decoder_recipe):
output = decoder(hidden)
Nested contexts – the inner context overrides the outer one for its scope:
outer_recipe = DelayedScaling(fp8_format=Format.E4M3)
inner_recipe = DelayedScaling(fp8_format=Format.HYBRID)
layer1 = te.Linear(1024, 1024)
layer2 = te.Linear(1024, 1024)
layer3 = te.Linear(1024, 1024)
with te.autocast(enabled=True, recipe=outer_recipe):
# layer1 uses outer_recipe
x = layer1(inp)
with te.autocast(enabled=True, recipe=inner_recipe):
# layer2 uses inner_recipe (overrides outer)
x = layer2(x)
# layer3 uses outer_recipe again
output = layer3(x)
The JAX API of Transformer Engine provides an autocast context manager similar to PyTorch.
The key difference is that in JAX, model initialization must happen inside the autocast context
to properly capture quantization metadata in the parameter tree.
Here is a basic example:
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import TransformerLayer
from transformer_engine.common.recipe import DelayedScaling, Format
# Set up recipe
recipe = DelayedScaling()
# Model initialization must happen inside autocast
with te.autocast(enabled=True, recipe=recipe):
layer = TransformerLayer(
hidden_size=1024,
mlp_hidden_size=4096,
num_attention_heads=16,
)
init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(init_key, (32, 128, 1024), dtype=jnp.bfloat16)
var_collect = layer.init({"params": init_key, "dropout": dropout_key}, x)
# Forward and backward pass (both inside autocast for JAX)
def loss_fn(var_collect):
output = layer.apply(var_collect, x, rngs={"dropout": dropout_key})
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
You can use multiple recipes in the same model in the following ways:
Sequential contexts – apply different recipes to different parts of your model:
encoder_recipe = DelayedScaling(fp8_format=Format.E4M3)
decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID)
with te.autocast(enabled=True, recipe=encoder_recipe):
encoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
encoder_var_collect = encoder.init({"params": init_key, "dropout": dropout_key}, x)
hidden = encoder.apply(encoder_var_collect, x, rngs={"dropout": dropout_key})
with te.autocast(enabled=True, recipe=decoder_recipe):
decoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
decoder_var_collect = decoder.init({"params": init_key, "dropout": dropout_key}, hidden)
output = decoder.apply(decoder_var_collect, hidden, rngs={"dropout": dropout_key})
Nested contexts – the inner context overrides the outer one for its scope:
outer_recipe = DelayedScaling(fp8_format=Format.E4M3)
inner_recipe = DelayedScaling(fp8_format=Format.HYBRID)
with te.autocast(enabled=True, recipe=outer_recipe):
# layer1 uses outer_recipe
layer1 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
var_collect1 = layer1.init({"params": init_key, "dropout": dropout_key}, x)
hidden = layer1.apply(var_collect1, x, rngs={"dropout": dropout_key})
with te.autocast(enabled=True, recipe=inner_recipe):
# layer2 uses inner_recipe (overrides outer)
layer2 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
var_collect2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden)
hidden = layer2.apply(var_collect2, hidden, rngs={"dropout": dropout_key})
# layer3 uses outer_recipe again
layer3 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16)
var_collect3 = layer3.init({"params": init_key, "dropout": dropout_key}, hidden)
output = layer3.apply(var_collect3, hidden, rngs={"dropout": dropout_key})
Note
Python context managers like autocast may interact unexpectedly with JAX’s JIT compilation.
For finer-grained control, consider passing the recipe directly to TE modules instead.
See the TE JAX Integration notebook
for details.
Mixed precision with 8- or 4-bit precisions
From now on, we will refer to FP8/MXFP8/NVFP4 etc. as low precision and to FP32/BF16/FP16 as high precision. This terminology will be used throughout the rest of the documentation.
Not all operations run in low precision:
Linear operations: run in low precision.
Attention computations: run in high precision by default (some recipes allow low precision as an option).
Other operations (layer normalization, softmax, etc.): run in high precision.
Within high-precision operations, there are two categories:
Configurable precision: most operations run in parameter precision (FP32/BF16/FP16) or the precision specified by
torch.autocast.Fixed FP32 precision: some operations, or parts of operations—such as the division in layernorm—always run in FP32, regardless of other settings.
Figure 3: Default precision of operations in a TransformerLayer forward pass. Only linear operations are in lower precision. Dot product attention is shown as three separate operations (QK^T, Softmax, Scores * V), though in practice these may be fused into a single kernel.
Linear layer data flow
Let’s see how data flow of a linear layer works by default on a single H100 GPU with FP8 precision:
H100 (Hopper) architecture natively supports FP8 Matrix Multiplication only in TN layout (Transpose-NoTranspose),
so GEMM with tensors A and B returns B * A^T.
Forward pass
Input is quantized to FP8 – both
inputandinput^Tquantized versions are created.Weights are stored in high precision and quantized to low precision before the GEMM – both
weightandweight^Tquantized versions are created.FP8 GEMM with layout TN is run with
weightandinputtensors,Outputs –
input * weight^Ttensor – are returned in high precision.
Backward pass
Output gradients are quantized to FP8 – both
output_gradandoutput_grad^Tquantized versions are created.FP8 GEMM with layout TN is performed with
weight^Tandoutput_gradtensors to compute input gradients.FP8 GEMM with layout TN is performed with
input^Tandoutput_grad^Ttensors to compute weight gradients.Input gradients –
output_grad * weighttensor – are returned in high precision.Weight gradients –
output_grad^T * inputtensor – are returned in high precision.
Figure 4: Forward pass of a Linear layer with low precision data flow.