expert_removal_hooks

MoE expert-removal and ranked-choice importance hooks (uses Puzzletron BlockConfig).

Classes

NemotronHRemoveExpertsIndependentHook

Expert removal importance hook for NemotronH models.

Qwen3VLRemoveExpertsIndependentHook

Expert removal importance hook for Qwen3-VL models.

RankedChoiceVotingHook

Hook for ranking experts using ranked choice voting algorithm.

RankedChoiceVotingHookNemotronH

Ranked choice voting hook for NemotronH models.

RemoveExpertsIndependentHook

Base hook for measuring expert importance in Mixture-of-Experts models.

class NemotronHRemoveExpertsIndependentHook

Bases: RemoveExpertsIndependentHook

Expert removal importance hook for NemotronH models.

get_router_logits_and_routed_experts(hidden_states, router_logits=None)

Extract router logits and expert outputs for NemotronH MoE.

Based on NemotronHMOE forward, uses minimum ops to get router_logits and routed_experts.

Parameters:
  • hidden_states (Tensor)

  • router_logits (Tensor | None)

Return type:

tuple[Tensor, Tensor]

class Qwen3VLRemoveExpertsIndependentHook

Bases: RemoveExpertsIndependentHook

Expert removal importance hook for Qwen3-VL models.

get_router_logits_and_routed_experts(hidden_states, router_logits=None)

Extract router logits and expert outputs for Qwen3-VL MoE.

Based on Qwen3VLMoeSparseMoe forward pass.

Parameters:
  • hidden_states (Tensor)

  • router_logits (Tensor | None)

Return type:

tuple[Tensor, Tensor]

class RankedChoiceVotingHook

Bases: ForwardHook

Hook for ranking experts using ranked choice voting algorithm.

This hook tracks router decisions and uses ranked choice voting to determine which experts are least important (can be pruned first).

__init__(router, activation_hooks_kwargs)

Initialize the hook.

Parameters:
  • router (Module) – The router module (typically nn.Linear)

  • activation_hooks_kwargs (dict) – Configuration dict containing block_config

accumulate()

Return accumulated expert ranks.

Return type:

Tensor

get_progress_info()

Get progress information.

Return type:

dict

load_state_dict(state_dict)

Load the internal state from a checkpoint.

Parameters:

state_dict (dict)

Return type:

None

state_dict()

Return the internal state for checkpointing.

Return type:

dict

to_dict()

Convert accumulated statistics to dict format using ranked choice voting.

Return type:

dict[str, Tensor]

class RankedChoiceVotingHookNemotronH

Bases: RankedChoiceVotingHook

Ranked choice voting hook for NemotronH models.

In NemotronH, router_logits is an internal temporary state that never leaves the forward() function. We reconstruct router_logits from the input hidden_states.

class RemoveExpertsIndependentHook

Bases: ForwardHook, ABC

Base hook for measuring expert importance in Mixture-of-Experts models.

This hook measures how much removing each expert affects the model output by comparing outputs with and without each expert.

__init__(moe, activation_hooks_kwargs)

Initialize the hook.

Parameters:
  • moe (Module) – The MoE module to analyze

  • activation_hooks_kwargs (dict) – Configuration dict containing block_config

accumulate()

Return accumulated expert importance scores.

Return type:

Tensor

abstract get_router_logits_and_routed_experts(hidden_states, router_logits=None)

Extract router logits and expert outputs for measuring expert importance.

This method is called twice per forward pass: 1. First call (router_logits=None): Compute original routing and expert outputs 2. Second call (router_logits provided): Re-run with modified logits (expert disabled)

Parameters:
  • hidden_states (Tensor) – Input tensor of shape (batch, seq_len, hidden_dim)

  • router_logits (Tensor | None) – Optional pre-computed router logits. If None, compute from hidden_states.

Returns:

  • router_logits: Shape (num_tokens, num_local_experts)

  • routed_experts: Shape (num_tokens, hidden_dim)

Return type:

tuple of (router_logits, routed_experts)

load_state_dict(state_dict)

Load the internal state from a checkpoint.

Parameters:

state_dict (dict)

Return type:

None

state_dict()

Return the internal state for checkpointing.

Return type:

dict

to_dict()

Convert accumulated statistics to dict format.

Return type:

dict[str, Tensor]