mxfp4_tensor
Implements MXFP4 quantization for efficient tensor storage and computation.
Classes
Implements the MXFP4 quantization on tensors for more efficient storage or computation. |
- class MXFP4QTensor
Bases:
BaseQuantizedTensorImplements the MXFP4 quantization on tensors for more efficient storage or computation.
- quantized_data
The quantized data stored as a packed fp8 tensor.
- Type:
torch.Tensor
- E2M1_bounds = tensor([0.2500, 0.7500, 1.2500, 1.7500, 2.5000, 3.5000, 5.0000])
- E2M1_max = 6.0
- E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6]
- dequantize(dtype=None, **kwarg)
Dequantze MXFP4 packed tensor to a target dtype.
- Parameters:
dtype (dtype)
- classmethod dequantize_packed(blocks, scales, *, block_size=32, dtype=torch.bfloat16)
Dequantize MXFP4-packed bytes to
dtypewithout a QTensor wrapper.Input layout (DeepSeek-V4 checkpoint convention — group axis and packing axis fused in a single trailing dim):
- blocks
[..., K // 2], dtypeuint8orint8. Low nibble = even element, high nibble = odd element.
- scales
[..., K // block_size], dtypeuint8or torch.float8_e8m0fnu. UE8M0: byteemaps to2 ** (e - 127).
Returns a tensor of shape
[..., K]in the requesteddtype.The GPT-OSS layout stores blocks and scales as
blocks.shape == (..., G, 16)andscales.shape == (..., G). To feed such inputs here, reshape blocks to(..., G * 16)so thatblocks.shape[:-1] == scales.shape[:-1]holds and the last dim of blocks isK // 2. This helper does no trailing transpose, so the result is in the natural(out, in)orientation, suitable for feeding a standardnn.Linearor a downstream weight quantizer.UE8M0 note: per the OCP MX spec byte
0xFFis NaN; we matchtransformers.integrations.mxfp4._convert_moe_packed_tensorsby treating it as exponent+128, which overflows bf16 to+Inf. Real MXFP4 checkpoints do not use0xFF.- Parameters:
blocks (Tensor)
scales (Tensor)
block_size (int)
dtype (dtype)
- Return type:
Tensor
- blocks
- classmethod quantize(input, block_size)
Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
- Parameters:
input (torch.Tensor) – The input tensor to be quantized.
block_sizes (dict | None) – The block sizes for quantization.
block_size (int | None)
- Return type:
tuple