fp8_kernel

FP8 Triton Kernel Implementations.

Functions

weight_dequant

Dequantizes the given weight tensor using the provided scale tensor.

weight_dequant(x, s, block_size=128, dtype=torch.float32)

Dequantizes the given weight tensor using the provided scale tensor.

Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py

Parameters:
  • x (torch.Tensor) – The quantized weight tensor of shape (M, N).

  • s (torch.Tensor) – The scale tensor of shape (M//block_size, N//block_size).

  • block_size (int, optional) – The block size to use for dequantization. Defaults to 128.

  • dtype (torch.dtype, optional) – The dtype of the output tensor. Defaults to torch.get_default_dtype().

Returns:

The dequantized weight tensor of the same shape as x.

Return type:

torch.Tensor

Raises:

AssertionError – If x or s are not contiguous or if their dimensions are not 2.

weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE)

Dequantizes weights using the provided scaling factors and stores the result.

Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py

Parameters:
  • x_ptr (tl.pointer) – Pointer to the quantized weights.

  • s_ptr (tl.pointer) – Pointer to the scaling factors.

  • y_ptr (tl.pointer) – Pointer to the output buffer for dequantized weights.

  • M (int) – Number of rows in the weight matrix.

  • N (int) – Number of columns in the weight matrix.

  • BLOCK_SIZE (tl.constexpr) – Size of the block for tiling.

Returns:

None