discriminators
Discriminator modules for the DMD2 GAN branch.
Ports FastGen’s image-DiT discriminator from
source/FastGen/fastgen/networks/discriminators.py so that ModelOpt’s
DMDPipeline can run the GAN branch
without a FastGen dependency. The discriminator is model-agnostic: it takes
a list of spatial feature tensors [B, C, H, W] and returns concatenated
logits [B, num_heads]. The model-specific work of producing those tensors
(installing forward hooks, reshaping packed-token streams into spatial maps)
lives in the per-model plugins (plugins/qwen_image.py).
Classes
Base class for DMD2 discriminators. |
|
Image-DiT discriminator with one lightweight conv head per captured block. |
- class Discriminator
Bases:
ModuleBase class for DMD2 discriminators.
- __init__(feature_indices=None)
Store the teacher block indices whose features feed the discriminator.
- Parameters:
feature_indices (set[int] | None)
- Return type:
None
- forward(feats)
Map captured teacher features to discriminator logits (overridden by subclasses).
- Parameters:
feats (list[Tensor])
- Return type:
Tensor
- class Discriminator_ImageDiT
Bases:
DiscriminatorImage-DiT discriminator with one lightweight conv head per captured block.
Input: list of feature tensors with shape
[B, inner_dim, H, W], one per block index infeature_indices.Output: concatenated logits
[B, num_heads](one column per head). The DMD2 generator/discriminator losses read this as a 2D tensor.Per-head parameter count is ~``inner_dim * (inner_dim // 2) * 16 + …``; for
inner_dim=3072(Flux / Qwen-Image) that’s ~75 M params per head, so keeplen(feature_indices)small (≤3 heads is typical).- __init__(feature_indices=None, num_blocks=57, inner_dim=3072)
Build one lightweight conv classification head per captured block.
- Parameters:
feature_indices (set[int] | None)
num_blocks (int)
inner_dim (int)
- Return type:
None
- forward(feats)
Run each per-block conv head and concatenate their logits to
[B, num_heads].- Parameters:
feats (list[Tensor])
- Return type:
Tensor