expert_removal_hooks
MoE expert-removal and ranked-choice importance hooks (uses Puzzletron BlockConfig).
Classes
Expert removal importance hook for NemotronH models. |
|
Expert removal importance hook for Qwen3-VL models. |
|
Hook for ranking experts using ranked choice voting algorithm. |
|
Ranked choice voting hook for NemotronH models. |
|
Base hook for measuring expert importance in Mixture-of-Experts models. |
- class NemotronHRemoveExpertsIndependentHook
Bases:
RemoveExpertsIndependentHookExpert removal importance hook for NemotronH models.
- get_router_logits_and_routed_experts(hidden_states, router_logits=None)
Extract router logits and expert outputs for NemotronH MoE.
Based on NemotronHMOE forward, uses minimum ops to get router_logits and routed_experts.
- Parameters:
hidden_states (Tensor)
router_logits (Tensor | None)
- Return type:
tuple[Tensor, Tensor]
- class Qwen3VLRemoveExpertsIndependentHook
Bases:
RemoveExpertsIndependentHookExpert removal importance hook for Qwen3-VL models.
- get_router_logits_and_routed_experts(hidden_states, router_logits=None)
Extract router logits and expert outputs for Qwen3-VL MoE.
Based on Qwen3VLMoeSparseMoe forward pass.
- Parameters:
hidden_states (Tensor)
router_logits (Tensor | None)
- Return type:
tuple[Tensor, Tensor]
- class RankedChoiceVotingHook
Bases:
ForwardHookHook for ranking experts using ranked choice voting algorithm.
This hook tracks router decisions and uses ranked choice voting to determine which experts are least important (can be pruned first).
- __init__(router, activation_hooks_kwargs)
Initialize the hook.
- Parameters:
router (Module) – The router module (typically nn.Linear)
activation_hooks_kwargs (dict) – Configuration dict containing block_config
- accumulate()
Return accumulated expert ranks.
- Return type:
Tensor
- get_progress_info()
Get progress information.
- Return type:
dict
- load_state_dict(state_dict)
Load the internal state from a checkpoint.
- Parameters:
state_dict (dict)
- Return type:
None
- state_dict()
Return the internal state for checkpointing.
- Return type:
dict
- to_dict()
Convert accumulated statistics to dict format using ranked choice voting.
- Return type:
dict[str, Tensor]
- class RankedChoiceVotingHookNemotronH
Bases:
RankedChoiceVotingHookRanked choice voting hook for NemotronH models.
In NemotronH, router_logits is an internal temporary state that never leaves the forward() function. We reconstruct router_logits from the input hidden_states.
- class RemoveExpertsIndependentHook
Bases:
ForwardHook,ABCBase hook for measuring expert importance in Mixture-of-Experts models.
This hook measures how much removing each expert affects the model output by comparing outputs with and without each expert.
- __init__(moe, activation_hooks_kwargs)
Initialize the hook.
- Parameters:
moe (Module) – The MoE module to analyze
activation_hooks_kwargs (dict) – Configuration dict containing block_config
- accumulate()
Return accumulated expert importance scores.
- Return type:
Tensor
- abstract get_router_logits_and_routed_experts(hidden_states, router_logits=None)
Extract router logits and expert outputs for measuring expert importance.
This method is called twice per forward pass: 1. First call (router_logits=None): Compute original routing and expert outputs 2. Second call (router_logits provided): Re-run with modified logits (expert disabled)
- Parameters:
hidden_states (Tensor) – Input tensor of shape (batch, seq_len, hidden_dim)
router_logits (Tensor | None) – Optional pre-computed router logits. If None, compute from hidden_states.
- Returns:
router_logits: Shape (num_tokens, num_local_experts)
routed_experts: Shape (num_tokens, hidden_dim)
- Return type:
tuple of (router_logits, routed_experts)
- load_state_dict(state_dict)
Load the internal state from a checkpoint.
- Parameters:
state_dict (dict)
- Return type:
None
- state_dict()
Return the internal state for checkpointing.
- Return type:
dict
- to_dict()
Convert accumulated statistics to dict format.
- Return type:
dict[str, Tensor]