config
Pydantic configuration classes for the fastgen distillation pipelines.
Configurations are layered so a method-specific config (e.g. DMDConfig) inherits
shared diffusion-distillation hyperparameters from DistillationConfig. All classes
inherit modelopt.torch.opt.config.ModeloptBaseConfig, which provides torch-safe
serialization and dict-like iteration.
The default values in DMDConfig mirror the FastGen Wan 2.2 5B experiment at
FastGen/fastgen/configs/experiments/WanT2V/config_dmd2_wan22_5b.py.
- ModeloptConfig DMDConfig
Bases:
DistillationConfigHyperparameters for DMD / DMD2 distribution-matching distillation.
Default values are tuned for Wan 2.2 5B; callers fine-tune them per model. See
FastGen/fastgen/configs/experiments/WanT2V/config_dmd2_wan22_5b.py.Show default config as JSON
- Default config (JSON):
{ "pred_type": "flow", "guidance_scale": null, "sample_t_cfg": null, "student_sample_steps": 1, "student_sample_type": "ode", "num_train_timesteps": null, "student_update_freq": 5, "fake_score_pred_type": "x0", "backward_simulation": false, "gan_loss_weight_gen": 0.0, "gan_use_same_t_noise": false, "gan_r1_reg_weight": 0.0, "gan_r1_reg_alpha": 0.1, "ema": null }
- field backward_simulation: bool
Show details
When True for multi-step students, build the selected student input by no-grad unrolling the current student from the first schedule rung through earlier rungs, then re-noising the generated x0 at the selected rung. When False, use FastGen’s Qwen-style noised-real latent path.
- field ema: EMAConfig | None
Show details
If set, an exponential moving average of the student is maintained and updated via
DMDPipeline.update_ema.
- field fake_score_pred_type: PredType | None
Show details
Parameterization used when training the fake score. If
Nonefalls back toDistillationConfig.pred_type.
- field gan_loss_weight_gen: float
Show details
Weight of the GAN generator term in the student loss.
0disables the GAN branch.
- field gan_r1_reg_alpha: float
Show details
Standard deviation of the perturbation applied to real data when computing the approximate R1 term.
- field gan_r1_reg_weight: float
Show details
Weight of the approximate-R1 regularization term for the discriminator update.
0disables R1. Recommended range when enabled: 100-1000.
- field gan_use_same_t_noise: bool
Show details
If True, reuse the same
tandepsfor real and fake samples in the discriminator update.
- field student_update_freq: int
Show details
One student step for every
student_update_freqfake-score / discriminator steps. Matches FastGen’s DMD2 alternation. Not read by DMDPipeline; the training loop is expected to enforce the alternation.
- classmethod from_yaml(config_file)
Construct a
DMDConfigfrom a YAML file.Thin wrapper around
modelopt.torch.fastgen.loader.load_dmd_config(). The resolver searches the built-inmodelopt_recipes/package first, then the filesystem. Suffixes (.yml/.yaml) may be omitted.- Parameters:
config_file (str | Path)
- Return type:
- ModeloptConfig DistillationConfig
Bases:
ModeloptBaseConfigShared hyperparameters for diffusion step-distillation methods.
Concrete methods subclass this config to add method-specific fields (see
DMDConfig).Show default config as JSON
- Default config (JSON):
{ "pred_type": "flow", "guidance_scale": null, "sample_t_cfg": null, "student_sample_steps": 1, "student_sample_type": "ode", "num_train_timesteps": null }
- field guidance_scale: float | None
Show details
Classifier-free guidance scale. If
NoneCFG is disabled.
- field num_train_timesteps: int | None
Show details
If set, the pipeline rescales the continuous RF timestep
t ∈ [0, 1]tonum_train_timesteps * tbefore passing it to the model. Matches the diffusers convention used by Wan 2.2 / SD3 / Flux (num_train_timesteps = 1000). LeaveNonewhen the model wrapper already handles the rescaling internally.
- field pred_type: PredType
Show details
Quantity predicted by the teacher / student network.
- field sample_t_cfg: SampleTimestepConfig [Optional]
Show details
Timestep distribution used for both the teacher forward and the VSD / DSM losses.
- field student_sample_steps: int
Show details
Number of denoising steps the distilled student performs at inference.
- field student_sample_type: Literal['sde', 'ode']
Show details
Integrator used when unrolling the student over
student_sample_steps > 1steps. Consumed by inference samplers and by DMDPipeline whenDMDConfig.backward_simulationis enabled.
- ModeloptConfig EMAConfig
Bases:
ModeloptBaseConfigExponential moving average (EMA) hyperparameters for the student network.
Show default config as JSON
- Default config (JSON):
{ "decay": 0.9999, "type": "constant", "start_iter": 0, "gamma": 16.97, "halflife_kimg": 500.0, "rampup_ratio": 0.05, "batch_size": 1, "fsdp2": true, "mode": "full_tensor", "dtype": "float32" }
- field batch_size: int
Show details
Per-step global batch size used to convert iterations to nimg for the halflife schedule.
- field decay: float
Show details
Decay coefficient for
type='constant'. Ignored forhalflife/power.
- field dtype: Literal['float32', 'bfloat16', 'float16'] | None
Show details
Precision of the EMA parameter shadows. Defaults to
float32so EMA updates remain numerically meaningful even when the live model is bf16/fp16 (cf. FastGen, which instantiates its EMA module in the net’s construction dtype — typically fp32). PassNoneto keep param shadows in the live parameter’s dtype. Buffer shadows always track the live dtype regardless of this setting.
- field fsdp2: bool
Show details
If True, the EMA uses
DTensor.full_tensor()to gather sharded parameters before updating.
- field gamma: float
Show details
Exponent for
type='power'(beta = (1 - 1/iter)**(gamma + 1)).
- field halflife_kimg: float
Show details
Halflife in thousands of images for
type='halflife'.
- field mode: Literal['full_tensor', 'local_shard']
Show details
full_tensorperforms an all_gather per parameter (higher memory, exact global EMA).local_shardupdates each rank’s local DTensor shard in place (low memory fallback).
- field rampup_ratio: float | None
Show details
Rampup fraction for
type='halflife'; passNoneto disable rampup.
- field start_iter: int
Show details
Iteration at which EMA tracking begins (EMA is initialized from the live weights at this step).
- field type: Literal['constant', 'halflife', 'power']
Show details
Schedule used to compute the per-step decay coefficient.
- ModeloptConfig SampleTimestepConfig
Bases:
ModeloptBaseConfigTimestep sampling distribution for diffusion training.
Show default config as JSON
- Default config (JSON):
{ "time_dist_type": "shifted", "min_t": 0.001, "max_t": 0.999, "shift": 5.0, "p_mean": 0.0, "p_std": 1.0, "t_list": null }
- field max_t: float
Show details
Upper bound of the sampling range (clamped before use).
- field min_t: float
Show details
Lower bound of the sampling range (clamped before use).
- field p_mean: float
Show details
Mean of the underlying normal for
logitnormal/lognormal.
- field p_std: float
Show details
Standard deviation of the underlying normal for
logitnormal/lognormal.
- field shift: float
Show details
Shift factor for
time_dist_type='shifted'; must be >= 1.
- field t_list: list[float] | None
Show details
Explicit timestep schedule used when
DMDConfig.student_sample_steps > 1. The final element must be0.0.
- field time_dist_type: TimeDistType
Show details
Distribution used to sample the training timestep
t. Rectified-flow models typically useshifted(Wan 2.2) orlogitnormal(SD3, Flux).