ema
Exponential moving average of a student network, FSDP2 DTensor aware.
Ported from FastGen/fastgen/callbacks/ema.py (lines 20-169) but exposed as a plain
class rather than a framework-specific callback. The caller decides when to call
update() (typically after optimizer.step()), how to persist the shadow state
(via state_dict()), and when to publish the EMA weights back to a target module
(via copy_to()).
Classes
FSDP2-aware EMA tracker for a PyTorch module. |
- class ExponentialMovingAverage
Bases:
objectFSDP2-aware EMA tracker for a PyTorch module.
The tracker stores a shadow state dict: parameters are promoted per
EMAConfig.dtype(default fp32) while buffers are kept in the live module’s dtype. Buffers are replicated across ranks and stepped viacopy_rather thanlerp_, so the bf16-roundoff argument that motivates parameter promotion doesn’t apply — preserving the live dtype makes the buffer restore exact.By default the tracker materialises the full tensor per parameter (
mode='full_tensor') so the EMA represents the globally averaged weights even when the model is sharded across ranks. Amode='local_shard'fallback is available for memory-constrained settings — it does not all-gather and therefore each rank holds an EMA of its local shard only.Example:
ema = ExponentialMovingAverage(student, EMAConfig(decay=0.999)) for step in range(max_steps): ... # compute loss, backward, optimizer.step() ema.update(student, iteration=step) ema.copy_to(student_for_eval) # publish for inference
- __init__(model, config)
Pre-allocate the shadow state from
model’s parameters and buffers.- Parameters:
model (nn.Module)
config (EMAConfig)
- Return type:
None
- copy_to(target)
Load the shadow state into
target(which should share the tracked module’s structure).The target is expected to be an unsharded module (i.e. the caller has unwrapped any FSDP2 wrappers before calling). For sharded targets, prefer saving the shadow via
state_dict()and reloading it through the framework’s usual checkpoint path.- Parameters:
target (Module)
- Return type:
None
- load_state_dict(state)
Restore the shadow state from a previously saved dict.
- Parameters:
state (dict[str, Tensor])
- Return type:
None
- state_dict()
Return the shadow state (parameters + buffers) for checkpointing.
- Return type:
dict[str, Tensor]
- update(model, *, iteration)
Update the shadow state from
modelat the given iteration.Skips updates before
EMAConfig.start_iter. On the iteration that equalsstart_iterthe shadow is (re-)initialised from the live weights; after that it is updated withshadow = beta * shadow + (1 - beta) * live.- Parameters:
model (Module)
iteration (int)
- Return type:
None