fp8_kernel
FP8 Triton Kernel Implementations.
Functions
Dequantizes the given weight tensor using the provided scale tensor. |
- weight_dequant(x, s, block_size=128, dtype=torch.float32)
Dequantizes the given weight tensor using the provided scale tensor.
Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
- Parameters:
x (torch.Tensor) – The quantized weight tensor of shape (M, N).
s (torch.Tensor) – The scale tensor of shape (M//block_size, N//block_size).
block_size (int, optional) – The block size to use for dequantization. Defaults to 128.
dtype (torch.dtype, optional) – The dtype of the output tensor. Defaults to torch.get_default_dtype().
- Returns:
The dequantized weight tensor of the same shape as x.
- Return type:
torch.Tensor
- Raises:
AssertionError – If x or s are not contiguous or if their dimensions are not 2.