tensor_quant

Basic tensor quantization functions.

Classes

DynamicBlockQuantizationFunction

Dynamic block quantization functional.

FakeTensorQuantFunction

Fake version of TensorQuantFunction use CUDA extension.

ScaledE4M3Function

E4M3fy input with scale.

StaticBlockwiseFP4FakeQuantFunction

Static blockwise FP4 fake quantization functional.

Functions

fake_quant_impl

Implementation of fake quantizing input according to number of bits.

fp8_eager

Eager mode implementation of FP8 quantization.

scaled_e4m3_impl

Implementation of fake quantizing input to FP8.

class DynamicBlockQuantizationFunction

Bases: Function

Dynamic block quantization functional.

static backward(ctx, grad_outputs)

Implements straight through estimation with clipping.

static forward(ctx, inputs, block_size, amax, bias, num_bits, scale_bits, trt_high_precision_dtype=None, onnx_quantizer_type='dynamic', pass_through_bwd=True)

Forward method.

static symbolic(g, inputs, block_size, amax, bias, num_bits, scale_bits, trt_high_precision_dtype=None, onnx_quantizer_type='dynamic', pass_through_bwd=True)

ONNX symbolic function.

class FakeTensorQuantFunction

Bases: Function

Fake version of TensorQuantFunction use CUDA extension.

static backward(ctx, grad_outputs)

Implements straight through estimation with clipping.

static forward(ctx, inputs, amax, bias=None, num_bits=8, unsigned=False, narrow_range=True, trt_high_precision_dtype=None, pass_through_bwd=False, block_size=None, axis=None)

Forward method.

static symbolic(g, inputs, amax, bias=None, num_bits=8, unsigned=False, narrow_range=True, trt_high_precision_dtype=None, pass_through_bwd=False, block_size=None, axis=None)

ONNX symbolic function.

class ScaledE4M3Function

Bases: Function

E4M3fy input with scale.

static backward(ctx, grad_outputs)

Implements straight through estimation with clipping.

static forward(ctx, inputs, amax, bias, E, M, trt_high_precision_dtype=None, pass_through_bwd=False)

Forward method.

static symbolic(g, inputs, amax=None, bias=None, E=4, M=3, trt_high_precision_dtype=None, pass_through_bwd=False)

ONNX symbolic function.

class StaticBlockwiseFP4FakeQuantFunction

Bases: Function

Static blockwise FP4 fake quantization functional.

static backward(ctx, grad_outputs)

Implements straight through estimation with clipping.

static forward(ctx, x, scale, scale_fp8_quant_amax, skip_scale_quant, out_dtype, pass_through_bwd=False)

Forward method.

fake_quant_impl(inputs, amax, num_bits=8, unsigned=False, narrow_range=True)

Implementation of fake quantizing input according to number of bits.

Parameters:
  • inputs (Tensor)

  • amax (Tensor)

fp8_eager(x, amax)

Eager mode implementation of FP8 quantization.

scaled_e4m3_impl(inputs, amax=None)

Implementation of fake quantizing input to FP8.

Parameters:
  • inputs (Tensor) – Torch tensor.

  • amax (Tensor | None) – Absolute max range of the input tensor.

Returns:

Input tensors faked quantized to FP8.

Return type:

Tensor