FP8 Delayed Scaling
FP8 Delayed Scaling recipe estimates scaling factors from historical amax values rather than computing them for each tensor. Compared to Current Scaling recipe, this reduces tensor reads per quantization from two to one, improving memory efficiency.
Both this and FP8 Current Scaling recipe use the same FP8 formats (E4M3/E5M2) with one FP32 scaling factor per tensor. Reading the FP8 Current Scaling documentation first is recommended.
Quantization with delayed scaling factors
FP8 Current Scaling requires two tensor reads per quantization: one to compute amax, one to cast. FP8 Delayed Scaling eliminates the first read by predicting the scaling factor from historical amax values - hence delayed (using past values) versus current (using present values).
The quantization process works as follows:
Compute scaling factor from history (no tensor read needed): The scaling factor is derived from stored
amax_historyusing the formula:scaling_factor = FP8_MAX / amaxwhere
amaxis computed from history using eithermax(maximum over window, default) ormost_recentalgorithm.Quantize the tensor (one tensor read): Apply the scaling factor and cast to FP8. Values exceeding FP8 range are clipped.
Update history: Record the actual amax from this quantization for future iterations.
Each module maintains an amax_history tensor of configurable length (amax_history_len)
for each quantized tensor.
Figure 1. Comparison of FP8 Current Scaling and FP8 Delayed Scaling quantization processes.
Amax History Management
The amax_history buffer acts as a sliding window of recent amax values.
Position 0 serves as a staging area for the current amax, while positions 1 to N-1
store the history from oldest to newest. Each quantization writes the observed amax
to position 0, and after the pass completes, the history is rotated:
Before rotation: [amax_N, amax_1, amax_2, ..., amax_N-1] (amax_N = current, amax_1 = oldest)
After rotation: [0, amax_2, ..., amax_N-1, amax_N] (amax_1 dropped, amax_N appended)
The scaling factor is computed before the rotation, so it uses all amax_history_len values.
Position 0 serves as a staging area — it is zeroed after the scale update, ready for the next iteration’s amax.
The implementation differs between PyTorch and JAX:
Each module creates two amax_history tensors, initialized to zero:
Forward: shape
(amax_history_len, num_gemms * 3)— three FP8 tensors per GEMM (input, weight, output)Backward: shape
(amax_history_len, num_gemms * 2)— two FP8 tensors per GEMM (grad_output, grad_input)
When the autocast context exits, a single CUDA kernel processes all tensors at once — performing amax reduction across GPUs and history rotation. This batched approach minimizes kernel launch overhead compared to updating each tensor separately.
Each quantizer maintains its own amax_history with shape (amax_history_len,)
and updates independently.
Here’s how to use FP8 Delayed Scaling in PyTorch and JAX:
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
# Create FP8 Delayed Scaling recipe
recipe = DelayedScaling(
margin=0, # Margin for scaling factor computation (default: 0)
amax_history_len=1024, # Length of amax history window (default: 1024)
amax_compute_algo="max", # How to compute amax from history (default: "max")
)
# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)
# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")
with te.autocast(enabled=True, recipe=recipe):
output = layer(inp)
loss = output.sum()
loss.backward()
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
from transformer_engine.jax.flax import DenseGeneral
from transformer_engine.common.recipe import DelayedScaling
# Create FP8 Delayed Scaling recipe
recipe = DelayedScaling(
margin=0, # Margin for scaling factor computation (default: 0)
amax_history_len=1024, # Length of amax history window (default: 1024)
amax_compute_algo="max", # How to compute amax from history (default: "max")
)
with te.autocast(enabled=True, recipe=recipe):
# Initialize layer and data
layer = DenseGeneral(features=1024)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16)
var_collect = layer.init(key, x)
# Forward and backward pass
def loss_fn(var_collect):
output = layer.apply(var_collect, x)
return output.sum()
loss, grads = jax.value_and_grad(loss_fn)(var_collect)
Distributed Training
FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling - quantized all-gather is supported. However, amax reduction works slightly differently in different frameworks.
Amax reduction is controlled by two parameters:
reduce_amaxin recipe: enables/disables reduction (required for SP and CP)amax_reduction_groupinautocast: specifies the process group for reduction
We recommend reducing amax across all GPUs where the tensor is sharded, including data parallel ranks.
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
# Create process group for amax reduction (e.g., all 8 GPUs)
amax_reduction_group = dist.new_group(ranks=[0, 1, 2, 3, 4, 5, 6, 7])
recipe = DelayedScaling(reduce_amax=True)
with te.autocast(recipe=recipe, amax_reduction_group=amax_reduction_group):
output = model(inp)
In data parallel training, some modules may not execute on certain ranks (e.g., MoE experts that receive no tokens). This is handled as follows:
First iteration: All modules must execute on all ranks to register their
amax_historytensors in the global buffer. Mismatched registration would cause theall_reduceto hang due to different tensor sizes across ranks.Subsequent iterations: The
autocastcontext must be entered and exited on all ranks (this triggers the collective reduction). Individual modules can be skipped - if no rank executes a module, its history is not rotated and scale remains unchanged.
Amax reduction is always enabled and managed automatically. Reduction scope: all parallelism axes except pipeline parallelism (TP, SP, DP/FSDP).
import transformer_engine.jax as te
from transformer_engine.common.recipe import DelayedScaling
# Amax reduction scope is managed internally
recipe = DelayedScaling(reduce_amax=True) # Must be True in JAX
with te.autocast(enabled=True, recipe=recipe):
output = layer.apply(params, inp)
Supported devices
Ada and later (SM 8.9+)