FP8 Delayed Scaling

FP8 Delayed Scaling recipe estimates scaling factors from historical amax values rather than computing them for each tensor. Compared to Current Scaling recipe, this reduces tensor reads per quantization from two to one, improving memory efficiency.

Both this and FP8 Current Scaling recipe use the same FP8 formats (E4M3/E5M2) with one FP32 scaling factor per tensor. Reading the FP8 Current Scaling documentation first is recommended.

Quantization with delayed scaling factors

FP8 Current Scaling requires two tensor reads per quantization: one to compute amax, one to cast. FP8 Delayed Scaling eliminates the first read by predicting the scaling factor from historical amax values - hence delayed (using past values) versus current (using present values).

The quantization process works as follows:

  1. Compute scaling factor from history (no tensor read needed): The scaling factor is derived from stored amax_history using the formula:

    scaling_factor = FP8_MAX / amax

    where amax is computed from history using either max (maximum over window, default) or most_recent algorithm.

  2. Quantize the tensor (one tensor read): Apply the scaling factor and cast to FP8. Values exceeding FP8 range are clipped.

  3. Update history: Record the actual amax from this quantization for future iterations.

Each module maintains an amax_history tensor of configurable length (amax_history_len) for each quantized tensor.

Current Scaling Tensor Amax Computation Quantization (uses tensor + amax) FP8 Tensor Delayed Scaling Tensor amax history read amax Quantization (uses tensor + amax from history) (updates amax history) update amax FP8 Tensor

Figure 1. Comparison of FP8 Current Scaling and FP8 Delayed Scaling quantization processes.

Amax History Management

The amax_history buffer acts as a sliding window of recent amax values. Position 0 serves as a staging area for the current amax, while positions 1 to N-1 store the history from oldest to newest. Each quantization writes the observed amax to position 0, and after the pass completes, the history is rotated:

Before rotation: [amax_N, amax_1, amax_2, ..., amax_N-1]   (amax_N = current, amax_1 = oldest)
After rotation:  [0,      amax_2, ..., amax_N-1, amax_N]   (amax_1 dropped, amax_N appended)

The scaling factor is computed before the rotation, so it uses all amax_history_len values. Position 0 serves as a staging area — it is zeroed after the scale update, ready for the next iteration’s amax.

The implementation differs between PyTorch and JAX:

Each module creates two amax_history tensors, initialized to zero:

  • Forward: shape (amax_history_len, num_gemms * 3) — three FP8 tensors per GEMM (input, weight, output)

  • Backward: shape (amax_history_len, num_gemms * 2) — two FP8 tensors per GEMM (grad_output, grad_input)

When the autocast context exits, a single CUDA kernel processes all tensors at once — performing amax reduction across GPUs and history rotation. This batched approach minimizes kernel launch overhead compared to updating each tensor separately.

Here’s how to use FP8 Delayed Scaling in PyTorch and JAX:

Requires SM89 (Ada) or later

import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling

# Create FP8 Delayed Scaling recipe
recipe = DelayedScaling(
    margin=0,  # Margin for scaling factor computation (default: 0)
    amax_history_len=1024,  # Length of amax history window (default: 1024)
    amax_compute_algo="max",  # How to compute amax from history (default: "max")
)

# Create a linear layer with bfloat16 parameters
layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16)

# Forward and backward pass
inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda")

with te.autocast(enabled=True, recipe=recipe):
    output = layer(inp)
    loss = output.sum()

loss.backward()

Distributed Training

FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling - quantized all-gather is supported. However, amax reduction works slightly differently in different frameworks.

Amax reduction is controlled by two parameters:

  • reduce_amax in recipe: enables/disables reduction (required for SP and CP)

  • amax_reduction_group in autocast: specifies the process group for reduction

We recommend reducing amax across all GPUs where the tensor is sharded, including data parallel ranks.

import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling

# Create process group for amax reduction (e.g., all 8 GPUs)
amax_reduction_group = dist.new_group(ranks=[0, 1, 2, 3, 4, 5, 6, 7])

recipe = DelayedScaling(reduce_amax=True)

with te.autocast(recipe=recipe, amax_reduction_group=amax_reduction_group):
    output = model(inp)

In data parallel training, some modules may not execute on certain ranks (e.g., MoE experts that receive no tokens). This is handled as follows:

  • First iteration: All modules must execute on all ranks to register their amax_history tensors in the global buffer. Mismatched registration would cause the all_reduce to hang due to different tensor sizes across ranks.

  • Subsequent iterations: The autocast context must be entered and exited on all ranks (this triggers the collective reduction). Individual modules can be skipped - if no rank executes a module, its history is not rotated and scale remains unchanged.

Supported devices

Ada and later (SM 8.9+)