ResNet20 on CIFAR-10: Pruning
Tip
It would take about 3 hours on Google Colab but can be run as fast as in 1 hour if you use a better GPU. You can expect slightly different accuracies than reported below depending on the system you run this notebook in. The purpose of this notebook is to demonstrate the workflow of pruning using Model Optimizer and not to achieve the best accuracy.
In this tutorial, we will use Model Optimizer to make the ResNet model faster for our target deployment constraints using pruning without sacrificing much accuracy!
By the end of this tutorial, you will:
Understand how to use Model Optimizer to prune a user-provided model to the best performing subnet architecture fitting your target deployment constraints.
Save and restore your pruned model for downstream tasks like fine-tuning and inference.
All of this with just a few lines of code! Yes, it’s that simple!
Let’s first install Model Optimizer
following the installation steps.
[ ]:
%pip install "nvidia-modelopt[torch]" --extra-index-url https://pypi.nvidia.com
[2]:
import math
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.models.resnet import BasicBlock
from tqdm.auto import tqdm
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device("cuda")
CIFAR-10 Dataset for Image Classification
For this tutorial, we will be working with the well known CIFAR-10 dataset for image classification. The dataset consists of 60k 32x32 images from 10 classes split into 50k training and 10k testing images. We will further take 5k randomly out from the training set to make it our validation set.
[3]:
def get_cifar10_dataloaders(train_batch_size: int):
"""Return Train-Val-Test data loaders for the CIFAR-10 dataset."""
np.random.seed(seed)
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
# Split Train dataset into Train-Val datasets
train_dataset = torchvision.datasets.CIFAR10(
root="./data",
train=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
normalize,
]
),
download=True,
)
n_trainval = len(train_dataset)
n_train = int(n_trainval * 0.9)
ids = np.arange(n_trainval)
np.random.shuffle(ids)
train_ids, val_ids = ids[:n_train], ids[n_train:]
train_dataset.data = train_dataset.data[train_ids]
train_dataset.targets = np.array(train_dataset.targets)[train_ids]
val_dataset = torchvision.datasets.CIFAR10(
root="./data",
train=True,
transform=transforms.Compose([transforms.ToTensor(), normalize]),
download=True,
)
val_dataset.data = val_dataset.data[val_ids]
val_dataset.targets = np.array(val_dataset.targets)[val_ids]
test_dataset = torchvision.datasets.CIFAR10(
root="./data",
train=False,
transform=val_dataset.transform,
download=True,
)
num_workers = min(8, os.cpu_count())
train_loader = torch.utils.data.DataLoader(
train_dataset, train_batch_size, num_workers=num_workers, pin_memory=True, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=1024, num_workers=num_workers, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=1024, num_workers=num_workers, pin_memory=True
)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
return train_loader, val_loader, test_loader
ResNet for CIFAR dataset
We will be working with the ResNet variants for CIFAR dataset, namely ResNet-20 and ResNet-32 since these are very small models to train. You can find more details about these models in this paper. Below is an example of a regular PyTorch model without anything new.
Setting up the model
We first set up and add some helper functions for training
[4]:
def _weights_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
torch.nn.init.kaiming_normal_(m.weight)
class LambdaLayer(nn.Module):
def __init__(self, lambd):
super().__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class ResNet(nn.Module):
def __init__(self, num_blocks, num_classes=10):
super().__init__()
self.in_planes = 16
self.layers = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(),
self._make_layer(16, num_blocks, stride=1),
self._make_layer(32, num_blocks, stride=2),
self._make_layer(64, num_blocks, stride=2),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(64, num_classes),
)
self.apply(_weights_init)
def _make_layer(self, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
downsample = None
if stride != 1 or self.in_planes != planes:
downsample = LambdaLayer(
lambda x: F.pad(
x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0
)
)
layers.append(BasicBlock(self.in_planes, planes, stride, downsample))
self.in_planes = planes
return nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
def resnet20(ckpt=None):
model = ResNet(num_blocks=3).to(device)
if ckpt is not None:
model.load_state_dict(torch.load(ckpt, device))
return model
def resnet32(ckpt=None):
model = ResNet(num_blocks=5).to(device)
if ckpt is not None:
model.load_state_dict(torch.load(ckpt, device))
return model
[5]:
class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
decay_steps: int,
warmup_lr: float = 0.0,
last_epoch: int = -1,
) -> None:
self.warmup_steps = warmup_steps
self.warmup_lr = warmup_lr
self.decay_steps = decay_steps
super().__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
return [
(base_lr - self.warmup_lr) * self.last_epoch / self.warmup_steps + self.warmup_lr
for base_lr in self.base_lrs
]
else:
current_steps = self.last_epoch - self.warmup_steps
return [
0.5 * base_lr * (1 + math.cos(math.pi * current_steps / self.decay_steps))
for base_lr in self.base_lrs
]
def get_optimizer_scheduler(model, lr, weight_decay, warmup_steps, decay_steps):
optimizer = torch.optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr,
momentum=0.9,
weight_decay=weight_decay,
)
lr_scheduler = CosineLRwithWarmup(optimizer, warmup_steps, decay_steps)
return optimizer, lr_scheduler
def train_one_epoch(model, train_loader, loss_fn, optimizer, lr_scheduler):
"""Train the given model for 1 epoch."""
model.train()
epoch_loss = 0.0
for imgs, labels in train_loader:
output = model(imgs.to(device))
loss = loss_fn(model, output, labels.to(device))
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
epoch_loss /= len(train_loader.dataset)
return epoch_loss
@torch.no_grad()
def evaluate(model, test_loader):
"""Evaluate the model on the given test_loader and return accuracy percentage."""
model.eval()
correct = total = 0.0
for imgs, labels in test_loader:
output = model(imgs.to(device))
predicted = output.argmax(dim=1).detach().cpu()
correct += torch.sum(labels == predicted).item()
total += len(labels)
accuracy = 100 * correct / total
return accuracy
def loss_fn_default(model, output, labels):
return F.cross_entropy(output, labels)
def train_model(
model,
train_loader,
val_loader,
optimizer,
lr_scheduler,
num_epochs,
loss_fn=loss_fn_default,
print_freq=25,
ckpt_path="temp_saved_model.pth",
):
"""Train the given model with provided parameters.
loss_fn: function that takes model, output, labels and returns loss. This allows us to obtain the loss
from the model as well if needed.
"""
best_val_acc, best_ep = 0.0, 0
print(f"Training the model for {num_epochs} epochs...")
for ep in tqdm(range(1, num_epochs + 1)):
train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, lr_scheduler)
val_acc = evaluate(model, val_loader)
if val_acc >= best_val_acc:
best_val_acc, best_ep = val_acc, ep
torch.save(model.state_dict(), ckpt_path)
if ep == 1 or ep % print_freq == 0 or ep == num_epochs:
print(f"Epoch {ep:3d}\t Training loss: {train_loss:.4f}\t Val Accuracy: {val_acc:.2f}%")
print(
f"Model Trained! Restoring to parameters that gave best Val Accuracy ({best_val_acc:.2f}% at Epoch {best_ep})."
)
model.load_state_dict(torch.load(ckpt_path), device)
You can uncomment the print statement below to see the ResNet20 model details.
[6]:
# print(resnet20())
Training a baseline model
It should take about 10-30 mins to train depending on your GPU and CPU. We use slightly different training hyperparameters compared to the original setup described in the paper to make the training faster for this tutorial.
You can also reduce the num_epochs
parameter below to make the whole notebook run faster at the cost of accuracy.
[7]:
batch_size = 512
num_epochs = 120
learning_rate = 0.1 * batch_size / 128
weight_decay = 1e-4
train_loader, val_loader, test_loader = get_cifar10_dataloaders(batch_size)
batch_per_epoch = len(train_loader)
warmup_steps = 5 * batch_per_epoch
decay_steps = num_epochs * batch_per_epoch
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Train: 45000, Val: 5000, Test: 10000
[8]:
resnet20_model = resnet20()
optimizer, lr_scheduler = get_optimizer_scheduler(
resnet20_model, learning_rate, weight_decay, warmup_steps, decay_steps
)
train_model(
resnet20_model,
train_loader,
val_loader,
optimizer,
lr_scheduler,
num_epochs,
ckpt_path="resnet20.pth",
)
print(f"Test Accuracy of ResNet20: {evaluate(resnet20_model, test_loader)}")
Training the model for 120 epochs...
Epoch 1 Training loss: 0.0049 Val Accuracy: 22.82%
Epoch 25 Training loss: 0.0006 Val Accuracy: 78.84%
Epoch 50 Training loss: 0.0004 Val Accuracy: 85.06%
Epoch 75 Training loss: 0.0002 Val Accuracy: 88.12%
Epoch 100 Training loss: 0.0001 Val Accuracy: 90.34%
Epoch 120 Training loss: 0.0000 Val Accuracy: 90.80%
Model Trained! Restoring to parameters that gave best Val Accuracy (90.92% at Epoch 119).
Test Accuracy of ResNet20: 90.97
We have now established a baseline model and accuracy that we will be comparing with using Model Optimizer.
So far, we have seen a regular PyTorch model trained without anything new. Now, lets optimize the model for our target constraints using Model Optimizer!
FastNAS Pruning with Model Optimizer
The Model Optimizer’s modelopt.torch.prune
module provides advanced state-of-the-art pruning algorithms that enable you to search for the best subnet architecture from your provided base model.
Model Optimizer can be used in one of the following complementary modes to create a search space for optimizing the model:
fastnas
: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints.mcore_gpt_minitron
: A pruning method developed by NVIDIA Research for pruning GPT-style models in NVIDIA NeMo or Megatron-LM framework that are using Pipeline Parallellism. It uses the activation magnitudes to prune the mlp, attention heads, and GQA query groups.gradnas
: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model’s linear layers and attention heads to meet the given constraints.
In this example, we will use the fastnas
mode to prune the ResNet20 model for CIFAR-10 dataset. Checkout the Model Optimizer GitHub repository for more examples.
Let’s first use the FastNAS mode to convert a ResNet model and reduce its FLOPs, number of parameters, and latency.
[9]:
import modelopt.torch.opt as mto
import modelopt.torch.prune as mtp
Prune base model and store pruned net
Using mtp.prune you can
generate a search space for pruning from your base model;
prune the model;
obtain a valid pytorch model that can be used for fine-tuning.
Let’s say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 30M FLOPs. We can provide search constraints for flops
and/or params
by an upper bound. The values can either be absolute numbers (e.g. 30e6
) or a string percentage (e.g. "75%"
). In addition, we should also provide our training data loader to mtp.prune. The training data
loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.
Finally, we can store the pruned architecture and weights using mto.save.
Note
We are optimizing a relatively smaller model here. A finer-grained search could be more effective in such a case. This is why we are specifying custom configs. In general however, it is recommended to convert models with the default config itself.
[10]:
# config with more fine-grained channel choices for fastnas
config = mtp.config.FastNASConfig()
config["nn.Conv2d"]["*"]["channel_divisor"] = 16
config["nn.BatchNorm2d"]["*"]["feature_divisor"] = 16
# A single 32x32 image for computing FLOPs
dummy_input = torch.randn(1, 3, 32, 32, device=device)
# Wrap your original validation function to only take the model as input.
# This function acts as the score function to rank models.
def score_func(model):
return evaluate(model, val_loader)
# prune the model
pruned_model, _ = mtp.prune(
model=resnet20(ckpt="resnet20.pth"),
mode=[("fastnas", config)],
constraints={"flops": 30e6},
dummy_input=dummy_input,
config={
"data_loader": train_loader,
"score_func": score_func,
"checkpoint": "modelopt_seaarch_checkpoint_fastnas.pth",
},
)
# save the pruned model for future use
mto.save(pruned_model, "modelopt_pruned_model_fastnas.pth")
# evaluate the pruned model
print(f"Test Accuracy of Pruned ResNet20: {evaluate(pruned_model, test_loader)}")
Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------
Profiling Results
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Constraint ┃ min ┃ centroid ┃ max ┃ max/min ratio ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops │ 24.33M │ 27.57M │ 40.55M │ 1.67 │
│ params │ 90.94K │ 141.63K │ 268.35K │ 2.95 │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
Constraints Evaluation
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃ ┃ ┃ Satisfiable ┃
┃ Constraint ┃ Upper Bound ┃ Upper Bound ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│ flops │ 30.00M │ True │
└──────────────┴──────────────┴──────────────┘
Search Space Summary:
----------------------------------------------------------------------------------------------------
layers.depth [9]
layers.0.out_channels [16]
layers.0.in_channels [3]
layers.3.depth [3]
layers.3.0.conv1.out_channels [16]
layers.3.0.conv1.in_channels [16]
layers.3.0.bn1.num_features [16]
layers.3.0.conv2.out_channels [16]
layers.3.0.conv2.in_channels [16]
layers.3.1.conv1.out_channels [16]
layers.3.1.conv1.in_channels [16]
layers.3.1.bn1.num_features [16]
layers.3.1.conv2.out_channels [16]
layers.3.1.conv2.in_channels [16]
layers.3.2.conv1.out_channels [16]
layers.3.2.conv1.in_channels [16]
layers.3.2.bn1.num_features [16]
layers.3.2.conv2.out_channels [16]
layers.3.2.conv2.in_channels [16]
layers.4.depth [3]
* layers.4.0.conv1.out_channels [16, 32]
layers.4.0.conv1.in_channels [16]
layers.4.0.bn1.num_features [16, 32]
layers.4.0.conv2.out_channels [32]
layers.4.0.conv2.in_channels [16, 32]
* layers.4.1.conv1.out_channels [16, 32]
layers.4.1.conv1.in_channels [32]
layers.4.1.bn1.num_features [16, 32]
layers.4.1.conv2.out_channels [32]
layers.4.1.conv2.in_channels [16, 32]
* layers.4.2.conv1.out_channels [16, 32]
layers.4.2.conv1.in_channels [32]
layers.4.2.bn1.num_features [16, 32]
layers.4.2.conv2.out_channels [32]
layers.4.2.conv2.in_channels [16, 32]
layers.5.depth [3]
* layers.5.0.conv1.out_channels [16, 32, 48, 64]
layers.5.0.conv1.in_channels [32]
layers.5.0.bn1.num_features [16, 32, 48, 64]
layers.5.0.conv2.out_channels [64]
layers.5.0.conv2.in_channels [16, 32, 48, 64]
* layers.5.1.conv1.out_channels [16, 32, 48, 64]
layers.5.1.conv1.in_channels [64]
layers.5.1.bn1.num_features [16, 32, 48, 64]
layers.5.1.conv2.out_channels [64]
layers.5.1.conv2.in_channels [16, 32, 48, 64]
* layers.5.2.conv1.out_channels [16, 32, 48, 64]
layers.5.2.conv1.in_channels [64]
layers.5.2.bn1.num_features [16, 32, 48, 64]
layers.5.2.conv2.out_channels [64]
layers.5.2.conv2.in_channels [16, 32, 48, 64]
----------------------------------------------------------------------------------------------------
Number of configurable hparams: 6
Total size of the search space: 5.12e+02
Note: all constraints can be satisfied within the search space!
Beginning pre-search estimation. If the runtime of score function is longer than a few minutes, consider subsampling the dataset used in score function.
A PyTorch dataset can be subsampled using torch.utils.data.Subset (https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset) as following:
subset_dataset = torch.utils.data.Subset(dataset, indices)
Collecting pre-search statistics: 100%|██████████| 18/18 [00:10<00:00, 1.76it/s, cur=layers.5.2.conv1.out_channels(64/64): 0.00]
[num_satisfied] = 11: 0%| | 20/10000 [00:02<17:43, 9.39it/s]
[best_subnet_constraints] = {'params': '173.88K', 'flops': '29.64M'}
Test Accuracy of Pruned ResNet20: 60.37
As we can see, the best subnet (29.6M FLOPs) has fitted our constraint of 30M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.
Restore the pruned subnet using mto.restore
[11]:
pruned_model = mto.restore(resnet20(), "modelopt_pruned_model_fastnas.pth")
Fine-tuning
To fine-tune the subnet, you can simply repeat the training pipeline of the original model (1x training time, 0.5x-1x of original learning rate). The fine-tuned model constitutes the final model with the optimal trade-off between accuracy and your provided constraints that is used for deployment.
Note that it would take about 5 - 15 mins to train depending on your GPU and CPU.
[12]:
optimizer, lr_scheduler = get_optimizer_scheduler(
pruned_model, 0.5 * learning_rate, weight_decay, warmup_steps, decay_steps
)
train_model(
pruned_model,
train_loader,
val_loader,
optimizer,
lr_scheduler,
num_epochs,
)
# store final model
mto.save(pruned_model, "modelopt_pruned_model_fastnas_finetuned.pth")
Training the model for 120 epochs...
Epoch 1 Training loss: 0.0011 Val Accuracy: 79.90%
Epoch 25 Training loss: 0.0004 Val Accuracy: 82.82%
Epoch 50 Training loss: 0.0003 Val Accuracy: 87.00%
Epoch 75 Training loss: 0.0002 Val Accuracy: 88.62%
Epoch 100 Training loss: 0.0000 Val Accuracy: 90.62%
Epoch 120 Training loss: 0.0000 Val Accuracy: 90.58%
Model Trained! Restoring to parameters that gave best Val Accuracy (90.70% at Epoch 101).
Evaluate the searched subnet
[13]:
# you can restore the fine-tuned model from the vanilla model
optimized_model = mto.restore(resnet20(), "modelopt_pruned_model_fastnas_finetuned.pth")
# test the accuracy
print(f"Test Accuracy of the fine-tuned pruned net: {evaluate(optimized_model, test_loader)}")
Test Accuracy of the fine-tuned pruned net: 90.28
Conclusion?
The comparison can be summarized as below:
Model |
FLOPs |
Params |
Test Accuracy |
---|---|---|---|
ResNet20 |
40.6M |
268k |
90.9% |
FastNAS subnet |
29.6M |
174k |
90.3% |
As we see here, we have reduced the FLOPs and number of parameters which would also result in a improvement in latency with very little loss in accuracy. Good job!
Next: checkout the Model Optimizer GitHub repository for more examples.