ema

Exponential moving average of a student network, FSDP2 DTensor aware.

Ported from FastGen/fastgen/callbacks/ema.py (lines 20-169) but exposed as a plain class rather than a framework-specific callback. The caller decides when to call update() (typically after optimizer.step()), how to persist the shadow state (via state_dict()), and when to publish the EMA weights back to a target module (via copy_to()).

Classes

ExponentialMovingAverage

FSDP2-aware EMA tracker for a PyTorch module.

class ExponentialMovingAverage

Bases: object

FSDP2-aware EMA tracker for a PyTorch module.

The tracker stores a shadow state dict: parameters are promoted per EMAConfig.dtype (default fp32) while buffers are kept in the live module’s dtype. Buffers are replicated across ranks and stepped via copy_ rather than lerp_, so the bf16-roundoff argument that motivates parameter promotion doesn’t apply — preserving the live dtype makes the buffer restore exact.

By default the tracker materialises the full tensor per parameter (mode='full_tensor') so the EMA represents the globally averaged weights even when the model is sharded across ranks. A mode='local_shard' fallback is available for memory-constrained settings — it does not all-gather and therefore each rank holds an EMA of its local shard only.

Example:

ema = ExponentialMovingAverage(student, EMAConfig(decay=0.999))
for step in range(max_steps):
    ...  # compute loss, backward, optimizer.step()
    ema.update(student, iteration=step)

ema.copy_to(student_for_eval)  # publish for inference
__init__(model, config)

Pre-allocate the shadow state from model’s parameters and buffers.

Parameters:
Return type:

None

copy_to(target)

Load the shadow state into target (which should share the tracked module’s structure).

The target is expected to be an unsharded module (i.e. the caller has unwrapped any FSDP2 wrappers before calling). For sharded targets, prefer saving the shadow via state_dict() and reloading it through the framework’s usual checkpoint path.

Parameters:

target (Module)

Return type:

None

load_state_dict(state)

Restore the shadow state from a previously saved dict.

Parameters:

state (dict[str, Tensor])

Return type:

None

state_dict()

Return the shadow state (parameters + buffers) for checkpointing.

Return type:

dict[str, Tensor]

update(model, *, iteration)

Update the shadow state from model at the given iteration.

Skips updates before EMAConfig.start_iter. On the iteration that equals start_iter the shadow is (re-)initialised from the live weights; after that it is updated with shadow = beta * shadow + (1 - beta) * live.

Parameters:
  • model (Module)

  • iteration (int)

Return type:

None