pipeline
Base class for diffusion step-distillation pipelines.
DistillationPipeline is deliberately minimal: it is not an nn.Module,
does not wrap the student or teacher, does not manage optimizers or lifecycle state,
and does not register itself in any mode registry. It exists only to hold references
to the student / teacher and to freeze the teacher in a single place.
Concrete methods — for now DMDPipeline —
subclass this and add compute_*_loss methods.
Classes
Hold student/teacher references and expose shared utilities. |
- class DistillationPipeline
Bases:
objectHold student/teacher references and expose shared utilities.
- Parameters:
student – Trainable student module. The pipeline does not wrap it — its lifecycle (
train()/eval(),requires_grad_, sharding, optimizer) remains owned by the caller.teacher – Reference module. Frozen here via
eval()+requires_grad_(False).config – A
DistillationConfig(or subclass).
- __init__(student, teacher, config)
Store student / teacher references and freeze the teacher.
- Parameters:
student (nn.Module)
teacher (nn.Module)
config (DistillationConfig)
- Return type:
None
- property device: device
Device of the first student parameter (best-effort; falls back to CPU).
- property dtype: dtype
Dtype of the first student parameter (best-effort; falls back to float32).
- sample_timesteps(n, *, device=None, dtype=torch.float32)
Sample
ntraining timesteps according toconfig.``sample_t_cfg``.- Parameters:
n (int)
device (device | None)
dtype (dtype)
- Return type:
Tensor