pipeline

Base class for diffusion step-distillation pipelines.

DistillationPipeline is deliberately minimal: it is not an nn.Module, does not wrap the student or teacher, does not manage optimizers or lifecycle state, and does not register itself in any mode registry. It exists only to hold references to the student / teacher and to freeze the teacher in a single place.

Concrete methods — for now DMDPipeline — subclass this and add compute_*_loss methods.

Classes

DistillationPipeline

Hold student/teacher references and expose shared utilities.

class DistillationPipeline

Bases: object

Hold student/teacher references and expose shared utilities.

Parameters:
  • student – Trainable student module. The pipeline does not wrap it — its lifecycle (train() / eval(), requires_grad_, sharding, optimizer) remains owned by the caller.

  • teacher – Reference module. Frozen here via eval() + requires_grad_(False).

  • config – A DistillationConfig (or subclass).

__init__(student, teacher, config)

Store student / teacher references and freeze the teacher.

Parameters:
Return type:

None

property device: device

Device of the first student parameter (best-effort; falls back to CPU).

property dtype: dtype

Dtype of the first student parameter (best-effort; falls back to float32).

sample_timesteps(n, *, device=None, dtype=torch.float32)

Sample n training timesteps according to config.``sample_t_cfg``.

Parameters:
  • n (int)

  • device (device | None)

  • dtype (dtype)

Return type:

Tensor