factory

Convenience factory helpers for constructing the auxiliary DMD networks.

These helpers are intentionally tiny — the training framework is free to build the fake score directly (e.g. under a meta-init context for FSDP2) instead of calling create_fake_score(). See the ModelOpt ↔ FastGen design doc (FASTGEN_MODELOPT.md, section “How the framework can build the fake_score”) for both options.

Functions

create_fake_score

Return a trainable fake-score network initialized from the teacher.

create_fake_score(teacher, *, deep_copy=True)

Return a trainable fake-score network initialized from the teacher.

This is the unit-test / single-script path; frameworks that do meta-init + FSDP2 wrapping will typically construct the fake score themselves and pass it directly into DMDPipeline.

Parameters:
  • teacher (nn.Module) – The already-built teacher module. Must already have its weights loaded.

  • deep_copy (bool) – If True, copy.deepcopy() the teacher; if False, reuse the same instance (only sensible if the caller can guarantee it is no longer held elsewhere as the frozen teacher).

Returns:

A copy of teacher in training mode with all parameters requiring gradients.

Return type:

nn.Module

FSDP2 caveat

copy.deepcopy(teacher) is not safe when the teacher is already FSDP2-wrapped (DTensor parameters + FSDP pre/post hooks + meta-init bookkeeping). For Stage-2 FSDP2 training, skip this factory and construct the fake score under meta-init, then rank-0-load weights and let sync_module_states broadcast:

with meta_init_context():
    fake_score = build_teacher_from_config(teacher_config)
if is_rank0():
    fake_score.load_state_dict(teacher.state_dict(), strict=False)
# Wrap with FSDP2(..., sync_module_states=True) to broadcast from rank 0.

The pattern mirrors FastGen’s methods/distribution_matching/dmd2.py::DMD2Model.build_model. A dedicated create_fake_score_meta factory is planned alongside the Stage-2 training example.

raises RuntimeError:

When deep_copy=True and the teacher looks FSDP-wrapped (either FSDP1 via _fsdp_wrapped_module or FSDP2 via DTensor parameters). The deep_copy=False branch skips the check because reusing the teacher directly is compatible with an FSDP-wrapped input.