config

Pydantic configuration classes for the fastgen distillation pipelines.

Configurations are layered so a method-specific config (e.g. DMDConfig) inherits shared diffusion-distillation hyperparameters from DistillationConfig. All classes inherit modelopt.torch.opt.config.ModeloptBaseConfig, which provides torch-safe serialization and dict-like iteration.

The default values in DMDConfig mirror the FastGen Wan 2.2 5B experiment at FastGen/fastgen/configs/experiments/WanT2V/config_dmd2_wan22_5b.py.

ModeloptConfig DMDConfig

Bases: DistillationConfig

Hyperparameters for DMD / DMD2 distribution-matching distillation.

Default values are tuned for Wan 2.2 5B; callers fine-tune them per model. See FastGen/fastgen/configs/experiments/WanT2V/config_dmd2_wan22_5b.py.

Show default config as JSON
Default config (JSON):

{
   "pred_type": "flow",
   "guidance_scale": null,
   "sample_t_cfg": null,
   "student_sample_steps": 1,
   "student_sample_type": "ode",
   "num_train_timesteps": null,
   "student_update_freq": 5,
   "fake_score_pred_type": "x0",
   "backward_simulation": false,
   "gan_loss_weight_gen": 0.0,
   "gan_use_same_t_noise": false,
   "gan_r1_reg_weight": 0.0,
   "gan_r1_reg_alpha": 0.1,
   "ema": null
}

field backward_simulation: bool

Show details

When True for multi-step students, build the selected student input by no-grad unrolling the current student from the first schedule rung through earlier rungs, then re-noising the generated x0 at the selected rung. When False, use FastGen’s Qwen-style noised-real latent path.

field ema: EMAConfig | None

Show details

If set, an exponential moving average of the student is maintained and updated via DMDPipeline.update_ema.

field fake_score_pred_type: PredType | None

Show details

Parameterization used when training the fake score. If None falls back to DistillationConfig.pred_type.

field gan_loss_weight_gen: float

Show details

Weight of the GAN generator term in the student loss. 0 disables the GAN branch.

field gan_r1_reg_alpha: float

Show details

Standard deviation of the perturbation applied to real data when computing the approximate R1 term.

field gan_r1_reg_weight: float

Show details

Weight of the approximate-R1 regularization term for the discriminator update. 0 disables R1. Recommended range when enabled: 100-1000.

field gan_use_same_t_noise: bool

Show details

If True, reuse the same t and eps for real and fake samples in the discriminator update.

field student_update_freq: int

Show details

One student step for every student_update_freq fake-score / discriminator steps. Matches FastGen’s DMD2 alternation. Not read by DMDPipeline; the training loop is expected to enforce the alternation.

classmethod from_yaml(config_file)

Construct a DMDConfig from a YAML file.

Thin wrapper around modelopt.torch.fastgen.loader.load_dmd_config(). The resolver searches the built-in modelopt_recipes/ package first, then the filesystem. Suffixes (.yml / .yaml) may be omitted.

Parameters:

config_file (str | Path)

Return type:

DMDConfig

ModeloptConfig DistillationConfig

Bases: ModeloptBaseConfig

Shared hyperparameters for diffusion step-distillation methods.

Concrete methods subclass this config to add method-specific fields (see DMDConfig).

Show default config as JSON
Default config (JSON):

{
   "pred_type": "flow",
   "guidance_scale": null,
   "sample_t_cfg": null,
   "student_sample_steps": 1,
   "student_sample_type": "ode",
   "num_train_timesteps": null
}

field guidance_scale: float | None

Show details

Classifier-free guidance scale. If None CFG is disabled.

field num_train_timesteps: int | None

Show details

If set, the pipeline rescales the continuous RF timestep t [0, 1] to num_train_timesteps * t before passing it to the model. Matches the diffusers convention used by Wan 2.2 / SD3 / Flux (num_train_timesteps = 1000). Leave None when the model wrapper already handles the rescaling internally.

field pred_type: PredType

Show details

Quantity predicted by the teacher / student network.

field sample_t_cfg: SampleTimestepConfig [Optional]

Show details

Timestep distribution used for both the teacher forward and the VSD / DSM losses.

field student_sample_steps: int

Show details

Number of denoising steps the distilled student performs at inference.

field student_sample_type: Literal['sde', 'ode']

Show details

Integrator used when unrolling the student over student_sample_steps > 1 steps. Consumed by inference samplers and by DMDPipeline when DMDConfig.backward_simulation is enabled.

ModeloptConfig EMAConfig

Bases: ModeloptBaseConfig

Exponential moving average (EMA) hyperparameters for the student network.

Show default config as JSON
Default config (JSON):

{
   "decay": 0.9999,
   "type": "constant",
   "start_iter": 0,
   "gamma": 16.97,
   "halflife_kimg": 500.0,
   "rampup_ratio": 0.05,
   "batch_size": 1,
   "fsdp2": true,
   "mode": "full_tensor",
   "dtype": "float32"
}

field batch_size: int

Show details

Per-step global batch size used to convert iterations to nimg for the halflife schedule.

field decay: float

Show details

Decay coefficient for type='constant'. Ignored for halflife/power.

field dtype: Literal['float32', 'bfloat16', 'float16'] | None

Show details

Precision of the EMA parameter shadows. Defaults to float32 so EMA updates remain numerically meaningful even when the live model is bf16/fp16 (cf. FastGen, which instantiates its EMA module in the net’s construction dtype — typically fp32). Pass None to keep param shadows in the live parameter’s dtype. Buffer shadows always track the live dtype regardless of this setting.

field fsdp2: bool

Show details

If True, the EMA uses DTensor.full_tensor() to gather sharded parameters before updating.

field gamma: float

Show details

Exponent for type='power' (beta = (1 - 1/iter)**(gamma + 1)).

field halflife_kimg: float

Show details

Halflife in thousands of images for type='halflife'.

field mode: Literal['full_tensor', 'local_shard']

Show details

full_tensor performs an all_gather per parameter (higher memory, exact global EMA). local_shard updates each rank’s local DTensor shard in place (low memory fallback).

field rampup_ratio: float | None

Show details

Rampup fraction for type='halflife'; pass None to disable rampup.

field start_iter: int

Show details

Iteration at which EMA tracking begins (EMA is initialized from the live weights at this step).

field type: Literal['constant', 'halflife', 'power']

Show details

Schedule used to compute the per-step decay coefficient.

ModeloptConfig SampleTimestepConfig

Bases: ModeloptBaseConfig

Timestep sampling distribution for diffusion training.

Show default config as JSON
Default config (JSON):

{
   "time_dist_type": "shifted",
   "min_t": 0.001,
   "max_t": 0.999,
   "shift": 5.0,
   "p_mean": 0.0,
   "p_std": 1.0,
   "t_list": null
}

field max_t: float

Show details

Upper bound of the sampling range (clamped before use).

field min_t: float

Show details

Lower bound of the sampling range (clamped before use).

field p_mean: float

Show details

Mean of the underlying normal for logitnormal / lognormal.

field p_std: float

Show details

Standard deviation of the underlying normal for logitnormal / lognormal.

field shift: float

Show details

Shift factor for time_dist_type='shifted'; must be >= 1.

field t_list: list[float] | None

Show details

Explicit timestep schedule used when DMDConfig.student_sample_steps > 1. The final element must be 0.0.

field time_dist_type: TimeDistType

Show details

Distribution used to sample the training timestep t. Rectified-flow models typically use shifted (Wan 2.2) or logitnormal (SD3, Flux).