fp8_kernel
FP8 Triton Kernel Implementations.
Functions
Dequantizes the given weight tensor using the provided scale tensor. |
- weight_dequant(x, s, block_size=128, dtype=torch.float32)
Dequantizes the given weight tensor using the provided scale tensor.
Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
- Parameters:
x (torch.Tensor) – The quantized weight tensor of shape (M, N).
s (torch.Tensor) – The scale tensor of shape (M//block_size, N//block_size).
block_size (int, optional) – The block size to use for dequantization. Defaults to 128.
dtype (torch.dtype, optional) – The dtype of the output tensor. Defaults to torch.get_default_dtype().
- Returns:
The dequantized weight tensor of the same shape as x.
- Return type:
torch.Tensor
- Raises:
AssertionError – If x or s are not contiguous or if their dimensions are not 2.
- weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE)
Dequantizes weights using the provided scaling factors and stores the result.
Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
- Parameters:
x_ptr (tl.pointer) – Pointer to the quantized weights.
s_ptr (tl.pointer) – Pointer to the scaling factors.
y_ptr (tl.pointer) – Pointer to the output buffer for dequantized weights.
M (int) – Number of rows in the weight matrix.
N (int) – Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr) – Size of the block for tiling.
- Returns:
None