gptq_fused_kernel
Fused Triton kernels for GPTQ blockwise weight-update.
A kernel for scalar (NVFP4) quantization with inline two-level scale computation. Fuses scale computation + quantization + per-column GPTQ error propagation into one launch per GPTQ block, avoiding the Python-level per-column loop.
- Architecture:
One Triton program per output row.
w_full [BLOCK_SIZE]register tensor holds working weights.Per-column: calls
nvfp4_scalar_qdq()for FP4 QDQ with inline scale computation, then propagates error viaw_full -= err * h_inv_row.
Functions
Run scalar GPTQ (NVFP4) column loop for one block in a single Triton kernel launch. |
- gptq_fused_block_scalar(w_block, block_amax, global_scale, h_inv_cho_blk, quant_block_size, block_start)
Run scalar GPTQ (NVFP4) column loop for one block in a single Triton kernel launch.
Computes FP8-quantized scales from per-block amax inline via
nvfp4_scalar_qdq(), then performs NVFP4 fake quantization and GPTQ error propagation per column.- Parameters:
w_block (Tensor) – Working weights
[num_rows, block_size](float32).block_amax (Tensor) – Per-block amax values
[num_rows, n_amax_blocks](float32).global_scale (float) – Pre-computed
global_amax / (6.0 * 448.0)(scalar).h_inv_cho_blk (Tensor) – Block of upper-Cholesky inverse Hessian
[block_size, block_size].quant_block_size (int) – Number of elements sharing one scale factor.
block_start (int) – Column offset of this block in the full weight matrix.
- Returns:
(qw_block, err_block)each[num_rows, block_size].- Return type:
tuple[Tensor, Tensor]