Deploying Quantization Aware Trained models in INT8 using TRTorch

Overview

Quantization Aware training (QAT) simulates quantization during training by quantizing weights and activation layers. This will help to reduce the loss in accuracy when we convert the network trained in FP32 to INT8 for faster inference. QAT introduces additional nodes in the graph which will be used to learn the dynamic ranges of weights and activation layers. In this notebook, we illustrate the following steps from training to inference of a QAT model in TRTorch.

  1. Requirements

  2. VGG16 Overview

  3. Training a baseline VGG16 model

  4. Apply Quantization

  5. Model calibration

  6. Quantization Aware training

  7. Export to Torchscript

  8. Inference using TRTorch

  9. References

## 1. Requirements Please install the required dependencies and import these libraries accordingly

[2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import trtorch

from torch.utils.tensorboard import SummaryWriter

from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import calib
from tqdm import tqdm

import os
import sys
sys.path.insert(0, "../examples/int8/training/vgg16")
from vgg16 import vgg16

WARNING: Logging before flag parsing goes to stderr.
E0831 15:09:13.151450 140586428176192 amp_wrapper.py:31] AMP is not avaialble.

## 2. VGG16 Overview ### Very Deep Convolutional Networks for Large-Scale Image Recognition VGG is one of the earliest family of image classification networks that first used small (3x3) convolution filters and achieved significant improvements on ImageNet recognition challenge. The network architecture looks as follows db485e72215749b1b3e43387176ae5e4

## 3. Training a baseline VGG16 model We train VGG16 on CIFAR10 dataset. Define training and testing datasets and dataloaders. This will download the CIFAR 10 data in your data directory. Data preprocessing is performed using torchvision transforms.

[3]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# ========== Define Training dataset and dataloaders =============#
training_dataset = datasets.CIFAR10(root='./data',
                                        train=True,
                                        download=True,
                                        transform=transforms.Compose([
                                            transforms.RandomCrop(32, padding=4),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]))

training_dataloader = torch.utils.data.DataLoader(training_dataset,
                                                      batch_size=32,
                                                      shuffle=True,
                                                      num_workers=2)

# ========== Define Testing dataset and dataloaders =============#
testing_dataset = datasets.CIFAR10(root='./data',
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                   ]))

testing_dataloader = torch.utils.data.DataLoader(testing_dataset,
                                                 batch_size=16,
                                                 shuffle=False,
                                                 num_workers=2)

Files already downloaded and verified
Files already downloaded and verified
[4]:
def train(model, dataloader, crit, opt, epoch):
#     global writer
    model.train()
    running_loss = 0.0
    for batch, (data, labels) in enumerate(dataloader):
        data, labels = data.cuda(), labels.cuda(non_blocking=True)
        opt.zero_grad()
        out = model(data)
        loss = crit(out, labels)
        loss.backward()
        opt.step()

        running_loss += loss.item()
        if batch % 500 == 499:
            print("Batch: [%5d | %5d] loss: %.3f" % (batch + 1, len(dataloader), running_loss / 100))
            running_loss = 0.0

def test(model, dataloader, crit, epoch):
    global writer
    global classes
    total = 0
    correct = 0
    loss = 0.0
    class_probs = []
    class_preds = []
    model.eval()
    with torch.no_grad():
        for data, labels in dataloader:
            data, labels = data.cuda(), labels.cuda(non_blocking=True)
            out = model(data)
            loss += crit(out, labels)
            preds = torch.max(out, 1)[1]
            class_probs.append([F.softmax(i, dim=0) for i in out])
            class_preds.append(preds)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
    test_preds = torch.cat(class_preds)

    return loss / total, correct / total

def save_checkpoint(state, ckpt_path="checkpoint.pth"):
    torch.save(state, ckpt_path)
    print("Checkpoint saved")

Define the VGG model that we are going to perfom QAT on.

[5]:
# CIFAR 10 has 10 classes
model = vgg16(num_classes=len(classes), init_weights=False)
model = model.cuda()
[6]:
# Declare Learning rate
lr = 0.1
state = {}
state["lr"] = lr

# Use cross entropy loss for classification and SGD optimizer
crit = nn.CrossEntropyLoss()
opt = optim.SGD(model.parameters(), lr=state["lr"], momentum=0.9, weight_decay=1e-4)


# Adjust learning rate based on epoch number
def adjust_lr(optimizer, epoch):
    global state
    new_lr = lr * (0.5**(epoch // 12)) if state["lr"] > 1e-7 else state["lr"]
    if new_lr != state["lr"]:
        state["lr"] = new_lr
        print("Updating learning rate: {}".format(state["lr"]))
        for param_group in optimizer.param_groups:
            param_group["lr"] = state["lr"]
[7]:
# Train the model for 25 epochs to get ~80% accuracy.
num_epochs=25
for epoch in range(num_epochs):
    adjust_lr(opt, epoch)
    print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, state["lr"]))

    train(model, training_dataloader, crit, opt, epoch)
    test_loss, test_acc = test(model, testing_dataloader, crit, epoch)

    print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

save_checkpoint({'epoch': epoch + 1,
                 'model_state_dict': model.state_dict(),
                 'acc': test_acc,
                 'opt_state_dict': opt.state_dict(),
                 'state': state},
                ckpt_path="vgg16_base_ckpt")
Epoch: [    1 /    25] LR: 0.100000
/home/dperi/Downloads/py3/lib/python3.6/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Batch: [  500 |  1563] loss: 12.466
Batch: [ 1000 |  1563] loss: 10.726
Batch: [ 1500 |  1563] loss: 10.289
Test Loss: 0.12190 Test Acc: 19.86%
Epoch: [    2 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 10.107
Batch: [ 1000 |  1563] loss: 9.986
Batch: [ 1500 |  1563] loss: 9.994
Test Loss: 0.12230 Test Acc: 21.54%
Epoch: [    3 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 9.826
Batch: [ 1000 |  1563] loss: 9.904
Batch: [ 1500 |  1563] loss: 9.771
Test Loss: 0.11709 Test Acc: 22.71%
Epoch: [    4 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 9.760
Batch: [ 1000 |  1563] loss: 9.629
Batch: [ 1500 |  1563] loss: 9.642
Test Loss: 0.11945 Test Acc: 23.89%
Epoch: [    5 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 9.590
Batch: [ 1000 |  1563] loss: 9.489
Batch: [ 1500 |  1563] loss: 9.468
Test Loss: 0.11180 Test Acc: 30.01%
Epoch: [    6 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 9.281
Batch: [ 1000 |  1563] loss: 9.057
Batch: [ 1500 |  1563] loss: 8.957
Test Loss: 0.11106 Test Acc: 28.03%
Epoch: [    7 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 8.799
Batch: [ 1000 |  1563] loss: 8.808
Batch: [ 1500 |  1563] loss: 8.647
Test Loss: 0.10456 Test Acc: 32.25%
Epoch: [    8 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 8.672
Batch: [ 1000 |  1563] loss: 8.478
Batch: [ 1500 |  1563] loss: 8.522
Test Loss: 0.10404 Test Acc: 32.40%
Epoch: [    9 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 8.422
Batch: [ 1000 |  1563] loss: 8.290
Batch: [ 1500 |  1563] loss: 8.474
Test Loss: 0.10282 Test Acc: 41.11%
Epoch: [   10 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 8.131
Batch: [ 1000 |  1563] loss: 8.005
Batch: [ 1500 |  1563] loss: 8.074
Test Loss: 0.09473 Test Acc: 38.91%
Epoch: [   11 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 8.132
Batch: [ 1000 |  1563] loss: 8.047
Batch: [ 1500 |  1563] loss: 7.941
Test Loss: 0.09928 Test Acc: 41.69%
Epoch: [   12 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 7.911
Batch: [ 1000 |  1563] loss: 7.974
Batch: [ 1500 |  1563] loss: 7.871
Test Loss: 0.10598 Test Acc: 38.90%
Updating learning rate: 0.05
Epoch: [   13 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 6.981
Batch: [ 1000 |  1563] loss: 6.543
Batch: [ 1500 |  1563] loss: 6.377
Test Loss: 0.07362 Test Acc: 53.72%
Epoch: [   14 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 6.208
Batch: [ 1000 |  1563] loss: 6.113
Batch: [ 1500 |  1563] loss: 6.016
Test Loss: 0.07922 Test Acc: 55.78%
Epoch: [   15 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 5.945
Batch: [ 1000 |  1563] loss: 5.726
Batch: [ 1500 |  1563] loss: 5.568
Test Loss: 0.05914 Test Acc: 65.33%
Epoch: [   16 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 5.412
Batch: [ 1000 |  1563] loss: 5.356
Batch: [ 1500 |  1563] loss: 5.143
Test Loss: 0.05833 Test Acc: 68.91%
Epoch: [   17 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 5.096
Batch: [ 1000 |  1563] loss: 5.064
Batch: [ 1500 |  1563] loss: 4.962
Test Loss: 0.05291 Test Acc: 71.72%
Epoch: [   18 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 4.958
Batch: [ 1000 |  1563] loss: 4.887
Batch: [ 1500 |  1563] loss: 4.711
Test Loss: 0.05003 Test Acc: 73.61%
Epoch: [   19 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 4.651
Batch: [ 1000 |  1563] loss: 4.567
Batch: [ 1500 |  1563] loss: 4.603
Test Loss: 0.05046 Test Acc: 73.80%
Epoch: [   20 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 4.467
Batch: [ 1000 |  1563] loss: 4.399
Batch: [ 1500 |  1563] loss: 4.310
Test Loss: 0.05038 Test Acc: 74.45%
Epoch: [   21 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 4.226
Batch: [ 1000 |  1563] loss: 4.196
Batch: [ 1500 |  1563] loss: 4.169
Test Loss: 0.05287 Test Acc: 71.18%
Epoch: [   22 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 4.120
Batch: [ 1000 |  1563] loss: 4.035
Batch: [ 1500 |  1563] loss: 4.018
Test Loss: 0.06157 Test Acc: 70.29%
Epoch: [   23 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 3.915
Batch: [ 1000 |  1563] loss: 3.968
Batch: [ 1500 |  1563] loss: 3.989
Test Loss: 0.04128 Test Acc: 79.01%
Epoch: [   24 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 3.871
Batch: [ 1000 |  1563] loss: 3.800
Batch: [ 1500 |  1563] loss: 3.871
Test Loss: 0.04785 Test Acc: 75.77%
Updating learning rate: 0.025
Epoch: [   25 /    25] LR: 0.025000
Batch: [  500 |  1563] loss: 3.141
Batch: [ 1000 |  1563] loss: 2.979
Batch: [ 1500 |  1563] loss: 2.874
Test Loss: 0.03345 Test Acc: 83.15%
Checkpoint saved

## 4. Apply Quantization

quant_modules.initialize() will ensure quantized version of modules will be called instead of original modules. For example, when you define a model with convolution, linear, pooling layers, QuantConv2d , QuantLinear and QuantPooling will be called. QuantConv2d basically wraps quantizer nodes around inputs and weights of regular Conv2d . Please refer to all the quantized modules in pytorch-quantization toolkit for more information. A QuantConv2d is represented in pytorch-quantization toolkit as follows.

def forward(self, input):
        # the actual quantization happens in the next level of the class hierarchy
        quant_input, quant_weight = self._quant(input)

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            output = F.conv2d(F.pad(quant_input, expanded_padding, mode='circular'),
                              quant_weight, self.bias, self.stride,
                              _pair(0), self.dilation, self.groups)
        else:
            output = F.conv2d(quant_input, quant_weight, self.bias, self.stride, self.padding, self.dilation,
                              self.groups)

        return output
[8]:
quant_modules.initialize()
[9]:
# All the regular conv, FC layers will be converted to their quantozed counterparts due to quant_modules.initialize()
qat_model = vgg16(num_classes=len(classes), init_weights=False)
qat_model = qat_model.cuda()
[10]:
# vgg16_base_ckpt is the checkpoint generated from Step 3 : Training a baseline VGG16 model.
ckpt = torch.load("./vgg16_base_ckpt")
modified_state_dict={}
for key, val in ckpt["model_state_dict"].items():
    # Remove 'module.' from the key names
    if key.startswith('module'):
        modified_state_dict[key[7:]] = val
    else:
        modified_state_dict[key] = val

# Load the pre-trained checkpoint
qat_model.load_state_dict(modified_state_dict)
opt.load_state_dict(ckpt["opt_state_dict"])

## 5. Model Calibration

The quantizer nodes introduced in the model around desired layers capture the dynamic range (min_value, max_value) that is observed by the layer. Calibration is the process of computing the dynamic range of these layers by passing calibration data, which is usually a subset of training or validation data. There are different ways of calibration: max , histogram and entropy . We use max calibration technique as it is simple and effective.

[11]:
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
            print(F"{name:40}: {module}")
    model.cuda()

def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistics"""
    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    # Feed data to the network for collecting stats
    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()

def calibrate_model(model, model_name, data_loader, num_calib_batch, calibrator, hist_percentile, out_dir):
    """
        Feed data to the network and calibrate.
        Arguments:
            model: classification model
            model_name: name to use when creating state files
            data_loader: calibration data set
            num_calib_batch: amount of calibration passes to perform
            calibrator: type of calibration to use (max/histogram)
            hist_percentile: percentiles to be used for historgram calibration
            out_dir: dir to save state files in
    """

    if num_calib_batch > 0:
        print("Calibrating model")
        with torch.no_grad():
            collect_stats(model, data_loader, num_calib_batch)

        if not calibrator == "histogram":
            compute_amax(model, method="max")
            calib_output = os.path.join(
                out_dir,
                F"{model_name}-max-{num_calib_batch*data_loader.batch_size}.pth")
            torch.save(model.state_dict(), calib_output)
        else:
            for percentile in hist_percentile:
                print(F"{percentile} percentile calibration")
                compute_amax(model, method="percentile")
                calib_output = os.path.join(
                    out_dir,
                    F"{model_name}-percentile-{percentile}-{num_calib_batch*data_loader.batch_size}.pth")
                torch.save(model.state_dict(), calib_output)

            for method in ["mse", "entropy"]:
                print(F"{method} calibration")
                compute_amax(model, method=method)
                calib_output = os.path.join(
                    out_dir,
                    F"{model_name}-{method}-{num_calib_batch*data_loader.batch_size}.pth")
                torch.save(model.state_dict(), calib_output)
[12]:
#Calibrate the model using max calibration technique.
with torch.no_grad():
    calibrate_model(
        model=qat_model,
        model_name="vgg16",
        data_loader=training_dataloader,
        num_calib_batch=32,
        calibrator="max",
        hist_percentile=[99.9, 99.99, 99.999, 99.9999],
        out_dir="./")
Calibrating model
100%|██████████| 32/32 [00:00<00:00, 85.05it/s]
W0831 15:32:46.956144 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.957227 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.958076 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.958884 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.959697 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.960512 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.961301 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.962079 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.962872 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.963665 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.964508 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.965338 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.966276 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.967190 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.967864 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.968530 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.969168 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.969751 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.970463 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.971141 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.971790 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.972402 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.973017 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.973696 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.974347 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.974952 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.975592 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.976269 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.976892 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.977430 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.977965 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.978480 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator
W0831 15:32:46.979063 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:46.979588 140586428176192 tensor_quantizer.py:239] Call .cuda() if running on GPU after loading calibrated amax.
W0831 15:32:46.980288 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([64, 1, 1, 1]).
W0831 15:32:46.987690 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.002152 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([64, 1, 1, 1]).
features.0._input_quantizer             : TensorQuantizer(8bit narrow fake per-tensor amax=2.7537 calibrator=MaxCalibrator scale=1.0 quant)
features.0._weight_quantizer            : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0287, 4.4272](64) calibrator=MaxCalibrator scale=1.0 quant)
features.3._input_quantizer             : TensorQuantizer(8bit narrow fake per-tensor amax=30.1997 calibrator=MaxCalibrator scale=1.0 quant)
W0831 15:32:57.006651 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.009306 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([128, 1, 1, 1]).
W0831 15:32:57.011739 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.014180 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([128, 1, 1, 1]).
W0831 15:32:57.016433 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.018157 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).
W0831 15:32:57.019830 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.021619 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).
W0831 15:32:57.023381 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.024606 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).
W0831 15:32:57.026464 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.027716 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W0831 15:32:57.029010 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.030247 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W0831 15:32:57.031455 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.032716 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W0831 15:32:57.034027 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.035287 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W0831 15:32:57.036572 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.037535 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W0831 15:32:57.038545 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.039479 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W0831 15:32:57.040493 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.041564 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([4096, 1]).
W0831 15:32:57.042597 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.043280 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([4096, 1]).
W0831 15:32:57.044521 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W0831 15:32:57.045206 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([10, 1]).
features.3._weight_quantizer            : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0137, 2.2506](64) calibrator=MaxCalibrator scale=1.0 quant)
features.7._input_quantizer             : TensorQuantizer(8bit narrow fake per-tensor amax=16.2026 calibrator=MaxCalibrator scale=1.0 quant)
features.7._weight_quantizer            : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0602, 1.3986](128) calibrator=MaxCalibrator scale=1.0 quant)
features.10._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=9.1012 calibrator=MaxCalibrator scale=1.0 quant)
features.10._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0841, 0.9074](128) calibrator=MaxCalibrator scale=1.0 quant)
features.14._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=10.0201 calibrator=MaxCalibrator scale=1.0 quant)
features.14._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0921, 0.7349](256) calibrator=MaxCalibrator scale=1.0 quant)
features.17._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=7.0232 calibrator=MaxCalibrator scale=1.0 quant)
features.17._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0406, 0.5232](256) calibrator=MaxCalibrator scale=1.0 quant)
features.20._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=8.3654 calibrator=MaxCalibrator scale=1.0 quant)
features.20._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0346, 0.4240](256) calibrator=MaxCalibrator scale=1.0 quant)
features.24._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=7.5746 calibrator=MaxCalibrator scale=1.0 quant)
features.24._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0218, 0.2763](512) calibrator=MaxCalibrator scale=1.0 quant)
features.27._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=4.8754 calibrator=MaxCalibrator scale=1.0 quant)
features.27._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0163, 0.1819](512) calibrator=MaxCalibrator scale=1.0 quant)
features.30._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=3.7100 calibrator=MaxCalibrator scale=1.0 quant)
features.30._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0113, 0.1578](512) calibrator=MaxCalibrator scale=1.0 quant)
features.34._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=3.2465 calibrator=MaxCalibrator scale=1.0 quant)
features.34._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0137, 0.1480](512) calibrator=MaxCalibrator scale=1.0 quant)
features.37._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=2.3264 calibrator=MaxCalibrator scale=1.0 quant)
features.37._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0122, 0.2957](512) calibrator=MaxCalibrator scale=1.0 quant)
features.40._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=3.4793 calibrator=MaxCalibrator scale=1.0 quant)
features.40._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0023, 0.6918](512) calibrator=MaxCalibrator scale=1.0 quant)
classifier.0._input_quantizer           : TensorQuantizer(8bit narrow fake per-tensor amax=7.0113 calibrator=MaxCalibrator scale=1.0 quant)
classifier.0._weight_quantizer          : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0027, 0.8358](4096) calibrator=MaxCalibrator scale=1.0 quant)
classifier.3._input_quantizer           : TensorQuantizer(8bit narrow fake per-tensor amax=7.8033 calibrator=MaxCalibrator scale=1.0 quant)
classifier.3._weight_quantizer          : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0024, 0.4038](4096) calibrator=MaxCalibrator scale=1.0 quant)
classifier.6._input_quantizer           : TensorQuantizer(8bit narrow fake per-tensor amax=8.7469 calibrator=MaxCalibrator scale=1.0 quant)
classifier.6._weight_quantizer          : TensorQuantizer(8bit narrow fake axis=0 amax=[0.3125, 0.4321](10) calibrator=MaxCalibrator scale=1.0 quant)

## 6. Quantization Aware Training

In this phase, we finetune the model weights and leave the quantizer node values frozen. The dynamic ranges for each layer obtained from the calibration are kept constant while the weights of the model are finetuned to be close to the accuracy of original FP32 model (model without quantizer nodes) is preserved. Usually the finetuning of QAT model should be quick compared to the full training of the original model. Use QAT to fine-tune for around 10% of the original training schedule with an annealing learning-rate. Please refer to Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT for detailed recommendations. For this VGG model, it is enough to finetune for 1 epoch to get acceptable accuracy. During finetuning with QAT, the quantization is applied as a composition of max , clamp , round and mul ops.

# amax is absolute maximum value for an input
# The upper bound for integer quantization (127 for int8)
max_bound = torch.tensor((2.0**(num_bits - 1 + int(unsigned))) - 1.0, device=amax.device)
scale = max_bound / amax
outputs = torch.clamp((inputs * scale).round_(), min_bound, max_bound)

tensor_quant function in pytorch_quantization toolkit is responsible for the above tensor quantization. Usually, per channel quantization is recommended for weights, while per tensor quantization is recommended for activations in a network. During inference, we use torch.fake_quantize_per_tensor_affine and torch.fake_quantize_per_channel_affine to perform quantization as this is easier to convert into corresponding TensorRT operators. Please refer to next sections for more details on how these operators are exported in torchscript and converted in TRTorch.

[13]:
# Finetune the QAT model for 1 epoch
num_epochs=1
for epoch in range(num_epochs):
    adjust_lr(opt, epoch)
    print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, state["lr"]))

    train(qat_model, training_dataloader, crit, opt, epoch)
    test_loss, test_acc = test(qat_model, testing_dataloader, crit, epoch)

    print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

save_checkpoint({'epoch': epoch + 1,
                 'model_state_dict': qat_model.state_dict(),
                 'acc': test_acc,
                 'opt_state_dict': opt.state_dict(),
                 'state': state},
                ckpt_path="vgg16_qat_ckpt")
Updating learning rate: 0.1
Epoch: [    1 /     1] LR: 0.100000
Batch: [  500 |  1563] loss: 2.694
Batch: [ 1000 |  1563] loss: 2.682
Batch: [ 1500 |  1563] loss: 2.624
Test Loss: 0.03277 Test Acc: 83.58%
Checkpoint saved

## 7. Export to Torchscript Export the model to Torch script. Trace the model and convert it into torchscript for deployment. To learn more about Torchscript, please refer to https://pytorch.org/docs/stable/jit.html . Setting quant_nn.TensorQuantizer.use_fb_fake_quant = True enables the QAT model to use torch.fake_quantize_per_tensor_affine and torch.fake_quantize_per_channel_affine operators instead of tensor_quant function to export quantization operators. In torchscript, they are represented as aten::fake_quantize_per_tensor_affine and aten::fake_quantize_per_channel_affine .

[16]:
quant_nn.TensorQuantizer.use_fb_fake_quant = True
with torch.no_grad():
    data = iter(testing_dataloader)
    images, _ = data.next()
    jit_model = torch.jit.trace(qat_model, images.to("cuda"))
    torch.jit.save(jit_model, "trained_vgg16_qat.jit.pt")
E0831 15:41:34.662368 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.664751 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.671072 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.671867 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.683352 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.684193 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.687814 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.688531 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.698150 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.698921 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.702409 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.702994 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.711167 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.711931 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.714900 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.715603 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.725254 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.725864 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.728618 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.729140 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.736662 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.737521 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.739989 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.740708 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.748396 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.749184 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.751592 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.752305 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.764246 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.764994 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.767470 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.768118 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.775590 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.776468 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.778920 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.779547 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.787922 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.788623 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.791333 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.793220 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.802763 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.803504 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.805943 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.806617 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.814899 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.815649 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.818024 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.818692 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.826974 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.827722 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.830084 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.830769 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.844441 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.845136 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.847555 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.848293 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.856972 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.857702 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.860140 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.860877 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.868146 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.868999 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.872753 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.873387 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.931684 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.932640 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.935498 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.936259 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.944115 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.944886 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.947971 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.949408 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.958851 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.959626 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.962537 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.963227 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.970601 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.971469 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.974947 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.975533 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.985072 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.985844 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.988213 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.988955 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.997645 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:34.998368 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.001345 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.001920 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.009888 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.010627 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.013032 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.013727 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.022683 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.023485 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.025832 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.026518 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.033935 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.034775 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.039378 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.040091 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.047529 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.048348 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.051363 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.051893 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.060786 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.061613 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.065534 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.066100 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.073963 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.074629 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.077306 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.077896 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.085539 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.086258 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.089163 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.089860 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.103728 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.104618 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.107046 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.107893 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.116841 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.117565 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.120490 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.121185 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.128972 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.129700 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.132617 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E0831 15:41:35.133241 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!

## 8. Inference using TRTorch In this phase, we run the exported torchscript graph of VGG QAT using TRTorch. TRTorch is a Pytorch-TensorRT compiler which converts Torchscript graphs into TensorRT. TensorRT 8.0 supports inference of quantization aware trained models and introduces new APIs; QuantizeLayer and DequantizeLayer . We can observe the entire VGG QAT graph quantization nodes from the debug log of TRTorch. To enable debug logging, you can set trtorch.logging.set_reportable_log_level(trtorch.logging.Level.Debug) . For example, QuantConv2d layer from pytorch_quantization toolkit is represented as follows in Torchscript

%quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%x, %636, %637, %638, %639)
%quant_weight : Tensor = aten::fake_quantize_per_channel_affine(%394, %640, %641, %637, %638, %639)
%input.2 : Tensor = aten::_convolution(%quant_input, %quant_weight, %395, %687, %688, %689, %643, %690, %642, %643, %643, %644, %644)

aten::fake_quantize_per_*_affine is converted into QuantizeLayer + DequantizeLayer in TRTorch internally. Please refer to quantization op converters in TRTorch.

[17]:
qat_model = torch.jit.load("trained_vgg16_qat.jit.pt").eval()

compile_spec = {"inputs": [trtorch.Input([16, 3, 32, 32])],
                "op_precision": torch.int8,
                }
trt_mod = trtorch.compile(qat_model, compile_spec)

test_loss, test_acc = test(trt_mod, testing_dataloader, crit, 0)
print("VGG QAT accuracy using TensorRT: {:.2f}%".format(100 * test_acc))
VGG QAT accuracy using TensorRT: 83.59%

## 9. References * Very Deep Convolution Networks for large scale Image Recognition * Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT * QAT workflow for VGG16 * Deploying VGG QAT model in C++ using TRTorch * Pytorch-quantization toolkit from NVIDIA * Pytorch quantization toolkit userguide * Quantization basics