discriminators

Discriminator modules for the DMD2 GAN branch.

Ports FastGen’s image-DiT discriminator from source/FastGen/fastgen/networks/discriminators.py so that ModelOpt’s DMDPipeline can run the GAN branch without a FastGen dependency. The discriminator is model-agnostic: it takes a list of spatial feature tensors [B, C, H, W] and returns concatenated logits [B, num_heads]. The model-specific work of producing those tensors (installing forward hooks, reshaping packed-token streams into spatial maps) lives in the per-model plugins (plugins/qwen_image.py).

Classes

Discriminator

Base class for DMD2 discriminators.

Discriminator_ImageDiT

Image-DiT discriminator with one lightweight conv head per captured block.

class Discriminator

Bases: Module

Base class for DMD2 discriminators.

__init__(feature_indices=None)

Store the teacher block indices whose features feed the discriminator.

Parameters:

feature_indices (set[int] | None)

Return type:

None

forward(feats)

Map captured teacher features to discriminator logits (overridden by subclasses).

Parameters:

feats (list[Tensor])

Return type:

Tensor

class Discriminator_ImageDiT

Bases: Discriminator

Image-DiT discriminator with one lightweight conv head per captured block.

Input: list of feature tensors with shape [B, inner_dim, H, W], one per block index in feature_indices.

Output: concatenated logits [B, num_heads] (one column per head). The DMD2 generator/discriminator losses read this as a 2D tensor.

Per-head parameter count is ~``inner_dim * (inner_dim // 2) * 16 + …``; for inner_dim=3072 (Flux / Qwen-Image) that’s ~75 M params per head, so keep len(feature_indices) small (≤3 heads is typical).

__init__(feature_indices=None, num_blocks=57, inner_dim=3072)

Build one lightweight conv classification head per captured block.

Parameters:
  • feature_indices (set[int] | None)

  • num_blocks (int)

  • inner_dim (int)

Return type:

None

forward(feats)

Run each per-block conv head and concatenate their logits to [B, num_heads].

Parameters:

feats (list[Tensor])

Return type:

Tensor