Distillation
Introduction
ModelOpt’s Distillation API (modelopt.torch.distill
) allows you to enable a
knowledge-distillation training pipeline with minimal script modification.
Follow the steps described below to obtain a model trained with direct knowledge transferred from
a more powerful teacher model using modelopt.torch.distill
:
Convert your model via
mtd.convert
: Wrap both a teacher and student model into a larger meta-model which abstracts away the interaction between the two.Distillation training: Seamlessly use the meta-model in place of the original model and run the original script with only one additional line of code for loss calculation.
Checkpoint and re-load: Save the model via
mto.save
and restore viamto.restore
. See saving and restoring to learn more.
To find out more about Distillation and related concepts, please refer to the below section Distillation Concepts.
Convert and integrate
You can convert your model into a DistillationModel
using mtd.convert()
.
Example usage:
import modelopt.torch.distill as mtd
from torchvision.models import resnet50
# User-defined model (student)
model = resnet50()
# Configure and convert for distillation
distillation_config = {
# `teacher_model` is a model class or callable, or a tuple.
# If a tuple, it must be of the form (model_cls_or_callable,) or
# (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs).
"teacher_model": teacher_model,
"criterion": mtd.LogitsDistillationLoss(),
"loss_balancer": mtd.StaticLossBalancer(),
}
distillation_model = mtd.convert(model, mode=[("kd_loss", distillation_config)])
# Export model in original class form
model_exported = mtd.export(distillation_model)
Note
The config requires a (non-lambda) Callable to return a teacher model in place of the model
itself. This is to avoid re-saving the teacher state dict upon saving the Distillation
meta model. Thus, the same callable must be available in the namespace when restoring via
the mto.restore
utility.
Note
As the model is not of the same class anymore, calling type()
on the model after conversion
will not work as expected.
Though isinstance()
will still work, as the model dynamically becomes a subclass of the original’s.
—
Distillation Concepts
Below, we will provide an overview of ModelOpt’s distillation feature as well as its basic concepts and terminology.
Overview
The transfer of learnable feature information from a teacher model to a student. |
|
The model to be trained (can either start from scratch or pre-trained). |
|
The fixed, pre-trained model used as the example the student will “learn” from. |
|
A loss function used between the features of a student and teacher to perform Knowledge Distillation, separate from the student’s original task loss. |
|
An implementation for a utility which determines how to combine Distillation loss(es) and original student task loss into a single scalar. |
|
The specific process of performing Knowledge Distillation between output logits of a teacher and student models. |
Concepts
Knowledge Distillation
Distillation can be a broader term used to define any sort of information compressed among models, but in this case we refer to basic teacher-student Knowledge Distillation. The process creates an auxiliary loss (or can replace the original one) between a model which is already trained (teacher) and a model which is not (student), in hopes of making the student learn information (i.e. feature maps or logits) which the teacher has already mastered. This can serve multiple purposes:
A. Model-size reduction: A smaller, efficient student model (potentially a pruned teacher) reaching accuracies near or exceeding that of the larger, slower teacher model. (See the Lottery Ticket Hypothesis for reasoning behind this, which also applies to pruning)
B. An alternative to pure training: Distilling a model from an existing one (and then fine-tuning) can often be faster than training it from scratch.
C. Module replacement: One can replace a single module within a model with a more efficient one and use distillation on its original outputs to effectively re-integrate it into the whole model.
Student
This is the model we wish to train and use in the end. It ideally meets the desired architectural and computational requirements, but is either untrained or requires a boost in accuracy.
Teacher
This is the model from which learned features/information are used to create a loss for the student. Usually it is larger and/or slower than desired, but possesses a satisfactory accuracy.
Distillation loss
To actually “transfer” knowledge from a teacher to student, we need to add (or replace) an optimization objective to the student’s original loss function(s). This can be as simple as enacting MSE on two same-sized activation tensors between the teacher and student, with the assumption that the features learned by the teacher are of high-quality and should be imitated as much as possible.
ModelOpt supports specifying a different loss function per layer-output pair, and includes a few
pre-defined functions for use, though users may often need to define their own.
Module-pairs-to-loss-function mappings are specified via the criterion
key of the configuration
dictionary - student and teacher, respectively in order - and the loss function itself should accept
outputs in the same order as well:
# Example using pairwise-mapped criterion.
# Will perform the loss on the output of ``student_model.classifier`` and ``teacher_model.layers.18``
distillation_config = {
"teacher_model": teacher_model,
"criterion": {("classifier", "layers.18"): mtd.LogitsDistillationLoss()},
}
distillation_model = atd.convert(student_model, mode=[("kd_loss", distillation_config)])
The intermediate outputs for the losses are captured by the
DistillationModel
and then the loss(es) are
invoked using DistillationModel.compute_kd_loss()
.
If present, the original student’s non-distillation loss is passed in as an argument.
Writing a custom loss function is often necessary, especially to handle outputs that need to be processed to obtain the logits and activations.
Loss Balancer
As Distillation losses may be applied to several pairs of layers, the losses are returned in the
form of a dictionary which should be reduced into a scalar value for backpropagation. A Loss
Balancer (whose interface is defined by
DistillationLossBalancer
) serves to fill
this purpose.
If Distillation loss is only applied to a single pair of layer outputs, and no student loss is available, a Loss Balancer should not be provided.
ModelOpt provides a simple Balancer implementation, and the aforementioned interface can be used to create custom ones.
Soft-label Distillation
The scenario involving distillation only on the output logits of student/teacher classification models is known as Soft-label Distillation. In this case, one could even omit the student’s original classification loss altogether if the teacher’s outputs are purely preferred over whatever the ground truth labels may be.