Distillation

Introduction

ModelOpt’s Distillation API (modelopt.torch.distill) allows you to enable a knowledge-distillation training pipeline with minimal script modification.

Follow the steps described below to obtain a model trained with direct knowledge transferred from a more powerful teacher model using modelopt.torch.distill:

  1. Convert your model via mtd.convert: Wrap both a teacher and student model into a larger meta-model which abstracts away the interaction between the two.

  2. Distillation training: Seamlessly use the meta-model in place of the original model and run the orignal script with only one additional line of code for loss calculation.

  3. Checkpoint and re-load: Save the model via mto.save and restore via mto.restore. See saving and restoring to learn more.

To find out more about Distillation and related concepts, please refer to the below section Distillation Concepts.

Convert and integrate

You can convert your model into a DistillationModel using mtd.convert().

Example usage:

import modelopt.torch.distill as mtd
from torchvision.models import resnet50

# User-defined model (student)
model = resnet50()

# Configure and convert for distillation
distillation_config = {
    # `teacher_model` is a model class or callable, or a tuple.
    # If a tuple, it must be of the form (model_cls_or_callable,) or
    # (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs).
    "teacher_model": teacher_model,
    "criterion": mtd.LogitsDistillationLoss(),
    "loss_balancer": mtd.StaticLossBalancer(),
}
distillation_model = mtd.convert(model, mode=[("kd_loss", distillation_config)])

# Export model in original class form
model_exported = mtd.export(distillation_model)

Note

The config requires a (non-lambda) Callable to return a teacher model in place of the model itself. This is to avoid re-saving the teacher state dict upon saving the Distillation meta model. Thus, the same callable must be available in the namespace when restoring via the mto.restore utility.

Note

As the model is not of the same class anymore, calling type() on the model after conversion will not work as expected. Though isinstance() will still work, as the model dynamically becomes a subclass of the original’s.

Distillation Concepts

Below, we will provide an overview of ModelOpt’s distillation feature as well as its basic concepts and terminology.

Overview

Glossary

Knowledge Distillation

The transfer of learnable feature information from a teacher model to a student.

Student

The model to be trained (can either start from scratch or pre-trained).

Teacher

The fixed, pre-trained model used as the example the student will “learn” from.

Distillation loss

A loss function used between the features of a student and teacher to perform Knowledge Distillation, separate from the student’s original task loss.

Loss Balancer

An implementation for a utility which determines how to combine Distillation loss(es) and orignal student task loss into a single scalar.

Soft-label Distillation

The specific process of performing Knowledge Distillation between output logits of a teacher and student models.

Concepts

Knowledge Distillation

Distillation can be a broader term used to define any sort of information compressed among models, but in this case we refer to basic teacher-student Knowledge Distillation. The process creates an auxilliary loss (or can replace the orignal one) between a model which is already trained (teacher) and a model which is not (student), in hopes of making the student learn information (i.e. feature maps or logits) which the teacher has already mastered. This can serve multiple purposes:

A. Model-size reduction: A smaller, efficient student model (potentially a pruned teacher) reaching accuracies near or exceeding that of the larger, slower teacher model. (See the Lottery Ticket Hypothesis for reasoning behind this, which also applies to pruning)

B. An alternative to pure training: Distilling a model from an existing one (and then fine-tuning) can often be faster than training it from scratch.

C. Module replacement: One can replace a single module within a model with a more efficient one and use distillation on its original outputs to effectively re-integrate it into the whole model.

Student

This is the model we wish to train and use in the end. It ideally meets the desired architectural and computational requirements, but is either untrained or requires a boost in accuracy.

Teacher

This is the model from which learned features/information are used to create a loss for the student. Usually it is larger and/or slower than desired, but possesses a satisfactory accuracy.

Distillation loss

To actually “transfer” knowledge from a teacher to student, we need to add (or replace) an optimization objective to the student’s original loss function(s). This can be as simple as enacting MSE on two same-sized activation tensors between the teacher and student, with the assumption that the features learned by the teacher are of high-quality and should be imitated as much as possible.

ModelOpt supports specifying a different loss function per layer-output pair, and includes a few pre-defined functions for use, though users may often need to define their own. Module-pairs-to-loss-function mappings are specified via the criterion key of the configuration dictionary - student and teacher, respectively in order - and the loss function itself should accept outputs in the same order as well:

# Example using pairwise-mapped criterion.
# Will perform the loss on the output of ``student_model.classifier`` and ``teacher_model.layers.18``
distillation_config = {
    "teacher_model": teacher_model,
    "criterion": {("classifier", "layers.18"): mtd.LogitsDistillationLoss()},
}
distillation_model = atd.convert(student_model, mode=[("kd_loss", distillation_config)])

The intermediate outputs for the losses are captured by the DistillationModel and then the loss(es) are invoked using DistillationModel.compute_kd_loss(). If present, the orignal student’s non-distillation loss is passed in as an argument.

Writing a custom loss function is often necessary, especially to handle outputs that need to be processed to obtain the logits and activations.

Loss Balancer

As Distillation losses may be applied to several pairs of layers, the losses are returned in the form of a dictionary which should be reduced into a scalar value for backpropagation. A Loss Balancer (whose interface is defined by DistillationLossBalancer) serves to fill this purpose.

If Distillation loss is only applied to a single pair of layer outputs, and no student loss is available, a Loss Balancer should not be provided.

ModelOpt provides a simple Balancer implementation, and the aforementioned interface can be used to create custom ones.

Soft-label Distillation

The scenario involving distillation only on the output logits of student/teacher classification models is known as Soft-label Distillation. In this case, one could even omit the student’s original classification loss altogether if the teacher’s outputs are purely preferred over whatever the ground truth labels may be.