Transformer Engine documentation
Warning
You are currently viewing unstable developer preview of the documentation. To see the documentation for the latest stable release, refer to:
Developer Guide (stable version of this page)
Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper, Ada, and Blackwell GPUs, to provide better performance with lower memory utilization in both training and inference. On Blackwell GPUs, TE also supports MXFP8 (Microscaling FP8) and NVFP4 formats for even greater efficiency. TE provides a collection of highly optimized building blocks for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your framework-specific code. TE also includes a framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.
As Transformer models scale to hundreds of billions of parameters across large language models, MoE architectures, and multimodal models, training and inference become increasingly memory and compute-intensive. Mixed-precision training, which combines single-precision (FP32) with lower precision formats, delivers significant speedups with minimal impact on accuracy. FP8, introduced with the Hopper GPU architecture, offers further performance gains over FP16 with no degradation in accuracy, and newer formats like MXFP8 and NVFP4 on Blackwell push efficiency even further.
TE integrates with popular LLM frameworks and provides optimizations that make low-precision training work seamlessly with advanced features like MoE, tensor/sequence/context parallelism, and fused operations. It provides a Python API consisting of modules to easily build a Transformer layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support. Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly simplifying mixed precision training for users.
Highlights
Easy-to-use modules for building Transformer layers with FP8 support
Optimizations (e.g. fused kernels) for Transformer models
Support for FP8 on NVIDIA Hopper, Ada, and Blackwell GPUs
Support for MXFP8 and NVFP4 on NVIDIA Blackwell GPUs
Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later
Examples
PyTorch
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda")
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
# Enable autocasting for the forward pass
with te.autocast(enabled=True, recipe=fp8_recipe):
out = model(inp)
loss = out.sum()
loss.backward()
JAX
Flax
import flax
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe
BATCH = 32
SEQLEN = 128
HIDDEN = 1024
# Initialize RNG and inputs.
rng = jax.random.PRNGKey(0)
init_rng, data_rng = jax.random.split(rng)
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)
# Enable autocasting for the forward pass
with te.autocast(enabled=True, recipe=fp8_recipe):
model = te_flax.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp):
out = model.apply({'params':params, **other_vars}, inp)
return jnp.mean(out)
# Initialize models.
variables = model.init(init_rng, inp)
other_variables, params = flax.core.pop(variables, 'params')
# Construct the forward and backward function
fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
For a more comprehensive tutorial, check out our Getting Started Guide.