factory
Convenience factory helpers for constructing the auxiliary DMD networks.
These helpers are intentionally tiny — the training framework is free to build the
fake score directly (e.g. under a meta-init context for FSDP2) instead of calling
create_fake_score(). See the ModelOpt ↔ FastGen design doc (FASTGEN_MODELOPT.md,
section “How the framework can build the fake_score”) for both options.
Functions
Return a trainable fake-score network initialized from the teacher. |
- create_fake_score(teacher, *, deep_copy=True)
Return a trainable fake-score network initialized from the teacher.
This is the unit-test / single-script path; frameworks that do meta-init + FSDP2 wrapping will typically construct the fake score themselves and pass it directly into
DMDPipeline.- Parameters:
teacher (nn.Module) – The already-built teacher module. Must already have its weights loaded.
deep_copy (bool) – If True,
copy.deepcopy()the teacher; if False, reuse the same instance (only sensible if the caller can guarantee it is no longer held elsewhere as the frozen teacher).
- Returns:
A copy of
teacherin training mode with all parameters requiring gradients.- Return type:
nn.Module
FSDP2 caveat
copy.deepcopy(teacher)is not safe when the teacher is already FSDP2-wrapped (DTensor parameters + FSDP pre/post hooks + meta-init bookkeeping). For Stage-2 FSDP2 training, skip this factory and construct the fake score under meta-init, then rank-0-load weights and letsync_module_statesbroadcast:with meta_init_context(): fake_score = build_teacher_from_config(teacher_config) if is_rank0(): fake_score.load_state_dict(teacher.state_dict(), strict=False) # Wrap with FSDP2(..., sync_module_states=True) to broadcast from rank 0.
The pattern mirrors FastGen’s
methods/distribution_matching/dmd2.py::DMD2Model.build_model. A dedicatedcreate_fake_score_metafactory is planned alongside the Stage-2 training example.- raises RuntimeError:
When
deep_copy=Trueand the teacher looks FSDP-wrapped (either FSDP1 via_fsdp_wrapped_moduleor FSDP2 via DTensor parameters). Thedeep_copy=Falsebranch skips the check because reusing the teacher directly is compatible with an FSDP-wrapped input.