Quick Start: Distillation

ModelOpt’s Distillation is a set of wrappers and utilities to easily perform Knowledge Distillation among teacher and student models. Given a pretrained teacher model, Distillation has the potential to train a smaller student model faster and/or with higher accuracy than the student model could achieve on its own.

This quick-start guide shows the necessary steps to integrate Distillation into your training pipeline.

Set up your base models

First obtain both a pretrained model to act as the teacher and a (usualy smaller) model to serve as the student.

from torchvision.models import resnet50, resnet18

# Define student
student_model = resnet18()


# Define callable which returns teacher
def teacher_factory():
    teacher_model = resnet50()
    teacher_model.load_state_dict(pretrained_weights)
    return teacher_model

Set up the meta model

As Knowledge Distillation involves (at least) two models, ModelOpt simplifies the integration process by wrapping both student and teacher into one meta model.

Please see an example Distillation setup below. This example assumes the outputs of teacher_model and student_model are logits.

import modelopt.torch.distill as mtd

distillation_config = {
    "teacher_model": teacher_factory,  # model initializer
    "criterion": mtd.LogitsDistillationLoss(),  # callable receiving student and teacher outputs, in order
    "loss_balancer": mtd.StaticLossBalancer(),  # combines multiple losses; omit if only one distillation loss used
}

distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)])

The teacher_model can be either a callable which returns an nn.Module or a tuple of (model_cls, args, kwargs). The criterion is the distillation loss used between student and teacher tensors. The loss_balancer determines how the original and distillation losses are combined (if needed).

See Distillation for more info.

Distill during training

To Distill from teacher to student, simply use the meta model in the usual training loop, while also using the meta model’s .compute_kd_loss() method to compute the distillation loss, in addition to the original user loss.

An example of Distillation training is given below:

# Setup the data loaders. As example:
train_loader = get_train_loader()

# Define user loss function. As example:
loss_fn = get_user_loss_fn()

for input, labels in train_dataloader:
    distillation_model.zero_grad()
    # Forward through the wrapped models
    out = distillation_model(input)
    # Same loss as originally present
    loss = loss_fn(out, labels)
    # Combine distillation and user losses
    loss_total = distillation_model.compute_kd_loss(student_loss=loss)
    loss_total.backward()

Note

DataParallel may break ModelOpt’s Distillation feature. Note that HuggingFace Trainer uses DataParallel by default.

Export trained model

The model can easily be reverted to its original class for further use (i.e deployment) without any ModelOpt modifications attached.

model = mtd.export(distillation_model)

Next steps