JAX: Integrating TransformerEngine into an existing framework
This is the landing page for a series of focused documents on bringing TransformerEngine into a JAX+Flax codebase one optimization at a time. Each linked page isolates a single feature so you can see exactly what changes are required and what are the performance benefits.
Pick a topic
Document |
Status |
Covers |
|---|---|---|
Available |
|
|
Coming soon |
||
Coming soon |
||
Coming soon |
Quantization recipes at a glance
TE exposes its quantization choices as recipes. Please see Low-precision Training for a more detailed description of each recipe.
Recipe |
Hardware |
State |
Description |
|---|---|---|---|
|
Blackwell+ |
none |
Block-scaled FP8 (32-element blocks) |
|
Blackwell+ |
requires a Flax RNG |
FP4 with 2D block scaling and stochastic rounding |
|
Hopper+ |
amax history (Flax variables) |
Per-tensor FP8 with amax history |
|
Hopper+ |
none |
Per-tensor FP8 without an amax history |
Import them from transformer_engine.common.recipe.
Conventions used across these documents
Framework. Flax Linen. (TE/JAX uses Linen; see Flax NNX/Linen interop and Haiku/Flax interop if you’re on a different stack.)
Baseline dtype. bf16 for inputs and parameters.
Benchmarking.
quickstart_jax_utils.speedometerruns a JIT-compiled fwd+bwd loop with warmup