distillation_model
Meta-model wrapper to support knowledge-distillation learning.
Classes
Class to encapsulate multiple teacher and student models as a single model. |
- class DistillationModel
Bases:
DynamicModule
Class to encapsulate multiple teacher and student models as a single model.
- compute_kd_loss(student_loss=None, loss_reduction_fn=None, skip_balancer=False)
Compute total loss for distillation backpropagation.
- Parameters:
student_loss (Tensor | None) – Original loss computed from the student’s output.
loss_reduction_fn (Callable) – Callable to be called on each loss tensor prior to balancing. Useful for loss-masking situations where the callable changes arguments each iteration.
skip_balancer (bool) – Whether or not to use loss balancer to reduce the loss dict into a scalar.
- Returns:
If reduce is True, the scalar total loss weighted between
student_loss
and the distillation losses. If reduce is False, a dict of student model output loss and layer-wise distillation losses.- Return type:
Tensor | Dict[str, Tensor]
- forward(*args, **kwargs)
Implement forward pass.
- Parameters:
*args – Positional inputs to the student and teacher model.
**kwargs – Named inputs to the student and teacher model.
- Returns:
The student model’s output.
- Return type:
Any
- hide_loss_modules(enable=True)
Context manager to temporarily hide teacher model from the model.
- hide_teacher_model(enable=True)
Context manager to temporarily hide teacher model from the model.
- load_state_dict(state_dict, *args, **kwargs)
Override to potentially load the state without teacher’s or loss modules’.
- Return type:
Any
- property loss_balancer: DistillationLossBalancer | None
Fetch the loss balancer, if any.
- property loss_modules: ModuleList
Fetch the loss modules list.
- modify(teacher_model, criterion, loss_balancer=None, expose_minimal_state_dict=True)
Constructor.
- Parameters:
teacher_model (Module) – A teacher model which this class would encapsulate.
criterion (Dict[Tuple[str, str], _Loss]) – A dictionary mapping the tuple of student and teacher model layer names to the loss function to apply to that layer pair.
loss_balancer (DistillationLossBalancer | None) – Instance of
DistillationLossBalancer
which reduces distillation and non-distillation losses into a single value using some weighing scheme.expose_minimal_state_dict (bool) – If True, will hide teacher’s state dict when calling
state_dict
on this class. This allows avoiding to save the teacher state unnecessarily during checkpointing. .. note: Set to False if using FSDP
- state_dict(*args, **kwargs)
Override to potentially return the state without teacher’s.
- Return type:
Dict[str, Any]
- property teacher_model: ModuleList
Fetch the teacher model.