This tutorial is available as a Jupyter Notebook! Download notebook from here.

ResNet20 on CIFAR-10: Pruning

Tip

It would take about 3 hours on Google Colab but can be run as fast as in 1 hour if you use a better GPU. You can expect slightly different accuracies than reported below depending on the system you run this notebook in. The purpose of this notebook is to demonstrate the workflow of pruning using Model Optimizer and not to achieve the best accuracy.

In this tutorial, we will use Model Optimizer to make the ResNet model faster for our target deployment constraints using pruning without sacrificing much accuracy!

By the end of this tutorial, you will:

  • Understand how to use Model Optimizer to prune a user-provided model to the best performing subnet architecture fitting your target deployment constraints.

  • Save and restore your pruned model for downstream tasks like fine-tuning and inference.

All of this with just a few lines of code! Yes, it’s that simple!

Let’s first install Model Optimizer following the installation steps.

[ ]:
%pip install "nvidia-modelopt[torch]" --extra-index-url https://pypi.nvidia.com
[2]:
import math
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.models.resnet import BasicBlock
from tqdm.auto import tqdm

seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda")

CIFAR-10 Dataset for Image Classification

For this tutorial, we will be working with the well known CIFAR-10 dataset for image classification. The dataset consists of 60k 32x32 images from 10 classes split into 50k training and 10k testing images. We will further take 5k randomly out from the training set to make it our validation set.

[3]:
def get_cifar10_dataloaders(train_batch_size: int):
    """Return Train-Val-Test data loaders for the CIFAR-10 dataset."""
    np.random.seed(seed)

    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])

    # Split Train dataset into Train-Val datasets
    train_dataset = torchvision.datasets.CIFAR10(
        root="./data",
        train=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                normalize,
            ]
        ),
        download=True,
    )

    n_trainval = len(train_dataset)
    n_train = int(n_trainval * 0.9)
    ids = np.arange(n_trainval)
    np.random.shuffle(ids)
    train_ids, val_ids = ids[:n_train], ids[n_train:]

    train_dataset.data = train_dataset.data[train_ids]
    train_dataset.targets = np.array(train_dataset.targets)[train_ids]

    val_dataset = torchvision.datasets.CIFAR10(
        root="./data",
        train=True,
        transform=transforms.Compose([transforms.ToTensor(), normalize]),
        download=True,
    )
    val_dataset.data = val_dataset.data[val_ids]
    val_dataset.targets = np.array(val_dataset.targets)[val_ids]

    test_dataset = torchvision.datasets.CIFAR10(
        root="./data",
        train=False,
        transform=val_dataset.transform,
        download=True,
    )

    num_workers = min(8, os.cpu_count())
    train_loader = torch.utils.data.DataLoader(train_dataset, train_batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1024, num_workers=num_workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, num_workers=num_workers, pin_memory=True)
    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    return train_loader, val_loader, test_loader

ResNet for CIFAR dataset

We will be working with the ResNet variants for CIFAR dataset, namely ResNet-20 and ResNet-32 since these are very small models to train. You can find more details about these models in this paper. Below is an example of a regular PyTorch model without anything new.

Setting up the model

We first set up and add some helper functions for training

[4]:
def _weights_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        torch.nn.init.kaiming_normal_(m.weight)


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class ResNet(nn.Module):
    def __init__(self, num_blocks, num_classes=10):
        super().__init__()
        self.in_planes = 16
        self.layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            self._make_layer(16, num_blocks, stride=1),
            self._make_layer(32, num_blocks, stride=2),
            self._make_layer(64, num_blocks, stride=2),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, num_classes)
        )
        self.apply(_weights_init)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            downsample = None
            if stride != 1 or self.in_planes != planes:
                downsample = LambdaLayer(
                    lambda x: F.pad(
                        x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0
                    )
                )
            layers.append(BasicBlock(self.in_planes, planes, stride, downsample))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


def resnet20(ckpt=None):
    model = ResNet(num_blocks=3).to(device)
    if ckpt is not None:
        model.load_state_dict(torch.load(ckpt, device))
    return model

def resnet32(ckpt=None):
    model = ResNet(num_blocks=5).to(device)
    if ckpt is not None:
        model.load_state_dict(torch.load(ckpt, device))
    return model
[5]:
class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_steps: int,
        decay_steps: int,
        warmup_lr: float = 0.0,
        last_epoch: int = -1,
    ) -> None:
        self.warmup_steps = warmup_steps
        self.warmup_lr = warmup_lr
        self.decay_steps = decay_steps
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return [
                (base_lr - self.warmup_lr) * self.last_epoch / self.warmup_steps + self.warmup_lr
                for base_lr in self.base_lrs
            ]
        else:
            current_steps = self.last_epoch - self.warmup_steps
            return [
                0.5 * base_lr * (1 + math.cos(math.pi * current_steps / self.decay_steps))
                for base_lr in self.base_lrs
            ]


def get_optimizer_scheduler(model, lr, weight_decay, warmup_steps, decay_steps):
    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr,
        momentum=0.9,
        weight_decay=weight_decay,
    )
    lr_scheduler = CosineLRwithWarmup(optimizer, warmup_steps, decay_steps)
    return optimizer, lr_scheduler


def train_one_epoch(model, train_loader, loss_fn, optimizer, lr_scheduler):
    """Train the given model for 1 epoch."""
    model.train()
    epoch_loss = 0.0
    for imgs, labels in train_loader:
        output = model(imgs.to(device))
        loss = loss_fn(model, output, labels.to(device))
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

    epoch_loss /= len(train_loader.dataset)
    return epoch_loss


@torch.no_grad()
def evaluate(model, test_loader):
    """Evaluate the model on the given test_loader and return accuracy percentage."""
    model.eval()
    correct = total = 0.0
    for imgs, labels in test_loader:
        output = model(imgs.to(device))
        predicted = output.argmax(dim=1).detach().cpu()
        correct += torch.sum(labels == predicted).item()
        total += len(labels)

    accuracy = 100 * correct / total
    return accuracy


def loss_fn_default(model, output, labels):
    return F.cross_entropy(output, labels)


def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    lr_scheduler,
    num_epochs,
    loss_fn = loss_fn_default,
    print_freq=25,
    ckpt_path="temp_saved_model.pth",
):
    """Train the given model with provided parameters.

    loss_fn: function that takes model, output, labels and returns loss. This allows us to obtain the loss
        from the model as well if needed.
    """
    best_val_acc, best_ep = 0.0, 0
    print(f"Training the model for {num_epochs} epochs...")
    for ep in tqdm(range(1, num_epochs + 1)):
        train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, lr_scheduler)

        val_acc = evaluate(model, val_loader)
        if val_acc >= best_val_acc:
            best_val_acc, best_ep = val_acc, ep
            torch.save(model.state_dict(), ckpt_path)

        if ep == 1 or ep % print_freq == 0 or ep == num_epochs:
            print(f"Epoch {ep:3d}\t Training loss: {train_loss:.4f}\t Val Accuracy: {val_acc:.2f}%")

    print(
        f"Model Trained! Restoring to parameters that gave best Val Accuracy ({best_val_acc:.2f}% at Epoch {best_ep})."
    )
    model.load_state_dict(torch.load(ckpt_path), device)

You can uncomment the print statement below to see the ResNet20 model details.

[6]:
# print(resnet20())

Training a baseline model

It should take about 10-30 mins to train depending on your GPU and CPU. We use slightly different training hyperparameters compared to the original setup described in the paper to make the training faster for this tutorial.

You can also reduce the num_epochs parameter below to make the whole notebook run faster at the cost of accuracy.

[7]:
batch_size = 512
num_epochs = 120
learning_rate = 0.1 * batch_size / 128
weight_decay = 1e-4

train_loader, val_loader, test_loader = get_cifar10_dataloaders(batch_size)

batch_per_epoch = len(train_loader)
warmup_steps = 5 * batch_per_epoch
decay_steps = num_epochs * batch_per_epoch
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Train: 45000, Val: 5000, Test: 10000
[8]:
resnet20_model = resnet20()
optimizer, lr_scheduler = get_optimizer_scheduler(
    resnet20_model, learning_rate, weight_decay, warmup_steps, decay_steps
)
train_model(
    resnet20_model,
    train_loader,
    val_loader,
    optimizer,
    lr_scheduler,
    num_epochs,
    ckpt_path="resnet20.pth",
)
print(f"Test Accuracy of ResNet20: {evaluate(resnet20_model, test_loader)}")
Training the model for 120 epochs...
Epoch   1        Training loss: 0.0049   Val Accuracy: 22.82%
Epoch  25        Training loss: 0.0006   Val Accuracy: 78.84%
Epoch  50        Training loss: 0.0004   Val Accuracy: 85.06%
Epoch  75        Training loss: 0.0002   Val Accuracy: 88.12%
Epoch 100        Training loss: 0.0001   Val Accuracy: 90.34%
Epoch 120        Training loss: 0.0000   Val Accuracy: 90.80%
Model Trained! Restoring to parameters that gave best Val Accuracy (90.92% at Epoch 119).
Test Accuracy of ResNet20: 90.97

We have now established a baseline model and accuracy that we will be comparing with using Model Optimizer.

So far, we have seen a regular PyTorch model trained without anything new. Now, lets optimize the model for our target constraints using Model Optimizer!

FastNAS Pruning with Model Optimizer

The Model Optimizer’s modelopt.torch.prune module provides advanced state-of-the-art pruning algorithms that enable you to search for the best subnet architecture from your provided base model.

Model Optimizer can be used in one of the following complementary modes to create a search space for optimizing the model:

  1. fastnas: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints.

  2. mcore_gpt_minitron: A pruning method developed by NVIDIA Research for pruning GPT-style models in NVIDIA NeMo or Megatron-LM framework that are using Pipeline Parallellism. It uses the activation magnitudes to prune the mlp, attention heads, and GQA query groups.

  3. gradnas: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model’s linear layers and attention heads to meet the given constraints.

In this example, we will use the fastnas mode to prune the ResNet20 model for CIFAR-10 dataset. Checkout the Model Optimizer GitHub repository for more examples.

Let’s first use the FastNAS mode to convert a ResNet model and reduce its FLOPs, number of parameters, and latency.

[9]:
import modelopt.torch.opt as mto
import modelopt.torch.prune as mtp

Prune base model and store pruned net

Using mtp.prune you can

  • generate a search space for pruning from your base model;

  • prune the model;

  • obtain a valid pytorch model that can be used for fine-tuning.

Let’s say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 30M FLOPs. We can provide search constraints for flops and/or params by an upper bound. The values can either be absolute numbers (e.g. 30e6) or a string percentage (e.g. "75%"). In addition, we should also provide our training data loader to mtp.prune. The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.

Finally, we can store the pruned architecture and weights using mto.save.

Note

We are optimizing a relatively smaller model here. A finer-grained search could be more effective in such a case. This is why we are specifying custom configs. In general however, it is recommended to convert models with the default config itself.

[10]:
# config with more fine-grained channel choices for fastnas
config = mtp.config.FastNASConfig()
config["nn.Conv2d"]["*"]["channel_divisor"] = 16
config["nn.BatchNorm2d"]["*"]["feature_divisor"] = 16

# A single 32x32 image for computing FLOPs
dummy_input = torch.randn(1, 3, 32, 32, device=device)

# Wrap your original validation function to only take the model as input.
# This function acts as the score function to rank models.
def score_func(model):
    return evaluate(model, val_loader)

# prune the model
pruned_model, _ = mtp.prune(
    model=resnet20(ckpt="resnet20.pth"),
    mode=[("fastnas", config)],
    constraints={"flops": 30e6},
    dummy_input=dummy_input,
    config={
        "data_loader": train_loader,
        "score_func": score_func,
        "checkpoint": "modelopt_seaarch_checkpoint_fastnas.pth",
    },
)

# save the pruned model for future use
mto.save(pruned_model, "modelopt_pruned_model_fastnas.pth")

# evaluate the pruned model
print(f"Test Accuracy of Pruned ResNet20: {evaluate(pruned_model, test_loader)}")

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------

                              Profiling Results
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Constraint    min           centroid      max           max/min ratio ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 24.33M       │ 27.57M       │ 40.55M       │ 1.67          │
│ params       │ 90.94K       │ 141.63K      │ 268.35K      │ 2.95          │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘

            Constraints Evaluation
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃                             Satisfiable  ┃
┃ Constraint    Upper Bound   Upper Bound  ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│ flops        │ 30.00M       │ True         │
└──────────────┴──────────────┴──────────────┘


Search Space Summary:
----------------------------------------------------------------------------------------------------
  layers.depth                                                                     [9]
  layers.0.out_channels                                                            [16]
  layers.0.in_channels                                                             [3]
  layers.3.depth                                                                   [3]
  layers.3.0.conv1.out_channels                                                    [16]
  layers.3.0.conv1.in_channels                                                     [16]
  layers.3.0.bn1.num_features                                                      [16]
  layers.3.0.conv2.out_channels                                                    [16]
  layers.3.0.conv2.in_channels                                                     [16]
  layers.3.1.conv1.out_channels                                                    [16]
  layers.3.1.conv1.in_channels                                                     [16]
  layers.3.1.bn1.num_features                                                      [16]
  layers.3.1.conv2.out_channels                                                    [16]
  layers.3.1.conv2.in_channels                                                     [16]
  layers.3.2.conv1.out_channels                                                    [16]
  layers.3.2.conv1.in_channels                                                     [16]
  layers.3.2.bn1.num_features                                                      [16]
  layers.3.2.conv2.out_channels                                                    [16]
  layers.3.2.conv2.in_channels                                                     [16]
  layers.4.depth                                                                   [3]
* layers.4.0.conv1.out_channels                                                    [16, 32]
  layers.4.0.conv1.in_channels                                                     [16]
  layers.4.0.bn1.num_features                                                      [16, 32]
  layers.4.0.conv2.out_channels                                                    [32]
  layers.4.0.conv2.in_channels                                                     [16, 32]
* layers.4.1.conv1.out_channels                                                    [16, 32]
  layers.4.1.conv1.in_channels                                                     [32]
  layers.4.1.bn1.num_features                                                      [16, 32]
  layers.4.1.conv2.out_channels                                                    [32]
  layers.4.1.conv2.in_channels                                                     [16, 32]
* layers.4.2.conv1.out_channels                                                    [16, 32]
  layers.4.2.conv1.in_channels                                                     [32]
  layers.4.2.bn1.num_features                                                      [16, 32]
  layers.4.2.conv2.out_channels                                                    [32]
  layers.4.2.conv2.in_channels                                                     [16, 32]
  layers.5.depth                                                                   [3]
* layers.5.0.conv1.out_channels                                                    [16, 32, 48, 64]
  layers.5.0.conv1.in_channels                                                     [32]
  layers.5.0.bn1.num_features                                                      [16, 32, 48, 64]
  layers.5.0.conv2.out_channels                                                    [64]
  layers.5.0.conv2.in_channels                                                     [16, 32, 48, 64]
* layers.5.1.conv1.out_channels                                                    [16, 32, 48, 64]
  layers.5.1.conv1.in_channels                                                     [64]
  layers.5.1.bn1.num_features                                                      [16, 32, 48, 64]
  layers.5.1.conv2.out_channels                                                    [64]
  layers.5.1.conv2.in_channels                                                     [16, 32, 48, 64]
* layers.5.2.conv1.out_channels                                                    [16, 32, 48, 64]
  layers.5.2.conv1.in_channels                                                     [64]
  layers.5.2.bn1.num_features                                                      [16, 32, 48, 64]
  layers.5.2.conv2.out_channels                                                    [64]
  layers.5.2.conv2.in_channels                                                     [16, 32, 48, 64]
----------------------------------------------------------------------------------------------------
Number of configurable hparams: 6
Total size of the search space: 5.12e+02
Note: all constraints can be satisfied within the search space!


Beginning pre-search estimation. If the runtime of score function is longer than a few minutes, consider subsampling the dataset used in score function.
A PyTorch dataset can be subsampled using torch.utils.data.Subset (https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset) as following:
 subset_dataset = torch.utils.data.Subset(dataset, indices)
Collecting pre-search statistics: 100%|██████████| 18/18 [00:10<00:00,  1.76it/s, cur=layers.5.2.conv1.out_channels(64/64): 0.00]
[num_satisfied] = 11:   0%|          | 20/10000 [00:02<17:43,  9.39it/s]
[best_subnet_constraints] = {'params': '173.88K', 'flops': '29.64M'}

Test Accuracy of Pruned ResNet20: 60.37

As we can see, the best subnet (29.6M FLOPs) has fitted our constraint of 30M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.

Restore the pruned subnet using mto.restore

[11]:
pruned_model = mto.restore(resnet20(), "modelopt_pruned_model_fastnas.pth")

Fine-tuning

To fine-tune the subnet, you can simply repeat the training pipeline of the original model (1x training time, 0.5x-1x of original learning rate). The fine-tuned model constitutes the final model with the optimal trade-off between accuracy and your provided constraints that is used for deployment.

Note that it would take about 5 - 15 mins to train depending on your GPU and CPU.

[12]:
optimizer, lr_scheduler = get_optimizer_scheduler(
    pruned_model, 0.5 * learning_rate, weight_decay, warmup_steps, decay_steps
)
train_model(
    pruned_model,
    train_loader,
    val_loader,
    optimizer,
    lr_scheduler,
    num_epochs,
)
# store final model
mto.save(pruned_model, "modelopt_pruned_model_fastnas_finetuned.pth")
Training the model for 120 epochs...
Epoch   1        Training loss: 0.0011   Val Accuracy: 79.90%
Epoch  25        Training loss: 0.0004   Val Accuracy: 82.82%
Epoch  50        Training loss: 0.0003   Val Accuracy: 87.00%
Epoch  75        Training loss: 0.0002   Val Accuracy: 88.62%
Epoch 100        Training loss: 0.0000   Val Accuracy: 90.62%
Epoch 120        Training loss: 0.0000   Val Accuracy: 90.58%
Model Trained! Restoring to parameters that gave best Val Accuracy (90.70% at Epoch 101).

Evaluate the searched subnet

[13]:
# you can restore the fine-tuned model from the vanilla model
optimized_model = mto.restore(resnet20(), "modelopt_pruned_model_fastnas_finetuned.pth")

# test the accuracy
print(f"Test Accuracy of the fine-tuned pruned net: {evaluate(optimized_model, test_loader)}")
Test Accuracy of the fine-tuned pruned net: 90.28

Conclusion?

The comparison can be summarized as below:

Model

FLOPs

Params

Test Accuracy

ResNet20

40.6M

268k

90.9%

FastNAS subnet

29.6M

174k

90.3%

As we see here, we have reduced the FLOPs and number of parameters which would also result in a improvement in latency with very little loss in accuracy. Good job!

Next: checkout the Model Optimizer GitHub repository for more examples.