Quick Start: Pruning

Tip

Checkout ResNet20 on CIFAR-10 Notebook and HF BERT Prune, Distill & Quantize for an end-to-end example of pruning.

ModelOpt’s Pruning library provides many light-weight pruning methods like Minitron, FastNAS, and GradNAS that can be run on any user-provided model. Check out this doc for more details on these pruning methods and recommendations on when what pruning method to use.

Pruning a pretrained model involves three steps which are setting up your model, setting up the search, and finally running the search (pruning).

Set up your model

To set up your model for pruning, simply initialize the model and load a pre-trained checkpoint. Alternatively, you can also train the model from scratch.

Prune the model

To prune your model, you can simply call the mtp.prune API and save the pruned model using mto.save.

An example of FastNAS pruning is given below:

import modelopt.torch.opt as mto
import modelopt.torch.prune as mtp

# prune_res (dict) contains state_dict / stats of the pruner/searcher.
pruned_model, prune_res = mtp.prune(
    model=model,
    mode="fastnas",
    constraints=prune_constraints,
    dummy_input=dummy_input,
    config={
        "data_loader": train_loader,  # training data is used for calibrating BN layers
        "score_func": score_func,  # validation score is used to rank the subnets
        # checkpoint to store the search state and resume or re-run the search with different constraint
        "checkpoint": "modelopt_fastnas_search_checkpoint.pth",
    },
)

# Save the pruned model.
mto.save(pruned_model, "modelopt_pruned_model.pth")

Note

Fine-tuning is required after pruning to recover the accuracy. Please refer to pruning fine-tuning for mode details.


Next steps
  • Learn more about Pruning API and supported algorithms / models.

  • Learn more about NAS, which is a generalization of pruning.

  • See ModelOpt API documentation for detailed functionality and usage information.