fp8_kernel

FP8 Triton Kernel Implementations.

Functions

weight_dequant

Dequantizes the given weight tensor using the provided scale tensor.

weight_dequant(x, s, block_size=128, dtype=torch.float32)

Dequantizes the given weight tensor using the provided scale tensor.

Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py

Parameters:
  • x (torch.Tensor) – The quantized weight tensor of shape (M, N).

  • s (torch.Tensor) – The scale tensor of shape (M//block_size, N//block_size).

  • block_size (int, optional) – The block size to use for dequantization. Defaults to 128.

  • dtype (torch.dtype, optional) – The dtype of the output tensor. Defaults to torch.get_default_dtype().

Returns:

The dequantized weight tensor of the same shape as x.

Return type:

torch.Tensor

Raises:

AssertionError – If x or s are not contiguous or if their dimensions are not 2.