Pruning

Tip

Checkout ResNet20 on CIFAR-10 Notebook and HF BERT Prune, Distill & Quantize for an end-to-end example of pruning.

ModelOpt provides three main pruning methods (aka mode) - Minitron, FastNAS and GradNAS - via a unified API mtp.prune. Given a model, these methods finds the subnet which meets the given deployment constraints (e.g. FLOPs, parameters) from your provided base model with little to no accuracy degradation (depending on how aggressive is the pruning). These pruning methods support pruning the convolutional and linear layers, and attention heads of the model. More details on these pruning modes is as follows:

  1. fastnas: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints.

  2. mcore_gpt_minitron: A pruning method developed by NVIDIA Research for pruning GPT-style models (e.g. Llama 3) in NVIDIA NeMo or Megatron-LM framework that are using Pipeline Parallellism. It uses the activation magnitudes to prune the mlp, attention heads, and GQA query groups. Checkout more details of the algorithm in the paper.

  3. gradnas: A light-weight pruning method recommended for language models like Hugging Face BERT and GPT-J. It uses the gradient information to prune the model’s linear layers and attention heads to meet the given constraints.

Follow the steps described below to obtain the optimal model satisfying your requirements using mtp:

  1. Training: Simply train your model using existing training pipeline or load a pre-trained checkpoint for your model.

  2. Pruning: Prune the model using our provided mtp.prune API and get an optimal subnet describing the pruned network architecture.

  3. Fine-tuning: fine-tune the resulting subnet to recover the accuracy.

To find out more about the concepts behind NAS and pruning, please refer to NAS concepts.

Training

To perform pruning, you can either use a model obtained by converting a pre-trained checkpoint model or train the model from scratch.

Simply initialize your model and load the checkpoint before you start using ModelOpt.

You can simply use your existing training pipeline to train the model without further modifications.

Fine-tuning

The final step of architecture search is to fine-tune the pruned model on your dataset. This way you can ensure to obtain the best possible performance for your pruned model.

Prerequisites

  1. To perform fine-tuning you need a pruned subnet as explained in the previous section.

  2. You can reuse your existing training pipeline. We recommend running fine-tuning with your original training schedule:

    • 1x training epochs (or 1x downstream task fine-tuning),

    • same or smaller (0.5x-1x) learning rate.

Load the pruned model

You can simply restore your pruned model (weights and architecture) using mto.restore():

import modelopt.torch.opt as mto
from torchvision.models import resnet50

# Build original model
model = resnet50()

# Restore the pruned architecture and weights
pruned_model = mto.restore(model, "modelopt_pruned_model.pth")

Run fine-tuning

Now, please go ahead and fine-tune the pruned subnet using your standard training pipeline with the pre-configured hyperparameters. A usually good fine-tuning schedule is to repeat the pre-training schedule with 0.5x-1x initial learning rate.

Do not forget to save the model using mto.save().

train(pruned_model)

mto.save(pruned_model, "modelopt_pruned_finetuned_model.pth")

Deploy

The pruned and finetuned model is now ready for downstream tasks like deployment. The model you have in hand now should be the best neural network meeting your deployment-aware search constraint.

import modelopt.torch.opt as mto
from torchvision.models import resnet50

# Build original model
model = resnet50()

model = mto.restore(model, "modelopt_pruned_finetuned_model.pth")

# Continue with downstream tasks like deployment (e.g. TensorRT or TensorRT-LLM)
...

Pruning Concepts

Pruning is the process of removing redundant components from a neural network for a given task. Conceptually, pruning is similar to NAS, but has less computational overhead compared to NAS at the cost of potentially finding a less optimal architecture compared to NAS. Most APIs are based on the corresponding NAS APIs but are adapted to reflect the simpler workflow.

Specifically, for pruning we do not specifically train the search space and all its subnets. Instead, a pre-trained checkpoint is used to approximate the search space. Therefore, we can skip the (potentially expensive) search space training step and directly search for a subnet architecture before fine-tuning the resulting subnet.

Note

If you want to learn more about the concept behind NAS and pruning, take a look at NAS Concepts including a more detailed comparison between NAS and pruning.