JAX: Dense GEMMs with TransformerEngine

This document walks through replacing a plain flax.linen.Dense’s GEMM with TransformerEngine’s quantized GEMM.

Recipe. We use MXFP8BlockScaling in this tutorial. MXFP8BlockScaling and NVFP4BlockScaling require a Blackwell-class GPU; on Hopper, swap in DelayedScaling or Float8CurrentScaling. For more information on recipes, see this recipe overview.

← Back to the JAX integration overview

1. Baseline: a plain Flax Dense block

We isolate the optimization to a single linear layer so it’s clear what’s changing. dot_general_cls is exposed as a constructor argument so we can swap in TE later without touching the model definition.

class FlaxDenseBlock(nn.Module):
    """One linear layer. ``dot_general_cls`` lets us swap the GEMM impl."""

    features: int
    dtype: jnp.dtype = jnp.bfloat16
    dot_general_cls: callable = lambda: None

    @nn.compact
    def __call__(self, x):
        return nn.Dense(
            features=self.features,
            use_bias=False,
            dtype=self.dtype,
            dot_general=self.dot_general_cls(),
        )(x)


batch, seq, hidden, out_features = 8, 2048, 8192, 32768
dtype = jnp.bfloat16

key = jax.random.PRNGKey(0)
k_init, k_x, k_dy = jax.random.split(key, 3)
x = jax.random.normal(k_x, (batch, seq, hidden)).astype(dtype)
dy = jax.random.normal(k_dy, (batch, seq, out_features)).astype(dtype)

baseline = FlaxDenseBlock(features=out_features)
baseline_vars = baseline.init(k_init, x)

2. Quantized Dense via make_dot_general_cls

TE exposes a helper, te_flax.make_dot_general_cls(recipe), that returns a Flax module class you pass directly to nn.Dense(..., dot_general=...).

With this API, TE doesn’t create the kernel params; it only wraps the GEMM. All your initialization, sharding annotations, and optimizer state stay where they were.

from transformer_engine.jax import flax as te_flax
from transformer_engine.common.recipe import MXFP8BlockScaling

recipe = MXFP8BlockScaling()
te_dot_general_cls = te_flax.make_dot_general_cls(recipe)

te_model = FlaxDenseBlock(features=out_features, dot_general_cls=te_dot_general_cls)
te_vars = te_model.init(k_init, x)

print("Variable collections:", list(te_vars.keys()))
print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars))

If using DelayedScaling, see [1].

3. Single-GPU performance

speedometer runs a JIT-compiled forward+backward loop with warmup, on the same input for both models.

def run_single_gpu_bench():
    print("bf16 baseline:")
    utils.speedometer(
        model_apply_fn=baseline.apply,
        variables=baseline_vars,
        input=x,
        output_grad=dy,
    )

    print(f"\nTE {type(recipe).__name__}:")
    utils.speedometer(
        model_apply_fn=te_model.apply,
        variables=te_vars,
        input=x,
        output_grad=dy,
    )


Output:
Variable collections: ['params']
{'params': {'Dense_0': {'kernel': ((8192, 32768), dtype('float32'))}}}

bf16 baseline:
Mean time: 18.056 ms

TE MXFP8BlockScaling:
Mean time: 11.260 ms

On a single GB200, that’s roughly 1.6× faster for the fwd+bwd of one large Dense — and the only code change was passing dot_general=te_dot_general_cls() into nn.Dense.

The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may not benefit at all because the cast + scale overhead can dominate.

Warning

Remat / activation checkpointing. If your training loop uses jax.checkpoint_policies.checkpoint_dots (or any policy that matches jax.lax.dot_general), swap it for transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms. Otherwise TE’s quantized GEMM primitives won’t be checkpointed correctly and your performance comparison will not be accurate.

4. Multi-GPU: DP=2 / TP=2 on a single Dense

Prerequisite: this section requires four GPUs.

Keeping the same FlaxDenseBlock from the rest of the document, we run it on a 2×2 mesh with data parallelism on one axis and tensor parallelism (column-parallel: shard the kernel’s output dim) on the other.

Two pieces wire this up:

  1. A jax.sharding.Mesh you build once at module scope (outside JIT).

  2. TE’s MeshResource, set globally via global_shard_guard, which tells TE which mesh axes are DP and TP.

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils
from transformer_engine.jax.sharding import MeshResource, global_shard_guard


def build_dp_tp_mesh():
    # 2x2 mesh: DP on one axis, TP on the other.
    devices = mesh_utils.create_device_mesh((2, 2))
    mesh = Mesh(devices, axis_names=("dp", "tp"))

    # Tell TE which mesh axis is which. This is a *global* setting, established
    # outside JIT, so TE's GEMM primitives can plan comms accordingly.
    mesh_resource = MeshResource(dp_resource="dp", tp_resource="tp")
    return mesh, mesh_resource


Sharding plan:

Tensor

Shape

PartitionSpec

Kernel (column-parallel)

(hidden, out_features)

P(None, 'tp')

Input activations

(batch, seq, hidden)

P('dp', None, None)

Gradient on output

(batch, seq, out_features)

P('dp', None, 'tp')

def shard_variables(mesh, variables_dict):
    kernel_sharding = NamedSharding(mesh, P(None, "tp"))

    def _shard(variables):
        params = variables["params"]
        sharded = jax.device_put(params["Dense_0"]["kernel"], kernel_sharding)
        return {
            **variables,
            "params": {
                **params,
                "Dense_0": {**params["Dense_0"], "kernel": sharded},
            },
        }

    input_sharding = NamedSharding(mesh, P("dp", None, None))
    output_grad_sharding = NamedSharding(mesh, P("dp", None, "tp"))

    return {
        "x": jax.device_put(x, input_sharding),
        "dy": jax.device_put(dy, output_grad_sharding),
        **{name: _shard(vars_) for name, vars_ in variables_dict.items()},
    }


def run_multi_gpu_bench():
    mesh, mesh_resource = build_dp_tp_mesh()
    sharded = shard_variables(mesh, {"baseline": baseline_vars, "te": te_vars})

    with jax.set_mesh(mesh), global_shard_guard(mesh_resource):
        print("bf16 DP=2/TP=2:")
        utils.speedometer(
            model_apply_fn=baseline.apply,
            variables=sharded["baseline"],
            input=sharded["x"],
            output_grad=sharded["dy"],
        )

        print(f"\nTE {type(recipe).__name__} DP=2/TP=2:")
        utils.speedometer(
            model_apply_fn=te_model.apply,
            variables=sharded["te"],
            input=sharded["x"],
            output_grad=sharded["dy"],
        )


Output:
bf16 DP=2/TP=2:
Mean time: 5.516 ms

TE MXFP8BlockScaling DP=2/TP=2:
Mean time: 3.712 ms

Next steps

Footnotes