nvfp4_act_max

Calibrator for the NVFP4 activation global scale (nvfp4_act_max).

Calibrates the per-tensor NVFP4 global amax (g_amax) of an activation quantizer from the distribution of its per-block (16-wide) amaxes, using the B_min-anchored recipe documented in experimental/nvfp4_global_scale_study/:

g_amax = rho * B_min (rho in (0, 28672), default 16384)

where B_min is a robust low percentile of the per-block amaxes. Unlike plain max calibration (which sets g_amax to the literal per-tensor max and so sits on the saturation cliff), this spends the format’s wide normal-FP8 window as upward headroom so unseen activation outliers degrade gracefully rather than clipping. See the study README §4 for the derivation.

Classes

NVFP4ActMaxCalibrator

Calibrates the NVFP4 activation global amax via the B_min-anchored recipe.

class NVFP4ActMaxCalibrator

Bases: _Calibrator

Calibrates the NVFP4 activation global amax via the B_min-anchored recipe.

The calibrator accumulates a base-2 log-spaced histogram of the per-block amaxes seen during calibration (bounded memory), then at compute_amax derives robust B_min / B_max percentiles and returns g_amax = clamp(rho * B_min, floor = margin * B_max).

Parameters:
  • num_bits – quantizer num_bits ((2, 1) for NVFP4); kept for interface parity.

  • axis – unused (the global amax is per-tensor); kept for interface parity.

  • unsigned – unused; kept for interface parity.

  • block_size – NVFP4 block width along the last dim (16).

  • rho – window-split factor; g_amax = rho * B_min. Must be in (0, 28672).

  • b_min_percentile – low percentile (over represented blocks) used for B_min.

  • b_max_percentile – high percentile used for B_max (100 => literal max).

  • margin – sanity-floor multiplier; g_amax >= margin * B_max.

  • num_bins – number of log2 histogram bins.

  • log2_max (log2_min /) – log2 range covered by the histogram.

__init__(num_bits=(2, 1), axis=None, unsigned=False, *, block_size=16, rho=16384.0, b_min_percentile=1.0, b_max_percentile=99.99, margin=1.0, num_bins=512, log2_min=-40.0, log2_max=40.0)

Initialize.

collect(x)

Accumulate the per-block amax histogram for one activation batch.

Parameters:

x (Tensor)

Return type:

None

compute_amax()

Return the calibrated NVFP4 activation global amax (g_amax).

Also records a diagnostic self._stats dict (literal max, p1, p99.99, the values actually used, the chosen g_amax and which term set it) for offline analysis.

Return type:

Tensor | None

reset()

Reset the collected histogram and statistics.

Return type:

None