JAX: Dense GEMMs with TransformerEngine
This document walks through replacing a plain flax.linen.Dense’s GEMM with
TransformerEngine’s quantized GEMM.
Recipe. We use MXFP8BlockScaling in this tutorial. MXFP8BlockScaling and
NVFP4BlockScaling require a Blackwell-class GPU; on Hopper, swap in
DelayedScaling or Float8CurrentScaling. For more information on recipes, see this recipe overview.
← Back to the JAX integration overview
1. Baseline: a plain Flax Dense block
We isolate the optimization to a single linear layer so it’s clear what’s
changing. dot_general_cls is exposed as a constructor argument so we can swap
in TE later without touching the model definition.
class FlaxDenseBlock(nn.Module):
"""One linear layer. ``dot_general_cls`` lets us swap the GEMM impl."""
features: int
dtype: jnp.dtype = jnp.bfloat16
dot_general_cls: callable = lambda: None
@nn.compact
def __call__(self, x):
return nn.Dense(
features=self.features,
use_bias=False,
dtype=self.dtype,
dot_general=self.dot_general_cls(),
)(x)
batch, seq, hidden, out_features = 8, 2048, 8192, 32768
dtype = jnp.bfloat16
key = jax.random.PRNGKey(0)
k_init, k_x, k_dy = jax.random.split(key, 3)
x = jax.random.normal(k_x, (batch, seq, hidden)).astype(dtype)
dy = jax.random.normal(k_dy, (batch, seq, out_features)).astype(dtype)
baseline = FlaxDenseBlock(features=out_features)
baseline_vars = baseline.init(k_init, x)
2. Quantized Dense via make_dot_general_cls
TE exposes a helper, te_flax.make_dot_general_cls(recipe), that returns a Flax
module class you pass directly to nn.Dense(..., dot_general=...).
With this API, TE doesn’t create the kernel params; it only wraps the GEMM.
All your initialization, sharding annotations, and optimizer state stay where
they were.
from transformer_engine.jax import flax as te_flax
from transformer_engine.common.recipe import MXFP8BlockScaling
recipe = MXFP8BlockScaling()
te_dot_general_cls = te_flax.make_dot_general_cls(recipe)
te_model = FlaxDenseBlock(features=out_features, dot_general_cls=te_dot_general_cls)
te_vars = te_model.init(k_init, x)
print("Variable collections:", list(te_vars.keys()))
print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars))
If using DelayedScaling, see [1].
3. Single-GPU performance
speedometer runs a JIT-compiled forward+backward loop with warmup, on the
same input for both models.
def run_single_gpu_bench():
print("bf16 baseline:")
utils.speedometer(
model_apply_fn=baseline.apply,
variables=baseline_vars,
input=x,
output_grad=dy,
)
print(f"\nTE {type(recipe).__name__}:")
utils.speedometer(
model_apply_fn=te_model.apply,
variables=te_vars,
input=x,
output_grad=dy,
)
Variable collections: ['params']
{'params': {'Dense_0': {'kernel': ((8192, 32768), dtype('float32'))}}}
bf16 baseline:
Mean time: 18.056 ms
TE MXFP8BlockScaling:
Mean time: 11.260 ms
On a single GB200, that’s roughly 1.6× faster for the fwd+bwd of one large
Dense — and the only code change was passing dot_general=te_dot_general_cls()
into nn.Dense.
The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may not benefit at all because the cast + scale overhead can dominate.
Warning
Remat / activation checkpointing. If your training loop uses
jax.checkpoint_policies.checkpoint_dots (or any policy that matches
jax.lax.dot_general), swap it for
transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms.
Otherwise TE’s quantized GEMM primitives won’t be checkpointed correctly
and your performance comparison will not be accurate.
4. Multi-GPU: DP=2 / TP=2 on a single Dense
Prerequisite: this section requires four GPUs.
Keeping the same FlaxDenseBlock from the rest of the document, we run it on
a 2×2 mesh with data parallelism on one axis and tensor parallelism
(column-parallel: shard the kernel’s output dim) on the other.
Two pieces wire this up:
A
jax.sharding.Meshyou build once at module scope (outside JIT).TE’s
MeshResource, set globally viaglobal_shard_guard, which tells TE which mesh axes are DP and TP.
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
def build_dp_tp_mesh():
# 2x2 mesh: DP on one axis, TP on the other.
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=("dp", "tp"))
# Tell TE which mesh axis is which. This is a *global* setting, established
# outside JIT, so TE's GEMM primitives can plan comms accordingly.
mesh_resource = MeshResource(dp_resource="dp", tp_resource="tp")
return mesh, mesh_resource
Sharding plan:
Tensor |
Shape |
PartitionSpec |
|---|---|---|
Kernel (column-parallel) |
|
|
Input activations |
|
|
Gradient on output |
|
|
def shard_variables(mesh, variables_dict):
kernel_sharding = NamedSharding(mesh, P(None, "tp"))
def _shard(variables):
params = variables["params"]
sharded = jax.device_put(params["Dense_0"]["kernel"], kernel_sharding)
return {
**variables,
"params": {
**params,
"Dense_0": {**params["Dense_0"], "kernel": sharded},
},
}
input_sharding = NamedSharding(mesh, P("dp", None, None))
output_grad_sharding = NamedSharding(mesh, P("dp", None, "tp"))
return {
"x": jax.device_put(x, input_sharding),
"dy": jax.device_put(dy, output_grad_sharding),
**{name: _shard(vars_) for name, vars_ in variables_dict.items()},
}
def run_multi_gpu_bench():
mesh, mesh_resource = build_dp_tp_mesh()
sharded = shard_variables(mesh, {"baseline": baseline_vars, "te": te_vars})
with jax.set_mesh(mesh), global_shard_guard(mesh_resource):
print("bf16 DP=2/TP=2:")
utils.speedometer(
model_apply_fn=baseline.apply,
variables=sharded["baseline"],
input=sharded["x"],
output_grad=sharded["dy"],
)
print(f"\nTE {type(recipe).__name__} DP=2/TP=2:")
utils.speedometer(
model_apply_fn=te_model.apply,
variables=sharded["te"],
input=sharded["x"],
output_grad=sharded["dy"],
)
bf16 DP=2/TP=2:
Mean time: 5.516 ms
TE MXFP8BlockScaling DP=2/TP=2:
Mean time: 3.712 ms
Next steps
Collective GEMM: further speedups by communicating between devices inside the GEMM.
Footnotes