distillation_model

Meta-model wrapper to support knowledge-distillation learning.

Classes

DistillationModel

Class to encapsulate multiple teacher and student models as a single model.

class DistillationModel

Bases: DynamicModule

Class to encapsulate multiple teacher and student models as a single model.

compute_kd_loss(student_loss=None, loss_reduction_fn=None, skip_balancer=False)

Compute total loss for distillation backpropagation.

Parameters:
  • student_loss (Tensor | None) – Original loss computed from the student’s output.

  • loss_reduction_fn (Callable) – Callable to be called on each loss tensor prior to balancing. Useful for loss-masking situations where the callable changes arguments each iteration.

  • skip_balancer (bool) – Whether or not to use loss balancer to reduce the loss dict into a scalar.

Returns:

If reduce is True, the scalar total loss weighted between student_loss and the distillation losses. If reduce is False, a dict of student model output loss and layer-wise distillation losses.

Return type:

Tensor | Dict[str, Tensor]

forward(*args, **kwargs)

Implement forward pass.

Parameters:
  • *args – Positional inputs to the student and teacher model.

  • **kwargs – Named inputs to the student and teacher model.

Returns:

The student model’s output.

Return type:

Any

hide_loss_modules(enable=True)

Context manager to temporarily hide teacher model from the model.

hide_teacher_model(enable=True)

Context manager to temporarily hide teacher model from the model.

load_state_dict(state_dict, *args, **kwargs)

Override to potentially load the state without teacher’s or loss modules’.

Return type:

Any

property loss_balancer: DistillationLossBalancer | None

Fetch the loss balancer, if any.

property loss_modules: ModuleList

Fetch the loss modules list.

modify(teacher_model, criterion, loss_balancer=None, expose_minimal_state_dict=True)

Constructor.

Parameters:
  • teacher_model (Module) – A teacher model which this class would encapsulate.

  • criterion (Dict[Tuple[str, str], _Loss]) – A dictionary mapping the tuple of student and teacher model layer names to the loss function to apply to that layer pair.

  • loss_balancer (DistillationLossBalancer | None) – Instance of DistillationLossBalancer which reduces distillation and non-distillation losses into a single value using some weighing scheme.

  • expose_minimal_state_dict (bool) – If True, will hide teacher’s state dict when calling state_dict on this class. This allows avoiding to save the teacher state unnecessarily during checkpointing. .. note: Set to False if using FSDP

only_student_forward(enable=True)

Context manager to temporarily disable forward passes on the student model.

only_teacher_forward(enable=True)

Context manager to temporarily disable forward passes on the student model.

state_dict(*args, **kwargs)

Override to potentially return the state without teacher’s.

Return type:

Dict[str, Any]

property teacher_model: ModuleList

Fetch the teacher model.

train(mode=True)

Override to prevent warnings of stored intermediate outputs in future forwards.

Parameters:

mode (bool) –