gptq_fused_kernel

Fused Triton kernels for GPTQ blockwise weight-update.

A kernel for scalar (NVFP4) quantization with inline two-level scale computation. Fuses scale computation + quantization + per-column GPTQ error propagation into one launch per GPTQ block, avoiding the Python-level per-column loop.

Architecture:
  • One Triton program per output row.

  • w_full [BLOCK_SIZE] register tensor holds working weights.

  • Per-column: calls nvfp4_scalar_qdq() for FP4 QDQ with inline scale computation, then propagates error via w_full -= err * h_inv_row.

Functions

gptq_fused_block_scalar

Run scalar GPTQ (NVFP4) column loop for one block in a single Triton kernel launch.

gptq_fused_block_scalar(w_block, block_amax, global_scale, h_inv_cho_blk, quant_block_size, block_start)

Run scalar GPTQ (NVFP4) column loop for one block in a single Triton kernel launch.

Computes FP8-quantized scales from per-block amax inline via nvfp4_scalar_qdq(), then performs NVFP4 fake quantization and GPTQ error propagation per column.

Parameters:
  • w_block (Tensor) – Working weights [num_rows, block_size] (float32).

  • block_amax (Tensor) – Per-block amax values [num_rows, n_amax_blocks] (float32).

  • global_scale (float) – Pre-computed global_amax / (6.0 * 448.0) (scalar).

  • h_inv_cho_blk (Tensor) – Block of upper-Cholesky inverse Hessian [block_size, block_size].

  • quant_block_size (int) – Number of elements sharing one scale factor.

  • block_start (int) – Column offset of this block in the full weight matrix.

Returns:

(qw_block, err_block) each [num_rows, block_size].

Return type:

tuple[Tensor, Tensor]