base_hooks
Forward hooks for activation-based importance estimation.
Classes
Base class for PyTorch forward hooks. |
|
Hook for channel importance estimation using weight norms and activation magnitudes. |
|
Hook for estimating KV head importance based on contribution analysis. |
|
Hook for iterative channel pruning based on contribution analysis. |
|
Hook for accumulating activation statistics for importance estimation. |
|
Hook for estimating channel importance based on layer normalization activations. |
- class ForwardHook
Bases:
ABCBase class for PyTorch forward hooks.
This follows the PyTorch forward hook API where the second parameter is ‘args’ (a tuple of positional arguments passed to forward()).
- Usage:
hook = MyHook() module.register_forward_hook(hook)
- abstract accumulate()
Return accumulated importance scores.
This method should be called after all forward passes to retrieve the final importance scores for each channel/feature.
- Returns:
Tensor of importance scores, one per channel/feature.
- Raises:
AssertionError – If no activations have been collected yet.
- Return type:
Tensor
- classmethod dump_activations_logs(activation_hooks, activations_log_dir, args)
Default implementation for dumping final activation scores logs to disk.
This is called only at the end of scoring to save final results.
- Parameters:
activation_hooks (dict[str, ForwardHook])
activations_log_dir (Path | str)
args (DictConfig)
- Return type:
None
- get_progress_info()
Get progress information for this hook.
- Returns:
- Progress information (e.g., current iteration, samples processed).
Default implementation returns empty dict.
- Return type:
dict
- abstract load_state_dict(state_dict)
Load the internal state from a checkpoint.
- Parameters:
state_dict (dict) – State dictionary previously returned by state_dict()
- Return type:
None
- classmethod save_hook_states(activation_hooks, activations_log_dir)
Save hook states for checkpointing (separate from final results).
This can be called periodically during scoring. Note: Synchronization should be handled at a higher level to avoid deadlocks.
- Parameters:
activation_hooks (dict[str, ForwardHook])
activations_log_dir (Path | str)
- Return type:
None
- abstract state_dict()
Return the internal state for checkpointing.
- Returns:
- State dictionary containing checkpoint data.
Can contain tensors, ints, lists, etc.
- Return type:
dict
- abstract to_dict()
Convert hook results to dictionary format for saving.
- Returns:
Dictionary containing result tensors (e.g., “score”, “channels_importance_ascending”).
- Return type:
dict
- class IndependentChannelContributionHook
Bases:
ForwardHookHook for channel importance estimation using weight norms and activation magnitudes.
Computes channel importance as the product of: - L2 norm of each column in the weight matrix (how much each input channel affects output) - Mean absolute activation for each channel (how strongly each channel is activated)
- Parameters:
linear_layer – The linear projection layer to analyze. Must have a weight attribute and either in_features (nn.Linear) or input_size (Megatron RowParallelLinear).
- __init__(linear_layer)
Initialize the independent channel contribution hook.
- Parameters:
linear_layer (Module)
- accumulate()
Return importance scores as a tensor.
- Returns:
Tensor of importance scores (weight_norm * activations), one per channel.
- Return type:
Tensor
- load_state_dict(state_dict)
Load the internal state from a checkpoint.
- Parameters:
state_dict (dict)
- Return type:
None
- state_dict()
Save the internal state for checkpointing.
- Return type:
dict
- to_dict()
Convert results to dict with channel importance scores.
- Returns:
Dict with “score” (weight_norm * activations), “weight_norm”, and “agg_channel_activations”.
- Return type:
dict[str, Tensor]
- class IndependentKvHeadContributionHook
Bases:
ForwardHookHook for estimating KV head importance based on contribution analysis.
Measures the contribution of each KV head group to the output projection by computing L2 norms of per-head outputs.
- Parameters:
linear_layer – The output projection layer (o_proj).
activation_hooks_kwargs – Configuration dict with: - model: The model instance (to get config). - block_config: Block configuration with attention settings. - optimize_for (str, optional): “latency” or “memory”. Defaults to “memory”.
- __init__(linear_layer, activation_hooks_kwargs)
Initialize the KV head contribution hook.
- Parameters:
linear_layer (Linear)
activation_hooks_kwargs (dict)
- accumulate()
Return accumulated KV head importance scores.
- Returns:
Tensor of importance scores, one per KV head.
- Return type:
Tensor
- load_state_dict(state_dict)
Load the internal state from a checkpoint.
- Parameters:
state_dict (dict)
- Return type:
None
- state_dict()
Return the internal state for checkpointing.
- Return type:
dict
- to_dict()
Convert to dict format for saving.
- Returns:
Dict with “score” tensor containing KV head importance scores.
- Return type:
dict[str, Tensor]
- class IterativeChannelContributionHook
Bases:
ForwardHookHook for iterative channel pruning based on contribution analysis.
Progressively identifies and removes the least important input channels of a linear layer by measuring channel contribution as the L2 norm of output change when removed.
- Parameters:
linear_layer – The linear projection layer to analyze. Must have a weight attribute and either in_features (nn.Linear) or input_size (Megatron RowParallelLinear).
activation_hooks_kwargs – Configuration dict with: - validation_full_iters (int): Number of pruning iterations. - clear_gpu_memory (bool, optional): Clear GPU memory during computation. - calibration_method (str, optional): “scale_by_magnitude” or None.
- __init__(linear_layer, activation_hooks_kwargs)
Initialize the iterative channel contribution hook.
- Parameters:
linear_layer (Module)
activation_hooks_kwargs (dict)
- accumulate()
Return importance scores as a tensor.
- Returns:
Tensor of importance scores, one per channel. Lower scores indicate less important channels.
- Return type:
Tensor
- get_progress_info()
Get progress information for this hook.
- Returns:
Progress information including iteration count and pruned channels.
- Return type:
dict
- load_state_dict(state_dict)
Load the internal state from a checkpoint.
- Parameters:
state_dict (dict)
- Return type:
None
- state_dict()
Save the internal state for checkpointing.
- Return type:
dict
- to_dict()
Convert pruning results to dict with channel importance rankings.
- Returns:
Dict with “score” (importance rank per channel) and “channels_importance_ascending” (channel indices in ascending importance).
- Return type:
dict[str, Tensor]
- class L2NormHook
Bases:
ForwardHookHook for accumulating activation statistics for importance estimation.
Activations are computed as mean over seq_len and then squared and summed over batch_size. In the accumulate() method we take the square root of the sum to get the L2 norm.
This is the base version without tensor parallelism support. For megatron with TP > 1, use MegatronL2NormHook instead.
- __init__()
Initialize the L2NormHook.
- accumulate()
Return the accumulated L2 norm of activations.
- Returns:
Tensor of accumulated scores, one per channel
- Raises:
AssertionError – If no activations have been collected yet
- Return type:
Tensor
- load_state_dict(state_dict)
Load activations from checkpoint.
- Parameters:
state_dict (dict)
- Return type:
None
- state_dict()
Return the state dictionary containing activations.
- Return type:
dict
- to_dict()
Convert to dict format for saving.
- Return type:
dict[str, Tensor]
- class LayerNormContributionHook
Bases:
ForwardHookHook for estimating channel importance based on layer normalization activations.
Aggregates mean absolute activation values per channel for a layer normalization layer.
- Parameters:
layernorm_layer – The layer normalization layer.
activation_hooks_kwargs – The activation hooks kwargs (not used).
- __init__(layernorm_layer, activation_hooks_kwargs)
Aggregates mean absolute activation values per channel for a layer normalization layer.
- Parameters:
layernorm_layer (Module) – The layer normalization layer
activation_hooks_kwargs (dict) – The activation hooks kwargs (not used)
- accumulate()
Return accumulated channel importance scores.
- Return type:
Tensor
- classmethod dump_activations_logs(activation_hooks, activations_log_dir, args)
At the end of the default implementation of dumping activation scores to disc.
Save aggregated channel importance results.
- Parameters:
activation_hooks (dict[str, ForwardHook])
activations_log_dir (Path | str)
args (DictConfig)
- Return type:
None
- load_state_dict(state_dict)
Load the internal state from a checkpoint.
- Parameters:
state_dict (dict)
- Return type:
None
- state_dict()
Return the internal state for checkpointing.
- Return type:
dict
- to_dict()
Convert to dict format for saving.
- Return type:
dict[str, Tensor]