Quantization¶
Quantization reduces memory and compute requirements by running operations in low precision:
Scaling is required to translate to/from low precision.
Scaling factors are chosen such that they minimize accuracy loss.
They can be either:
Loaded into quantization-enabled
nvtripy.Module
s, orUsed with
nvtripy.quantize()
/nvtripy.dequantize()
.
See also
The TensorRT developer guide explains quantization in more detail.
Post-Training Quantization With ModelOpt¶
If the model was not trained with quantization-aware training (QAT), we can use TensorRT ModelOpt to do calibration to determine scaling factors.
Info
Calibration runs a model with a small set of input data to determine the numerical distribution of each tensor.
The dynamic range is the most important range within this distribution and scales are chosen to target this range.
Let’s calibrate a GPT model:
Install ModelOpt:
python3 -m pip install nvidia-modelopt==0.11.1 transformers==4.46.2 datasets==2.21.0
Download the model:
1from transformers import GPT2LMHeadModel 2 3model = GPT2LMHeadModel.from_pretrained("gpt2")
Calibrate for
int8
precision:Define the forward pass:
1from transformers import AutoTokenizer 2from modelopt.torch.utils.dataset_utils import create_forward_loop 3 4MAX_SEQ_LEN = 512 5tokenizer = AutoTokenizer.from_pretrained( 6 "gpt2", 7 use_fast=True, 8 model_max_length=MAX_SEQ_LEN, 9 padding_side="left", 10 trust_remote_code=True, 11) 12tokenizer.pad_token = tokenizer.eos_token 13 14forward_loop = create_forward_loop( 15 model=model, 16 dataset_name="cnn_dailymail", 17 tokenizer=tokenizer, 18 device=model.device, 19 num_samples=8, 20)
Set up quantization configuration:
1import modelopt.torch.quantization as mtq 2 3quant_cfg = mtq.INT8_DEFAULT_CFG
Run calibration to replace linear layers with
QuantLinear
, which contain calibration information:1mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
The amax
attributes of QuantLinear
’s quantizers specify dynamic ranges:
1torch_qlinear = model.transformer.h[0].attn.c_attn
2print(torch_qlinear)
Output:
QuantLinear(
in_features=768, out_features=2304, bias=True
(input_quantizer): TensorQuantizer(8 bit fake per-tensor amax=0.8646 calibrator=MaxCalibrator quant)
(output_quantizer): TensorQuantizer(disabled)
(weight_quantizer): TensorQuantizer(8 bit fake axis=0 amax=[0.1202, 2.8436](2304) calibrator=MaxCalibrator quant)
)
We must convert dynamic ranges to scaling factors to load them into Tripy:
1def get_scale(quantizer):
2 amax = quantizer.export_amax()
3 # `maxbound` is the maximum value representible by the data type.
4 # For `int8`, this is 127.
5 scale = amax.float() / quantizer.maxbound
6 return tp.Tensor(scale.squeeze().contiguous())
7
8
9input_scale = get_scale(torch_qlinear.input_quantizer)
10weight_scale = get_scale(torch_qlinear.weight_quantizer)
Local Variables
>>> input_scale
tensor(0.006808243691921234, dtype=float32, loc=gpu:0, shape=())
>>> weight_scale
tensor([0.0073, 0.0070, 0.0067, ..., 0.0026, 0.0016, 0.0021], dtype=float32, loc=gpu:0, shape=(2304,))
Loading Scales Into Tripy¶
Using Modules¶
Modules that support quantization usually:
Expose additional model parameters for scales.
Accept arguments that control how quantization is performed.
Let’s load the scales into an nvtripy.Linear
module:
1qlinear = tp.Linear(
2 768,
3 2304,
4 # The data type to quantize to:
5 quant_dtype=tp.int8,
6 # The dimension along which the weights are quantized:
7 weight_quant_dim=torch_qlinear.weight_quantizer.axis,
8)
9
10# Load weights:
11qlinear.weight = tp.Tensor(torch_qlinear.weight.detach().contiguous())
12qlinear.bias = tp.Tensor(torch_qlinear.bias.detach().contiguous())
13
14# Load scaling factors:
15qlinear.input_scale = input_scale
16qlinear.weight_scale = weight_scale
Local Variables
>>> qlinear
Linear(
weight: Parameter = (shape=[2304, 768], dtype=float32),
bias: Parameter = (shape=[2304], dtype=float32),
weight_scale: Parameter = (shape=[2304], dtype=float32),
input_scale: Parameter = (shape=[], dtype=float32),
)
>>> qlinear.state_dict()
{
weight: tensor(
[[-0.4738, 0.0874, 0.0039, ..., -0.2592, 0.1517, -0.4100],
[-0.2614, 0.1473, 0.0695, ..., -0.0164, 0.2170, -0.1924],
[-0.0978, 0.2387, 0.3668, ..., 0.1991, 0.1043, -0.2400],
...,
[0.0513, -0.0525, 0.1143, ..., 0.0095, 0.0293, -0.0046],
[-0.0584, -0.0113, 0.0363, ..., -0.0516, -0.0429, 0.0070],
[0.0250, -0.0156, -0.0318, ..., 0.0319, -0.0475, 0.0198]],
dtype=float32, loc=gpu:0, shape=(2304, 768)),
bias: tensor([0.4803, -0.5254, -0.4293, ..., 0.0126, -0.0499, 0.0032], dtype=float32, loc=gpu:0, shape=(2304,)),
weight_scale: tensor([0.0073, 0.0070, 0.0067, ..., 0.0026, 0.0016, 0.0021], dtype=float32, loc=gpu:0, shape=(2304,)),
input_scale: tensor(0.006808243691921234, dtype=float32, loc=gpu:0, shape=()),
}
Note
We use scales from ModelOpt here, but scaling factors can come from anywhere.
We can run it just like a regular float32
module.
Inputs/weights are quantized internally:
1input = tp.ones((1, 768), dtype=tp.float32)
2
3output = qlinear(input)
Local Variables
>>> input
tensor(
[[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000]],
dtype=float32, loc=gpu:0, shape=(1, 768))
>>> output
tensor(
[[-11.8799, 11.4679, 12.3159, ..., 0.1293, 1.8775, -0.6599]],
dtype=float32, loc=gpu:0, shape=(1, 2304))
See also
load_quant_weights_from_hf
in the nanoGPT weight loader
is an example of loading scaling factors for an entire model.
Manually¶
When using nvtripy.quantize()
/nvtripy.dequantize()
,
dequantize
must immediately follow quantize
.
TensorRT will rotate dequantize
over subsequent ops as needed.
See also
The TensorRT developer guide includes recommendations on placement of quantization and dequantization ops.
To mimic the behavior of the nvtripy.Linear
module above, we can:
Quantize the input:
1input = tp.ones((1, 768), dtype=tp.float32) 2 3input = tp.quantize(input, input_scale, dtype=tp.int8) 4# Note the placement of dequantize: 5input = tp.dequantize(input, input_scale, dtype=tp.float32)
Quantize the weights:
1weight = tp.Tensor(torch_qlinear.weight.detach().contiguous()) 2 3dim = torch_qlinear.weight_quantizer.axis 4weight = tp.quantize(weight, weight_scale, dtype=tp.int8, dim=dim) 5weight = tp.dequantize(weight, weight_scale, dtype=tp.float32, dim=dim)
Perform the computation (matrix multiply in this case):
1bias = tp.Tensor(torch_qlinear.bias.detach().contiguous()) 2 3output = input @ tp.transpose(weight, 0, 1) + bias
Local Variables
>>> output tensor( [[-11.8781, 11.4667, 12.3143, ..., 0.1296, 1.8768, -0.6599]], dtype=float32, loc=gpu:0, shape=(1, 2304))
Warning
Evaluating the tensor produced by dequantize
will affect accuracy.
Why: Evaluation replaces the tensor with a constant, losing information like which op produced it.
So, TensorRT won’t see
dequantize
when evaluating subsequent ops and won’t rotate it correctly.
For example, don’t do this:
1tensor = tp.ones(...)
2
3tensor = tp.quantize(tensor, ...)
4tensor = tp.dequantize(tensor, ...)
5
6# The `print` below will trigger an evaluation of the tensor which will prevent
7# TensorRT from rotating the dequantization node. This will affect accuracy!
8print(tensor)
9
10# Rest of the program, including some computation involving tensor
11...