Quantization¶
Using Quantized Modules¶
Various modules predefined by Tripy support quantization. For example, the tripy.Linear
module includes two arguments to configure the quantization mode. Let’s construct the following
quantized linear module:
1quant_linear = tp.Linear(
2 4,
3 2,
4 quant_dtype=tp.int8,
5 weight_quant_dim=None,
6)
>>> quant_linear.state_dict()
{
weight: tensor(
[[0.0000, 1.0000, 2.0000, 3.0000],
[4.0000, 5.0000, 6.0000, 7.0000]],
dtype=float32, loc=gpu:0, shape=(2, 4)),
bias: tensor([0.0000, 1.0000], dtype=float32, loc=gpu:0, shape=(2,)),
weight_scale: tensor(0.0, dtype=float32, loc=gpu:0, shape=()),
input_scale: tensor(0.0, dtype=float32, loc=gpu:0, shape=()),
}
As described in tripy.Linear
, the quantized linear module has
2 additional tripy.Parameter
s compared to a normal linear layer:
weight_scale
: The quantization scale forweight
.input_scale
: The quantization scale for the input.
weight_scale
must always be provided while input_scale
is optional. The input will be quantized
only if input_scale
is provided. For a Linear
module in this example, only “per-tensor” quantization
is allowed for the input. This is why there is no input_quant_dim
argument.
Let’s fill the scale parameters with dummy data:
1quant_linear.weight_scale = tp.Parameter(1.0)
2quant_linear.input_scale = tp.Parameter(1.0)
>>> quant_linear.state_dict()
{
weight: tensor(
[[0.0000, 1.0000, 2.0000, 3.0000],
[4.0000, 5.0000, 6.0000, 7.0000]],
dtype=float32, loc=gpu:0, shape=(2, 4)),
bias: tensor([0.0000, 1.0000], dtype=float32, loc=gpu:0, shape=(2,)),
weight_scale: tensor(1.0, dtype=float32, loc=gpu:0, shape=()),
input_scale: tensor(1.0, dtype=float32, loc=gpu:0, shape=()),
}
and run a forward pass to see the result:
1x = tp.iota((3, 4), dtype=tp.float32)
2out = quant_linear(x)
>>> x
tensor(
[[0.0000, 0.0000, 0.0000, 0.0000],
[1.0000, 1.0000, 1.0000, 1.0000],
[2.0000, 2.0000, 2.0000, 2.0000]],
dtype=float32, loc=gpu:0, shape=(3, 4))
>>> out
tensor(
[[0.0000, 1.0000],
[6.0000, 23.0000],
[12.0000, 45.0000]],
dtype=float32, loc=gpu:0, shape=(3, 2))
The result still has a data type of tripy.float32
, but internally, TensorRT quantized the
input and weight, executed the linear layer with tripy.int8
precision, and finally dequantized
the output back to the original precision.
Running Quantized Models¶
Now that we have covered how quantization works in tripy.Linear
, we will walk through
the workflow of running a real-world quantized model: nanoGPT.
Calibration With Model Optimizer¶
The quantization scales are not available unless the model was trained with QAT (quantization-aware training).
We need to perform another step called calibration to compute the correct scales for each quantized layer.
There are many ways to do calibration, one of which is using the nvidia-modelopt
toolkit. To install it, run:
python3 -m pip install --extra-index-url https://pypi.nvidia.com nvidia-modelopt==0.11.0 transformers==4.46.2 datasets==2.21.0
First, let’s get the pre-trained GPT model from hugging face:
1from transformers import GPT2LMHeadModel
2
3model = GPT2LMHeadModel.from_pretrained("gpt2")
Then, we perform int8 weight-only quantization:
1from transformers import AutoTokenizer
2import modelopt.torch.quantization as mtq
3
4from modelopt.torch.utils.dataset_utils import create_forward_loop
5
6# define the modelopt quant configs
7quant_cfg = mtq.INT8_DEFAULT_CFG
8# disable input quantization for weight-only
9# quantized linear modules
10quant_cfg["quant_cfg"]["*input_quantizer"] = {
11 "enable": False,
12}
13
14# define the forward loop for calibration
15MAX_SEQ_LEN = 512
16tokenizer = AutoTokenizer.from_pretrained(
17 "gpt2",
18 use_fast=True,
19 model_max_length=MAX_SEQ_LEN,
20 padding_side="left",
21 trust_remote_code=True,
22)
23tokenizer.pad_token = tokenizer.eos_token
24
25forward_loop = create_forward_loop(
26 model=model,
27 dataset_name="cnn_dailymail",
28 tokenizer=tokenizer,
29 device=model.device,
30 num_samples=8,
31)
32
33# call the api for calibration
34mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
Output:
/usr/local/lib/python3.10/dist-packages/modelopt/torch/quantization/tensor_quant.py:81: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
scaled_e4m3_abstract = torch.library.impl_abstract("trt::quantize_fp8")(
Downloading readme: 0%| | 0.00/15.6k [00:00<?, ?B/s]
Downloading readme: 100%|##########| 15.6k/15.6k [00:00<00:00, 104kB/s]
Downloading readme: 100%|##########| 15.6k/15.6k [00:00<00:00, 104kB/s]
Inserted 147 quantizers
Warning: The following arguments will not be used in the forward loop:
- Positional argument 0: GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0-11): 12 x GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2SdpaAttention(
(c_attn): QuantLinear(
in_features=768, out_features=2304, bias=True
(input_quantizer): TensorQuantizer(disabled)
(output_quantizer): TensorQuantizer(disabled)
(weight_quantizer): TensorQuantizer(8 bit fake axis=0 amax=dynamic calibrator=MaxCalibrator calib)
)
(c_proj): QuantLinear(
in_features=768, out_features=768, bias=True
(input_quantizer): TensorQuantizer(disabled)
(output_quantizer): TensorQuantizer(disabled)
(weight_quantizer): TensorQuantizer(8 bit fake axis=0 amax=dynamic calibrator=MaxCalibrator calib)
)
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): QuantLinear(
in_features=768, out_features=3072, bias=True
(input_quantizer): TensorQuantizer(disabled)
(output_quantizer): TensorQuantizer(disabled)
(weight_quantizer): TensorQuantizer(8 bit fake axis=0 amax=dynamic calibrator=MaxCalibrator calib)
)
(c_proj): QuantLinear(
in_features=3072, out_features=768, bias=True
(input_quantizer): TensorQuantizer(disabled)
(output_quantizer): TensorQuantizer(disabled)
(weight_quantizer): TensorQuantizer(8 bit fake axis=0 amax=dynamic calibrator=MaxCalibrator calib)
)
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): QuantLinear(
in_features=768, out_features=50257, bias=False
(input_quantizer): TensorQuantizer(disabled)
(output_quantizer): TensorQuantizer(disabled)
(weight_quantizer): TensorQuantizer(disabled)
)
)
mtq.quantize
replaces all linear layers specified in quant_cfg
with QuantLinear
layers, which contain the calibrated parameters.
Load Scales Into The Tripy Model¶
Let’s take a look at one of the QuantLinear
produced by model optimizer:
1print(model.transformer.h[0].attn.c_attn)
Output:
QuantLinear(
in_features=768, out_features=2304, bias=True
(input_quantizer): TensorQuantizer(disabled)
(output_quantizer): TensorQuantizer(disabled)
(weight_quantizer): TensorQuantizer(8 bit fake axis=0 amax=[0.1202, 2.8436](2304) calibrator=MaxCalibrator quant)
)
The amax
attribute gives us the dynamic range of the tensor. Tripy requires scaling factors, so we can convert it like so:
1def convert_to_scale(amax, maxbound):
2 return amax.float() / maxbound
Let’s convert the amax
to the scaling factor and load it to a compatible tripy.Linear
module:
1weight_only_qlinear = tp.Linear(
2 768,
3 2304,
4 quant_dtype=tp.int8,
5 weight_quant_dim=0,
6)
7quantizer = model.transformer.h[0].attn.c_attn.weight_quantizer
8scale = convert_to_scale(quantizer.export_amax(), quantizer.maxbound)
9scale = scale.squeeze().contiguous()
10weight_only_qlinear.weight_scale = tp.Parameter(scale)
>>> weight_only_qlinear.state_dict()
{
weight: tensor(
[[0.0000, 1.0000, 2.0000, ..., 765.0000, 766.0000, 767.0000],
[768.0000, 769.0000, 770.0000, ..., 1533.0000, 1534.0000, 1535.0000],
[1536.0000, 1537.0000, 1538.0000, ..., 2301.0000, 2302.0000, 2303.0000],
...,
[1767168.0000, 1767169.0000, 1767170.0000, ..., 1767933.0000, 1767934.0000, 1767935.0000],
[1767936.0000, 1767937.0000, 1767938.0000, ..., 1768701.0000, 1768702.0000, 1768703.0000],
[1768704.0000, 1768705.0000, 1768706.0000, ..., 1769469.0000, 1769470.0000, 1769471.0000]],
dtype=float32, loc=gpu:0, shape=(2304, 768)),
bias: tensor([0.0000, 1.0000, 2.0000, ..., 2301.0000, 2302.0000, 2303.0000], dtype=float32, loc=gpu:0, shape=(2304,)),
weight_scale: tensor([0.0073, 0.0070, 0.0067, ..., 0.0026, 0.0016, 0.0021], dtype=float32, loc=gpu:0, shape=(2304,)),
input_scale: tensor(0.0, dtype=float32, loc=gpu:0, shape=()),
}
For an example of how to load weights from a quantized model, refer to load_quant_weights_from_hf from the nanoGPT example.