mxfp8_tensor
Implements MXFP8 quantization for efficient tensor storage and computation.
Classes
Implements the MXFP8 quantization on tensors for more efficient storage or computation. |
- class MXFP8QTensor
Bases:
BaseQuantizedTensorImplements the MXFP8 quantization on tensors for more efficient storage or computation.
MXFP8 uses: - FP8 E4M3 format for elements - E8M0 format for shared scales (power-of-2 only, stored as biased uint8 exponent) - Block size of 32 elements along the last dimension
- quantized_data
The quantized data stored as float8_e4m3fn tensor.
- Type:
torch.Tensor
- BLOCK_SIZE = 32
- E4M3_MAX = 448.0
- SCALE_DTYPE = torch.uint8
- dequantize(dtype=None, **kwargs)
Dequantize MXFP8 tensor back to the target dtype.
- Parameters:
dtype (torch.dtype | None) – Target dtype for dequantization. Defaults to original dtype.
**kwargs – Must contain ‘scale’ (E8M0 biased uint8).
- Returns:
Dequantized tensor in the target dtype.
- Return type:
torch.Tensor
- classmethod get_weights_scaling_factor(weight)
Returns E8M0 scale (uint8 biased exponent) for weight tensor.
- Parameters:
weight (Tensor) – The weight tensor to compute scale for. Must be at least 2D. Supports 2D (out_dim, in_dim) and 3D MoE (num_experts, out_dim, in_dim).
- Returns:
- E8M0 scale as uint8 tensor with shape […, out_dim, in_dim // 32].
For 2D input: (out_dim, in_dim // 32) For 3D MoE input: (num_experts, out_dim, in_dim // 32)
- Return type:
torch.Tensor
- classmethod get_weights_scaling_factor_from_quantizer(weight, weight_quantizer)
Returns E8M0 scale from quantizer or computes from weight.
This method handles extracting the scale from a weight quantizer, with proper format conversion and shape correction.
- Parameters:
weight (Tensor) – The weight tensor. Can be 2D (out_dim, in_dim) or 3D for MoE (num_experts, out_dim, in_dim).
weight_quantizer – The weight quantizer with block_sizes and optional _scale.
- Returns:
E8M0 scale as uint8 tensor with shape […, out_dim, in_dim // 32].
- Return type:
torch.Tensor
- classmethod quantize(input, weights_scaling_factor=None)
Convert a tensor to MXFP8 quantized format.
- Parameters:
input (torch.Tensor) – The input tensor to be quantized.
weights_scaling_factor (torch.Tensor | None) – Optional pre-computed E8M0 scale as uint8 biased exponent. If None, the scale will be computed from the input. Shape should be […, in_dim // 32] matching input dimensions.
- Returns:
- (MXFP8QTensor, weights_scaling_factor) where weights_scaling_factor is
E8M0 scale as uint8 biased exponent.
- Return type:
tuple
- classmethod quantize_with_scale(weight, weights_scaling_factor)
Quantize weight tensor using a pre-computed E8M0 scale.
This method is useful for export paths where the scale has already been computed.
- Parameters:
weight (Tensor) – The weight tensor to quantize. Must be at least 1D.
weights_scaling_factor (Tensor) – E8M0 scale as uint8 biased exponent (bias = 127). Shape should be […, out_dim, in_dim // 32] for 2D+ tensors, or [in_dim // 32] for 1D tensors.
- Returns:
Quantized weight as float8_e4m3fn with same shape as input.
- Return type:
torch.Tensor