fastgen
Modules
Rectified-flow (RF) helpers: forward process, inversions, timestep sampling. |
|
Pure loss functions used by the fastgen distillation pipelines. |
|
Optional plugins for the fastgen subpackage (gated via |
|
Small tensor helpers shared across the fastgen subpackage. |
Framework-agnostic diffusion step-distillation losses (FastGen port).
modelopt.torch.fastgen is a loss-computation library. It accepts already-built
nn.Module references (student / teacher / fake-score / optional discriminator) and
returns scalar loss tensors. It does not load models, manage optimizers, wrap
anything as a DynamicModule, or register itself in any mode registry.
Typical usage with a YAML-driven config:
import modelopt.torch.fastgen as mtf
student, teacher = build_wan_student_and_teacher(...)
fake_score = mtf.create_fake_score(teacher)
cfg = mtf.load_dmd_config("general/distillation/dmd2_wan22_5b")
# If GAN is enabled, expose intermediate teacher features to the discriminator.
if cfg.gan_loss_weight_gen > 0:
mtf.plugins.wan22.attach_feature_capture(teacher, feature_indices=[15, 22, 29])
pipeline = mtf.DMDPipeline(student, teacher, fake_score, cfg, discriminator=disc)
# Inside the training loop (framework-owned):
if step % cfg.student_update_freq == 0:
losses = pipeline.compute_student_loss(
latents, noise, text_embeds, negative_encoder_hidden_states=neg_embeds
)
losses["total"].backward()
student_opt.step()
pipeline.update_ema()
else:
f = pipeline.compute_fake_score_loss(latents, noise, text_embeds)
f["total"].backward()
fake_score_opt.step()
if disc is not None:
d = pipeline.compute_discriminator_loss(latents, noise, text_embeds)
d["total"].backward()
disc_opt.step()
Classes
DMD2 loss pipeline. |
|
Hold student/teacher references and expose shared utilities. |
|
FSDP2-aware EMA tracker for a PyTorch module. |
Functions
Return a trainable fake-score network initialized from the teacher. |
|
Load a YAML file and return the parsed mapping. |
|
Load a YAML file and construct a |
- ModeloptConfig DMDConfig
Bases:
DistillationConfigHyperparameters for DMD / DMD2 distribution-matching distillation.
Default values are tuned for Wan 2.2 5B; callers fine-tune them per model. See
FastGen/fastgen/configs/experiments/WanT2V/config_dmd2_wan22_5b.py.Show default config as JSON
- Default config (JSON):
{ "pred_type": "flow", "guidance_scale": null, "sample_t_cfg": null, "student_sample_steps": 1, "student_sample_type": "ode", "num_train_timesteps": null, "student_update_freq": 5, "fake_score_pred_type": "x0", "gan_loss_weight_gen": 0.0, "gan_use_same_t_noise": false, "gan_r1_reg_weight": 0.0, "gan_r1_reg_alpha": 0.1, "ema": null }
- field ema: EMAConfig | None
Show details
If set, an exponential moving average of the student is maintained and updated via
DMDPipeline.update_ema.
- field fake_score_pred_type: PredType | None
Show details
Parameterization used when training the fake score. If
Nonefalls back toDistillationConfig.pred_type.
- field gan_loss_weight_gen: float
Show details
Weight of the GAN generator term in the student loss.
0disables the GAN branch.
- field gan_r1_reg_alpha: float
Show details
Standard deviation of the perturbation applied to real data when computing the approximate R1 term.
- field gan_r1_reg_weight: float
Show details
Weight of the approximate-R1 regularization term for the discriminator update.
0disables R1. Recommended range when enabled: 100-1000.
- field gan_use_same_t_noise: bool
Show details
If True, reuse the same
tandepsfor real and fake samples in the discriminator update.
- field student_update_freq: int
Show details
One student step for every
student_update_freqfake-score / discriminator steps. Matches FastGen’s DMD2 alternation. Not read by DMDPipeline; the training loop is expected to enforce the alternation.
- classmethod from_yaml(config_file)
Construct a
DMDConfigfrom a YAML file.Thin wrapper around
modelopt.torch.fastgen.loader.load_dmd_config(). The resolver searches the built-inmodelopt_recipes/package first, then the filesystem. Suffixes (.yml/.yaml) may be omitted.- Parameters:
config_file (str | Path)
- Return type:
- class DMDPipeline
Bases:
DistillationPipelineDMD2 loss pipeline.
- Parameters:
student – Trainable student module. Must be callable with
(hidden_states, timestep, encoder_hidden_states=..., **kwargs)and return either aTensor, a(Tensor, ...)tuple (as diffusers returns withreturn_dict=False), or an object with a.sampleattribute.teacher – Frozen reference module with the same call signature. If
discriminatoris provided, feature-capture hooks must be attached toteacherbefore callingcompute_*_loss— seemodelopt.torch.fastgen.plugins.wan22.attach_feature_capture().fake_score – Trainable auxiliary module (same signature as teacher/student). Used to approximate the student’s generated distribution for the VSD gradient.
config –
DMDConfigwith the hyperparameters.discriminator – Optional discriminator. Required when
config.gan_loss_weight_gen > 0. Must acceptlist[Tensor](the captured teacher features) and return a 2D logit tensor.
- __init__(student, teacher, fake_score, config, *, discriminator=None)
Wire up student / teacher / fake-score / discriminator and create the EMA tracker.
- Parameters:
student (nn.Module)
teacher (nn.Module)
fake_score (nn.Module)
config (DMDConfig)
discriminator (nn.Module | None)
- Return type:
None
- compute_discriminator_loss(latents, noise, encoder_hidden_states=None, **model_kwargs)
Compute the discriminator update loss (GAN + optional R1).
Teacher and student forwards are wrapped in
torch.no_grad(); gradient flows only through the discriminator.Returns a dict with
"gan_disc"and"total". Whenconfig.gan_r1_reg_weight > 0the dict also contains"r1".- Parameters:
latents (Tensor)
noise (Tensor)
encoder_hidden_states (Tensor | None)
model_kwargs (Any)
- Return type:
dict[str, Tensor]
- compute_fake_score_loss(latents, noise, encoder_hidden_states=None, **model_kwargs)
Compute the fake-score (auxiliary) update loss.
The fake score is trained with denoising score matching against the student’s generated samples. The student forward is wrapped in
torch.no_grad()— the gradient here is w.r.t.fake_scoreonly.Returns a dict with
"fake_score"and"total"(both equal).- Parameters:
latents (Tensor)
noise (Tensor)
encoder_hidden_states (Tensor | None)
model_kwargs (Any)
- Return type:
dict[str, Tensor]
- compute_student_loss(latents, noise, encoder_hidden_states=None, *, negative_encoder_hidden_states=None, guidance_scale=None, **model_kwargs)
Compute the student update losses.
The returned dict always contains
"vsd"and"total". When the GAN branch is enabled (discriminator is not Noneandconfig.gan_loss_weight_gen > 0),"gan_gen"is also present.Gradient flow summary:
VSD gradient: flows through
studentonly (teacher_x0is detached,fake_score_x0is computed undertorch.no_grad()).GAN generator gradient: flows through
studentvia the feature-capture hooks on the teacher. The teacher forward is therefore not wrapped intorch.no_grad()when the GAN branch is active.
- Parameters:
latents (Tensor) – Real clean-data latents
x_0. Used only whenstudent_sample_steps > 1to constructinput_student.noise (Tensor) – Pure Gaussian noise tensor matching
latentsin shape/dtype.encoder_hidden_states (Tensor | None) – Positive conditioning passed unchanged to all three models.
negative_encoder_hidden_states (Tensor | None) – Negative conditioning used by classifier-free guidance. Required when
guidance_scale(orDMDConfig.guidance_scale) is notNone.guidance_scale (float | None) – Overrides
DMDConfig.guidance_scalefor this call.Nonekeeps the config-level value.**model_kwargs (Any) – Forwarded verbatim to
student,teacher, andfake_score.
- Returns:
Dictionary with keys
"vsd","total", and optionally"gan_gen".- Return type:
dict[str, Tensor]
- property ema: ExponentialMovingAverage | None
Reference to the student EMA tracker, if configured.
- update_ema(*, iteration=None)
Update the student EMA tracker (no-op if
config.emaisNone).Typically called after the student optimizer step. If
iterationis not provided, an internal counter is auto-incremented.- Parameters:
iteration (int | None)
- Return type:
None
- ModeloptConfig DistillationConfig
Bases:
ModeloptBaseConfigShared hyperparameters for diffusion step-distillation methods.
Concrete methods subclass this config to add method-specific fields (see
DMDConfig).Show default config as JSON
- Default config (JSON):
{ "pred_type": "flow", "guidance_scale": null, "sample_t_cfg": null, "student_sample_steps": 1, "student_sample_type": "ode", "num_train_timesteps": null }
- field guidance_scale: float | None
Show details
Classifier-free guidance scale. If
NoneCFG is disabled.
- field num_train_timesteps: int | None
Show details
If set, the pipeline rescales the continuous RF timestep
t ∈ [0, 1]tonum_train_timesteps * tbefore passing it to the model. Matches the diffusers convention used by Wan 2.2 / SD3 / Flux (num_train_timesteps = 1000). LeaveNonewhen the model wrapper already handles the rescaling internally.
- field pred_type: PredType
Show details
Quantity predicted by the teacher / student network.
- field sample_t_cfg: SampleTimestepConfig [Optional]
Show details
Timestep distribution used for both the teacher forward and the VSD / DSM losses.
- field student_sample_steps: int
Show details
Number of denoising steps the distilled student performs at inference.
- field student_sample_type: Literal['sde', 'ode']
Show details
Integrator used when unrolling the student over
student_sample_steps > 1steps. Not read by DMDPipeline at training time — consumed by inference samplers that unroll the student overstudent_sample_steps > 1steps.
- class DistillationPipeline
Bases:
objectHold student/teacher references and expose shared utilities.
- Parameters:
student – Trainable student module. The pipeline does not wrap it — its lifecycle (
train()/eval(),requires_grad_, sharding, optimizer) remains owned by the caller.teacher – Reference module. Frozen here via
eval()+requires_grad_(False).config – A
DistillationConfig(or subclass).
- __init__(student, teacher, config)
Store student / teacher references and freeze the teacher.
- Parameters:
student (nn.Module)
teacher (nn.Module)
config (DistillationConfig)
- Return type:
None
- property device: device
Device of the first student parameter (best-effort; falls back to CPU).
- property dtype: dtype
Dtype of the first student parameter (best-effort; falls back to float32).
- sample_timesteps(n, *, device=None, dtype=torch.float32)
Sample
ntraining timesteps according toconfig.``sample_t_cfg``.- Parameters:
n (int)
device (device | None)
dtype (dtype)
- Return type:
Tensor
- ModeloptConfig EMAConfig
Bases:
ModeloptBaseConfigExponential moving average (EMA) hyperparameters for the student network.
Show default config as JSON
- Default config (JSON):
{ "decay": 0.9999, "type": "constant", "start_iter": 0, "gamma": 16.97, "halflife_kimg": 500.0, "rampup_ratio": 0.05, "batch_size": 1, "fsdp2": true, "mode": "full_tensor", "dtype": "float32" }
- field batch_size: int
Show details
Per-step global batch size used to convert iterations to nimg for the halflife schedule.
- field decay: float
Show details
Decay coefficient for
type='constant'. Ignored forhalflife/power.
- field dtype: Literal['float32', 'bfloat16', 'float16'] | None
Show details
Precision of the EMA parameter shadows. Defaults to
float32so EMA updates remain numerically meaningful even when the live model is bf16/fp16 (cf. FastGen, which instantiates its EMA module in the net’s construction dtype — typically fp32). PassNoneto keep param shadows in the live parameter’s dtype. Buffer shadows always track the live dtype regardless of this setting.
- field fsdp2: bool
Show details
If True, the EMA uses
DTensor.full_tensor()to gather sharded parameters before updating.
- field gamma: float
Show details
Exponent for
type='power'(beta = (1 - 1/iter)**(gamma + 1)).
- field halflife_kimg: float
Show details
Halflife in thousands of images for
type='halflife'.
- field mode: Literal['full_tensor', 'local_shard']
Show details
full_tensorperforms an all_gather per parameter (higher memory, exact global EMA).local_shardupdates each rank’s local DTensor shard in place (low memory fallback).
- field rampup_ratio: float | None
Show details
Rampup fraction for
type='halflife'; passNoneto disable rampup.
- field start_iter: int
Show details
Iteration at which EMA tracking begins (EMA is initialized from the live weights at this step).
- field type: Literal['constant', 'halflife', 'power']
Show details
Schedule used to compute the per-step decay coefficient.
- class ExponentialMovingAverage
Bases:
objectFSDP2-aware EMA tracker for a PyTorch module.
The tracker stores a shadow state dict: parameters are promoted per
EMAConfig.dtype(default fp32) while buffers are kept in the live module’s dtype. Buffers are replicated across ranks and stepped viacopy_rather thanlerp_, so the bf16-roundoff argument that motivates parameter promotion doesn’t apply — preserving the live dtype makes the buffer restore exact.By default the tracker materialises the full tensor per parameter (
mode='full_tensor') so the EMA represents the globally averaged weights even when the model is sharded across ranks. Amode='local_shard'fallback is available for memory-constrained settings — it does not all-gather and therefore each rank holds an EMA of its local shard only.Example:
ema = ExponentialMovingAverage(student, EMAConfig(decay=0.999)) for step in range(max_steps): ... # compute loss, backward, optimizer.step() ema.update(student, iteration=step) ema.copy_to(student_for_eval) # publish for inference
- __init__(model, config)
Pre-allocate the shadow state from
model’s parameters and buffers.- Parameters:
model (nn.Module)
config (EMAConfig)
- Return type:
None
- copy_to(target)
Load the shadow state into
target(which should share the tracked module’s structure).The target is expected to be an unsharded module (i.e. the caller has unwrapped any FSDP2 wrappers before calling). For sharded targets, prefer saving the shadow via
state_dict()and reloading it through the framework’s usual checkpoint path.- Parameters:
target (Module)
- Return type:
None
- load_state_dict(state)
Restore the shadow state from a previously saved dict.
- Parameters:
state (dict[str, Tensor])
- Return type:
None
- state_dict()
Return the shadow state (parameters + buffers) for checkpointing.
- Return type:
dict[str, Tensor]
- update(model, *, iteration)
Update the shadow state from
modelat the given iteration.Skips updates before
EMAConfig.start_iter. On the iteration that equalsstart_iterthe shadow is (re-)initialised from the live weights; after that it is updated withshadow = beta * shadow + (1 - beta) * live.- Parameters:
model (Module)
iteration (int)
- Return type:
None
- ModeloptConfig SampleTimestepConfig
Bases:
ModeloptBaseConfigTimestep sampling distribution for diffusion training.
Show default config as JSON
- Default config (JSON):
{ "time_dist_type": "shifted", "min_t": 0.001, "max_t": 0.999, "shift": 5.0, "p_mean": 0.0, "p_std": 1.0, "t_list": null }
- field max_t: float
Show details
Upper bound of the sampling range (clamped before use).
- field min_t: float
Show details
Lower bound of the sampling range (clamped before use).
- field p_mean: float
Show details
Mean of the underlying normal for
logitnormal/lognormal.
- field p_std: float
Show details
Standard deviation of the underlying normal for
logitnormal/lognormal.
- field shift: float
Show details
Shift factor for
time_dist_type='shifted'; must be >= 1.
- field t_list: list[float] | None
Show details
Explicit timestep schedule used when
DMDConfig.student_sample_steps > 1. The final element must be0.0.
- field time_dist_type: TimeDistType
Show details
Distribution used to sample the training timestep
t. Rectified-flow models typically useshifted(Wan 2.2) orlogitnormal(SD3, Flux).
- create_fake_score(teacher, *, deep_copy=True)
Return a trainable fake-score network initialized from the teacher.
This is the unit-test / single-script path; frameworks that do meta-init + FSDP2 wrapping will typically construct the fake score themselves and pass it directly into
DMDPipeline.- Parameters:
teacher (nn.Module) – The already-built teacher module. Must already have its weights loaded.
deep_copy (bool) – If True,
copy.deepcopy()the teacher; if False, reuse the same instance (only sensible if the caller can guarantee it is no longer held elsewhere as the frozen teacher).
- Returns:
A copy of
teacherin training mode with all parameters requiring gradients.- Return type:
nn.Module
FSDP2 caveat
copy.deepcopy(teacher)is not safe when the teacher is already FSDP2-wrapped (DTensor parameters + FSDP pre/post hooks + meta-init bookkeeping). For Stage-2 FSDP2 training, skip this factory and construct the fake score under meta-init, then rank-0-load weights and letsync_module_statesbroadcast:with meta_init_context(): fake_score = build_teacher_from_config(teacher_config) if is_rank0(): fake_score.load_state_dict(teacher.state_dict(), strict=False) # Wrap with FSDP2(..., sync_module_states=True) to broadcast from rank 0.
The pattern mirrors FastGen’s
methods/distribution_matching/dmd2.py::DMD2Model.build_model. A dedicatedcreate_fake_score_metafactory is planned alongside the Stage-2 training example.- raises RuntimeError:
When
deep_copy=Trueand the teacher looks FSDP-wrapped (either FSDP1 via_fsdp_wrapped_moduleor FSDP2 via DTensor parameters). Thedeep_copy=Falsebranch skips the check because reusing the teacher directly is compatible with an FSDP-wrapped input.
- load_config(config_file)
Load a YAML file and return the parsed mapping.
Mirrors
modelopt.recipe._config_loader.load_config()in spirit but without the ExMy-num-bits post-processing that is specific to quantization recipes.- Parameters:
config_file (str | Path) – YAML path. Suffix is optional; resolution searches the built-in
modelopt_recipes/package first, then the filesystem.- Returns:
The parsed dictionary. An empty file yields
{}.- Return type:
dict[str, Any]
- load_dmd_config(config_file)
Load a YAML file and construct a
DMDConfig.The YAML is validated against
DMDConfig’s Pydantic schema — unknown keys raiseValidationError.Example YAML:
pred_type: flow guidance_scale: 5.0 student_sample_steps: 2 gan_loss_weight_gen: 0.03 sample_t_cfg: time_dist_type: shifted t_list: [0.999, 0.833, 0.0] ema: decay: 0.9999
- Parameters:
config_file (str | Path)
- Return type: