fastgen

Modules

modelopt.torch.fastgen.flow_matching

Rectified-flow (RF) helpers: forward process, inversions, timestep sampling.

modelopt.torch.fastgen.losses

Pure loss functions used by the fastgen distillation pipelines.

modelopt.torch.fastgen.plugins

Optional plugins for the fastgen subpackage (gated via import_plugin).

modelopt.torch.fastgen.utils

Small tensor helpers shared across the fastgen subpackage.

Framework-agnostic diffusion step-distillation losses (FastGen port).

modelopt.torch.fastgen is a loss-computation library. It accepts already-built nn.Module references (student / teacher / fake-score / optional discriminator) and returns scalar loss tensors. It does not load models, manage optimizers, wrap anything as a DynamicModule, or register itself in any mode registry.

Typical usage with a YAML-driven config:

import modelopt.torch.fastgen as mtf

student, teacher = build_wan_student_and_teacher(...)
fake_score = mtf.create_fake_score(teacher)

cfg = mtf.load_dmd_config("general/distillation/dmd2_wan22_5b")

# If GAN is enabled, expose intermediate teacher features to the discriminator.
if cfg.gan_loss_weight_gen > 0:
    mtf.plugins.wan22.attach_feature_capture(teacher, feature_indices=[15, 22, 29])

pipeline = mtf.DMDPipeline(student, teacher, fake_score, cfg, discriminator=disc)

# Inside the training loop (framework-owned):
if step % cfg.student_update_freq == 0:
    losses = pipeline.compute_student_loss(
        latents, noise, text_embeds, negative_encoder_hidden_states=neg_embeds
    )
    losses["total"].backward()
    student_opt.step()
    pipeline.update_ema()
else:
    f = pipeline.compute_fake_score_loss(latents, noise, text_embeds)
    f["total"].backward()
    fake_score_opt.step()
    if disc is not None:
        d = pipeline.compute_discriminator_loss(latents, noise, text_embeds)
        d["total"].backward()
        disc_opt.step()

Classes

DMDPipeline

DMD2 loss pipeline.

DistillationPipeline

Hold student/teacher references and expose shared utilities.

ExponentialMovingAverage

FSDP2-aware EMA tracker for a PyTorch module.

Functions

create_fake_score

Return a trainable fake-score network initialized from the teacher.

load_config

Load a YAML file and return the parsed mapping.

load_dmd_config

Load a YAML file and construct a DMDConfig.

ModeloptConfig DMDConfig

Bases: DistillationConfig

Hyperparameters for DMD / DMD2 distribution-matching distillation.

Default values are tuned for Wan 2.2 5B; callers fine-tune them per model. See FastGen/fastgen/configs/experiments/WanT2V/config_dmd2_wan22_5b.py.

Show default config as JSON
Default config (JSON):

{
   "pred_type": "flow",
   "guidance_scale": null,
   "sample_t_cfg": null,
   "student_sample_steps": 1,
   "student_sample_type": "ode",
   "num_train_timesteps": null,
   "student_update_freq": 5,
   "fake_score_pred_type": "x0",
   "gan_loss_weight_gen": 0.0,
   "gan_use_same_t_noise": false,
   "gan_r1_reg_weight": 0.0,
   "gan_r1_reg_alpha": 0.1,
   "ema": null
}

field ema: EMAConfig | None

Show details

If set, an exponential moving average of the student is maintained and updated via DMDPipeline.update_ema.

field fake_score_pred_type: PredType | None

Show details

Parameterization used when training the fake score. If None falls back to DistillationConfig.pred_type.

field gan_loss_weight_gen: float

Show details

Weight of the GAN generator term in the student loss. 0 disables the GAN branch.

field gan_r1_reg_alpha: float

Show details

Standard deviation of the perturbation applied to real data when computing the approximate R1 term.

field gan_r1_reg_weight: float

Show details

Weight of the approximate-R1 regularization term for the discriminator update. 0 disables R1. Recommended range when enabled: 100-1000.

field gan_use_same_t_noise: bool

Show details

If True, reuse the same t and eps for real and fake samples in the discriminator update.

field student_update_freq: int

Show details

One student step for every student_update_freq fake-score / discriminator steps. Matches FastGen’s DMD2 alternation. Not read by DMDPipeline; the training loop is expected to enforce the alternation.

classmethod from_yaml(config_file)

Construct a DMDConfig from a YAML file.

Thin wrapper around modelopt.torch.fastgen.loader.load_dmd_config(). The resolver searches the built-in modelopt_recipes/ package first, then the filesystem. Suffixes (.yml / .yaml) may be omitted.

Parameters:

config_file (str | Path)

Return type:

DMDConfig

class DMDPipeline

Bases: DistillationPipeline

DMD2 loss pipeline.

Parameters:
  • student – Trainable student module. Must be callable with (hidden_states, timestep, encoder_hidden_states=..., **kwargs) and return either a Tensor, a (Tensor, ...) tuple (as diffusers returns with return_dict=False), or an object with a .sample attribute.

  • teacher – Frozen reference module with the same call signature. If discriminator is provided, feature-capture hooks must be attached to teacher before calling compute_*_loss — see modelopt.torch.fastgen.plugins.wan22.attach_feature_capture().

  • fake_score – Trainable auxiliary module (same signature as teacher/student). Used to approximate the student’s generated distribution for the VSD gradient.

  • configDMDConfig with the hyperparameters.

  • discriminator – Optional discriminator. Required when config.gan_loss_weight_gen > 0. Must accept list[Tensor] (the captured teacher features) and return a 2D logit tensor.

__init__(student, teacher, fake_score, config, *, discriminator=None)

Wire up student / teacher / fake-score / discriminator and create the EMA tracker.

Parameters:
  • student (nn.Module)

  • teacher (nn.Module)

  • fake_score (nn.Module)

  • config (DMDConfig)

  • discriminator (nn.Module | None)

Return type:

None

compute_discriminator_loss(latents, noise, encoder_hidden_states=None, **model_kwargs)

Compute the discriminator update loss (GAN + optional R1).

Teacher and student forwards are wrapped in torch.no_grad(); gradient flows only through the discriminator.

Returns a dict with "gan_disc" and "total". When config.gan_r1_reg_weight > 0 the dict also contains "r1".

Parameters:
  • latents (Tensor)

  • noise (Tensor)

  • encoder_hidden_states (Tensor | None)

  • model_kwargs (Any)

Return type:

dict[str, Tensor]

compute_fake_score_loss(latents, noise, encoder_hidden_states=None, **model_kwargs)

Compute the fake-score (auxiliary) update loss.

The fake score is trained with denoising score matching against the student’s generated samples. The student forward is wrapped in torch.no_grad() — the gradient here is w.r.t. fake_score only.

Returns a dict with "fake_score" and "total" (both equal).

Parameters:
  • latents (Tensor)

  • noise (Tensor)

  • encoder_hidden_states (Tensor | None)

  • model_kwargs (Any)

Return type:

dict[str, Tensor]

compute_student_loss(latents, noise, encoder_hidden_states=None, *, negative_encoder_hidden_states=None, guidance_scale=None, **model_kwargs)

Compute the student update losses.

The returned dict always contains "vsd" and "total". When the GAN branch is enabled (discriminator is not None and config.gan_loss_weight_gen > 0), "gan_gen" is also present.

Gradient flow summary:

  • VSD gradient: flows through student only (teacher_x0 is detached, fake_score_x0 is computed under torch.no_grad()).

  • GAN generator gradient: flows through student via the feature-capture hooks on the teacher. The teacher forward is therefore not wrapped in torch.no_grad() when the GAN branch is active.

Parameters:
  • latents (Tensor) – Real clean-data latents x_0. Used only when student_sample_steps > 1 to construct input_student.

  • noise (Tensor) – Pure Gaussian noise tensor matching latents in shape/dtype.

  • encoder_hidden_states (Tensor | None) – Positive conditioning passed unchanged to all three models.

  • negative_encoder_hidden_states (Tensor | None) – Negative conditioning used by classifier-free guidance. Required when guidance_scale (or DMDConfig.guidance_scale) is not None.

  • guidance_scale (float | None) – Overrides DMDConfig.guidance_scale for this call. None keeps the config-level value.

  • **model_kwargs (Any) – Forwarded verbatim to student, teacher, and fake_score.

Returns:

Dictionary with keys "vsd", "total", and optionally "gan_gen".

Return type:

dict[str, Tensor]

config: DMDConfig
property ema: ExponentialMovingAverage | None

Reference to the student EMA tracker, if configured.

update_ema(*, iteration=None)

Update the student EMA tracker (no-op if config.ema is None).

Typically called after the student optimizer step. If iteration is not provided, an internal counter is auto-incremented.

Parameters:

iteration (int | None)

Return type:

None

ModeloptConfig DistillationConfig

Bases: ModeloptBaseConfig

Shared hyperparameters for diffusion step-distillation methods.

Concrete methods subclass this config to add method-specific fields (see DMDConfig).

Show default config as JSON
Default config (JSON):

{
   "pred_type": "flow",
   "guidance_scale": null,
   "sample_t_cfg": null,
   "student_sample_steps": 1,
   "student_sample_type": "ode",
   "num_train_timesteps": null
}

field guidance_scale: float | None

Show details

Classifier-free guidance scale. If None CFG is disabled.

field num_train_timesteps: int | None

Show details

If set, the pipeline rescales the continuous RF timestep t [0, 1] to num_train_timesteps * t before passing it to the model. Matches the diffusers convention used by Wan 2.2 / SD3 / Flux (num_train_timesteps = 1000). Leave None when the model wrapper already handles the rescaling internally.

field pred_type: PredType

Show details

Quantity predicted by the teacher / student network.

field sample_t_cfg: SampleTimestepConfig [Optional]

Show details

Timestep distribution used for both the teacher forward and the VSD / DSM losses.

field student_sample_steps: int

Show details

Number of denoising steps the distilled student performs at inference.

field student_sample_type: Literal['sde', 'ode']

Show details

Integrator used when unrolling the student over student_sample_steps > 1 steps. Not read by DMDPipeline at training time — consumed by inference samplers that unroll the student over student_sample_steps > 1 steps.

class DistillationPipeline

Bases: object

Hold student/teacher references and expose shared utilities.

Parameters:
  • student – Trainable student module. The pipeline does not wrap it — its lifecycle (train() / eval(), requires_grad_, sharding, optimizer) remains owned by the caller.

  • teacher – Reference module. Frozen here via eval() + requires_grad_(False).

  • config – A DistillationConfig (or subclass).

__init__(student, teacher, config)

Store student / teacher references and freeze the teacher.

Parameters:
Return type:

None

property device: device

Device of the first student parameter (best-effort; falls back to CPU).

property dtype: dtype

Dtype of the first student parameter (best-effort; falls back to float32).

sample_timesteps(n, *, device=None, dtype=torch.float32)

Sample n training timesteps according to config.``sample_t_cfg``.

Parameters:
  • n (int)

  • device (device | None)

  • dtype (dtype)

Return type:

Tensor

ModeloptConfig EMAConfig

Bases: ModeloptBaseConfig

Exponential moving average (EMA) hyperparameters for the student network.

Show default config as JSON
Default config (JSON):

{
   "decay": 0.9999,
   "type": "constant",
   "start_iter": 0,
   "gamma": 16.97,
   "halflife_kimg": 500.0,
   "rampup_ratio": 0.05,
   "batch_size": 1,
   "fsdp2": true,
   "mode": "full_tensor",
   "dtype": "float32"
}

field batch_size: int

Show details

Per-step global batch size used to convert iterations to nimg for the halflife schedule.

field decay: float

Show details

Decay coefficient for type='constant'. Ignored for halflife/power.

field dtype: Literal['float32', 'bfloat16', 'float16'] | None

Show details

Precision of the EMA parameter shadows. Defaults to float32 so EMA updates remain numerically meaningful even when the live model is bf16/fp16 (cf. FastGen, which instantiates its EMA module in the net’s construction dtype — typically fp32). Pass None to keep param shadows in the live parameter’s dtype. Buffer shadows always track the live dtype regardless of this setting.

field fsdp2: bool

Show details

If True, the EMA uses DTensor.full_tensor() to gather sharded parameters before updating.

field gamma: float

Show details

Exponent for type='power' (beta = (1 - 1/iter)**(gamma + 1)).

field halflife_kimg: float

Show details

Halflife in thousands of images for type='halflife'.

field mode: Literal['full_tensor', 'local_shard']

Show details

full_tensor performs an all_gather per parameter (higher memory, exact global EMA). local_shard updates each rank’s local DTensor shard in place (low memory fallback).

field rampup_ratio: float | None

Show details

Rampup fraction for type='halflife'; pass None to disable rampup.

field start_iter: int

Show details

Iteration at which EMA tracking begins (EMA is initialized from the live weights at this step).

field type: Literal['constant', 'halflife', 'power']

Show details

Schedule used to compute the per-step decay coefficient.

class ExponentialMovingAverage

Bases: object

FSDP2-aware EMA tracker for a PyTorch module.

The tracker stores a shadow state dict: parameters are promoted per EMAConfig.dtype (default fp32) while buffers are kept in the live module’s dtype. Buffers are replicated across ranks and stepped via copy_ rather than lerp_, so the bf16-roundoff argument that motivates parameter promotion doesn’t apply — preserving the live dtype makes the buffer restore exact.

By default the tracker materialises the full tensor per parameter (mode='full_tensor') so the EMA represents the globally averaged weights even when the model is sharded across ranks. A mode='local_shard' fallback is available for memory-constrained settings — it does not all-gather and therefore each rank holds an EMA of its local shard only.

Example:

ema = ExponentialMovingAverage(student, EMAConfig(decay=0.999))
for step in range(max_steps):
    ...  # compute loss, backward, optimizer.step()
    ema.update(student, iteration=step)

ema.copy_to(student_for_eval)  # publish for inference
__init__(model, config)

Pre-allocate the shadow state from model’s parameters and buffers.

Parameters:
Return type:

None

copy_to(target)

Load the shadow state into target (which should share the tracked module’s structure).

The target is expected to be an unsharded module (i.e. the caller has unwrapped any FSDP2 wrappers before calling). For sharded targets, prefer saving the shadow via state_dict() and reloading it through the framework’s usual checkpoint path.

Parameters:

target (Module)

Return type:

None

load_state_dict(state)

Restore the shadow state from a previously saved dict.

Parameters:

state (dict[str, Tensor])

Return type:

None

state_dict()

Return the shadow state (parameters + buffers) for checkpointing.

Return type:

dict[str, Tensor]

update(model, *, iteration)

Update the shadow state from model at the given iteration.

Skips updates before EMAConfig.start_iter. On the iteration that equals start_iter the shadow is (re-)initialised from the live weights; after that it is updated with shadow = beta * shadow + (1 - beta) * live.

Parameters:
  • model (Module)

  • iteration (int)

Return type:

None

ModeloptConfig SampleTimestepConfig

Bases: ModeloptBaseConfig

Timestep sampling distribution for diffusion training.

Show default config as JSON
Default config (JSON):

{
   "time_dist_type": "shifted",
   "min_t": 0.001,
   "max_t": 0.999,
   "shift": 5.0,
   "p_mean": 0.0,
   "p_std": 1.0,
   "t_list": null
}

field max_t: float

Show details

Upper bound of the sampling range (clamped before use).

field min_t: float

Show details

Lower bound of the sampling range (clamped before use).

field p_mean: float

Show details

Mean of the underlying normal for logitnormal / lognormal.

field p_std: float

Show details

Standard deviation of the underlying normal for logitnormal / lognormal.

field shift: float

Show details

Shift factor for time_dist_type='shifted'; must be >= 1.

field t_list: list[float] | None

Show details

Explicit timestep schedule used when DMDConfig.student_sample_steps > 1. The final element must be 0.0.

field time_dist_type: TimeDistType

Show details

Distribution used to sample the training timestep t. Rectified-flow models typically use shifted (Wan 2.2) or logitnormal (SD3, Flux).

create_fake_score(teacher, *, deep_copy=True)

Return a trainable fake-score network initialized from the teacher.

This is the unit-test / single-script path; frameworks that do meta-init + FSDP2 wrapping will typically construct the fake score themselves and pass it directly into DMDPipeline.

Parameters:
  • teacher (nn.Module) – The already-built teacher module. Must already have its weights loaded.

  • deep_copy (bool) – If True, copy.deepcopy() the teacher; if False, reuse the same instance (only sensible if the caller can guarantee it is no longer held elsewhere as the frozen teacher).

Returns:

A copy of teacher in training mode with all parameters requiring gradients.

Return type:

nn.Module

FSDP2 caveat

copy.deepcopy(teacher) is not safe when the teacher is already FSDP2-wrapped (DTensor parameters + FSDP pre/post hooks + meta-init bookkeeping). For Stage-2 FSDP2 training, skip this factory and construct the fake score under meta-init, then rank-0-load weights and let sync_module_states broadcast:

with meta_init_context():
    fake_score = build_teacher_from_config(teacher_config)
if is_rank0():
    fake_score.load_state_dict(teacher.state_dict(), strict=False)
# Wrap with FSDP2(..., sync_module_states=True) to broadcast from rank 0.

The pattern mirrors FastGen’s methods/distribution_matching/dmd2.py::DMD2Model.build_model. A dedicated create_fake_score_meta factory is planned alongside the Stage-2 training example.

raises RuntimeError:

When deep_copy=True and the teacher looks FSDP-wrapped (either FSDP1 via _fsdp_wrapped_module or FSDP2 via DTensor parameters). The deep_copy=False branch skips the check because reusing the teacher directly is compatible with an FSDP-wrapped input.

load_config(config_file)

Load a YAML file and return the parsed mapping.

Mirrors modelopt.recipe._config_loader.load_config() in spirit but without the ExMy-num-bits post-processing that is specific to quantization recipes.

Parameters:

config_file (str | Path) – YAML path. Suffix is optional; resolution searches the built-in modelopt_recipes/ package first, then the filesystem.

Returns:

The parsed dictionary. An empty file yields {}.

Return type:

dict[str, Any]

load_dmd_config(config_file)

Load a YAML file and construct a DMDConfig.

The YAML is validated against DMDConfig’s Pydantic schema — unknown keys raise ValidationError.

Example YAML:

pred_type: flow
guidance_scale: 5.0
student_sample_steps: 2
gan_loss_weight_gen: 0.03
sample_t_cfg:
  time_dist_type: shifted
  t_list: [0.999, 0.833, 0.0]
ema:
  decay: 0.9999
Parameters:

config_file (str | Path)

Return type:

DMDConfig