mxfp8_tensor

Implements MXFP8 quantization for efficient tensor storage and computation.

Classes

MXFP8QTensor

Implements the MXFP8 quantization on tensors for more efficient storage or computation.

class MXFP8QTensor

Bases: BaseQuantizedTensor

Implements the MXFP8 quantization on tensors for more efficient storage or computation.

MXFP8 uses: - FP8 E4M3 format for elements - E8M0 format for shared scales (power-of-2 only, stored as biased uint8 exponent) - Block size of 32 elements along the last dimension

quantized_data

The quantized data stored as float8_e4m3fn tensor.

Type:

torch.Tensor

BLOCK_SIZE = 32
E4M3_MAX = 448.0
SCALE_DTYPE = torch.uint8
dequantize(dtype=None, **kwargs)

Dequantize MXFP8 tensor back to the target dtype.

Parameters:
  • dtype (torch.dtype | None) – Target dtype for dequantization. Defaults to original dtype.

  • **kwargs – Must contain ‘scale’ (E8M0 biased uint8).

Returns:

Dequantized tensor in the target dtype.

Return type:

torch.Tensor

classmethod get_weights_scaling_factor(weight)

Returns E8M0 scale (uint8 biased exponent) for weight tensor.

Parameters:

weight (Tensor) – The weight tensor to compute scale for. Must be at least 2D. Supports 2D (out_dim, in_dim) and 3D MoE (num_experts, out_dim, in_dim).

Returns:

E8M0 scale as uint8 tensor with shape […, out_dim, in_dim // 32].

For 2D input: (out_dim, in_dim // 32) For 3D MoE input: (num_experts, out_dim, in_dim // 32)

Return type:

torch.Tensor

classmethod get_weights_scaling_factor_from_quantizer(weight, weight_quantizer)

Returns E8M0 scale from quantizer or computes from weight.

This method handles extracting the scale from a weight quantizer, with proper format conversion and shape correction.

Parameters:
  • weight (Tensor) – The weight tensor. Can be 2D (out_dim, in_dim) or 3D for MoE (num_experts, out_dim, in_dim).

  • weight_quantizer – The weight quantizer with block_sizes and optional _scale.

Returns:

E8M0 scale as uint8 tensor with shape […, out_dim, in_dim // 32].

Return type:

torch.Tensor

classmethod quantize(input, weights_scaling_factor=None)

Convert a tensor to MXFP8 quantized format.

Parameters:
  • input (torch.Tensor) – The input tensor to be quantized.

  • weights_scaling_factor (torch.Tensor | None) – Optional pre-computed E8M0 scale as uint8 biased exponent. If None, the scale will be computed from the input. Shape should be […, in_dim // 32] matching input dimensions.

Returns:

(MXFP8QTensor, weights_scaling_factor) where weights_scaling_factor is

E8M0 scale as uint8 biased exponent.

Return type:

tuple

classmethod quantize_with_scale(weight, weights_scaling_factor)

Quantize weight tensor using a pre-computed E8M0 scale.

This method is useful for export paths where the scale has already been computed.

Parameters:
  • weight (Tensor) – The weight tensor to quantize. Must be at least 1D.

  • weights_scaling_factor (Tensor) – E8M0 scale as uint8 biased exponent (bias = 127). Shape should be […, out_dim, in_dim // 32] for 2D+ tensors, or [in_dim // 32] for 1D tensors.

Returns:

Quantized weight as float8_e4m3fn with same shape as input.

Return type:

torch.Tensor