base_hooks

Forward hooks for activation-based importance estimation.

Classes

ForwardHook

Base class for PyTorch forward hooks.

IndependentChannelContributionHook

Hook for channel importance estimation using weight norms and activation magnitudes.

IndependentKvHeadContributionHook

Hook for estimating KV head importance based on contribution analysis.

IterativeChannelContributionHook

Hook for iterative channel pruning based on contribution analysis.

L2NormHook

Hook for accumulating activation statistics for importance estimation.

LayerNormContributionHook

Hook for estimating channel importance based on layer normalization activations.

class ForwardHook

Bases: ABC

Base class for PyTorch forward hooks.

This follows the PyTorch forward hook API where the second parameter is ‘args’ (a tuple of positional arguments passed to forward()).

Usage:

hook = MyHook() module.register_forward_hook(hook)

abstract accumulate()

Return accumulated importance scores.

This method should be called after all forward passes to retrieve the final importance scores for each channel/feature.

Returns:

Tensor of importance scores, one per channel/feature.

Raises:

AssertionError – If no activations have been collected yet.

Return type:

Tensor

classmethod dump_activations_logs(activation_hooks, activations_log_dir, args)

Default implementation for dumping final activation scores logs to disk.

This is called only at the end of scoring to save final results.

Parameters:
  • activation_hooks (dict[str, ForwardHook])

  • activations_log_dir (Path | str)

  • args (DictConfig)

Return type:

None

get_progress_info()

Get progress information for this hook.

Returns:

Progress information (e.g., current iteration, samples processed).

Default implementation returns empty dict.

Return type:

dict

abstract load_state_dict(state_dict)

Load the internal state from a checkpoint.

Parameters:

state_dict (dict) – State dictionary previously returned by state_dict()

Return type:

None

classmethod save_hook_states(activation_hooks, activations_log_dir)

Save hook states for checkpointing (separate from final results).

This can be called periodically during scoring. Note: Synchronization should be handled at a higher level to avoid deadlocks.

Parameters:
  • activation_hooks (dict[str, ForwardHook])

  • activations_log_dir (Path | str)

Return type:

None

abstract state_dict()

Return the internal state for checkpointing.

Returns:

State dictionary containing checkpoint data.

Can contain tensors, ints, lists, etc.

Return type:

dict

abstract to_dict()

Convert hook results to dictionary format for saving.

Returns:

Dictionary containing result tensors (e.g., “score”, “channels_importance_ascending”).

Return type:

dict

class IndependentChannelContributionHook

Bases: ForwardHook

Hook for channel importance estimation using weight norms and activation magnitudes.

Computes channel importance as the product of: - L2 norm of each column in the weight matrix (how much each input channel affects output) - Mean absolute activation for each channel (how strongly each channel is activated)

Parameters:

linear_layer – The linear projection layer to analyze. Must have a weight attribute and either in_features (nn.Linear) or input_size (Megatron RowParallelLinear).

__init__(linear_layer)

Initialize the independent channel contribution hook.

Parameters:

linear_layer (Module)

accumulate()

Return importance scores as a tensor.

Returns:

Tensor of importance scores (weight_norm * activations), one per channel.

Return type:

Tensor

load_state_dict(state_dict)

Load the internal state from a checkpoint.

Parameters:

state_dict (dict)

Return type:

None

state_dict()

Save the internal state for checkpointing.

Return type:

dict

to_dict()

Convert results to dict with channel importance scores.

Returns:

Dict with “score” (weight_norm * activations), “weight_norm”, and “agg_channel_activations”.

Return type:

dict[str, Tensor]

class IndependentKvHeadContributionHook

Bases: ForwardHook

Hook for estimating KV head importance based on contribution analysis.

Measures the contribution of each KV head group to the output projection by computing L2 norms of per-head outputs.

Parameters:
  • linear_layer – The output projection layer (o_proj).

  • activation_hooks_kwargs – Configuration dict with: - model: The model instance (to get config). - block_config: Block configuration with attention settings. - optimize_for (str, optional): “latency” or “memory”. Defaults to “memory”.

__init__(linear_layer, activation_hooks_kwargs)

Initialize the KV head contribution hook.

Parameters:
  • linear_layer (Linear)

  • activation_hooks_kwargs (dict)

accumulate()

Return accumulated KV head importance scores.

Returns:

Tensor of importance scores, one per KV head.

Return type:

Tensor

load_state_dict(state_dict)

Load the internal state from a checkpoint.

Parameters:

state_dict (dict)

Return type:

None

state_dict()

Return the internal state for checkpointing.

Return type:

dict

to_dict()

Convert to dict format for saving.

Returns:

Dict with “score” tensor containing KV head importance scores.

Return type:

dict[str, Tensor]

class IterativeChannelContributionHook

Bases: ForwardHook

Hook for iterative channel pruning based on contribution analysis.

Progressively identifies and removes the least important input channels of a linear layer by measuring channel contribution as the L2 norm of output change when removed.

Parameters:
  • linear_layer – The linear projection layer to analyze. Must have a weight attribute and either in_features (nn.Linear) or input_size (Megatron RowParallelLinear).

  • activation_hooks_kwargs – Configuration dict with: - validation_full_iters (int): Number of pruning iterations. - clear_gpu_memory (bool, optional): Clear GPU memory during computation. - calibration_method (str, optional): “scale_by_magnitude” or None.

__init__(linear_layer, activation_hooks_kwargs)

Initialize the iterative channel contribution hook.

Parameters:
  • linear_layer (Module)

  • activation_hooks_kwargs (dict)

accumulate()

Return importance scores as a tensor.

Returns:

Tensor of importance scores, one per channel. Lower scores indicate less important channels.

Return type:

Tensor

get_progress_info()

Get progress information for this hook.

Returns:

Progress information including iteration count and pruned channels.

Return type:

dict

load_state_dict(state_dict)

Load the internal state from a checkpoint.

Parameters:

state_dict (dict)

Return type:

None

state_dict()

Save the internal state for checkpointing.

Return type:

dict

to_dict()

Convert pruning results to dict with channel importance rankings.

Returns:

Dict with “score” (importance rank per channel) and “channels_importance_ascending” (channel indices in ascending importance).

Return type:

dict[str, Tensor]

class L2NormHook

Bases: ForwardHook

Hook for accumulating activation statistics for importance estimation.

Activations are computed as mean over seq_len and then squared and summed over batch_size. In the accumulate() method we take the square root of the sum to get the L2 norm.

This is the base version without tensor parallelism support. For megatron with TP > 1, use MegatronL2NormHook instead.

__init__()

Initialize the L2NormHook.

accumulate()

Return the accumulated L2 norm of activations.

Returns:

Tensor of accumulated scores, one per channel

Raises:

AssertionError – If no activations have been collected yet

Return type:

Tensor

load_state_dict(state_dict)

Load activations from checkpoint.

Parameters:

state_dict (dict)

Return type:

None

state_dict()

Return the state dictionary containing activations.

Return type:

dict

to_dict()

Convert to dict format for saving.

Return type:

dict[str, Tensor]

class LayerNormContributionHook

Bases: ForwardHook

Hook for estimating channel importance based on layer normalization activations.

Aggregates mean absolute activation values per channel for a layer normalization layer.

Parameters:
  • layernorm_layer – The layer normalization layer.

  • activation_hooks_kwargs – The activation hooks kwargs (not used).

__init__(layernorm_layer, activation_hooks_kwargs)

Aggregates mean absolute activation values per channel for a layer normalization layer.

Parameters:
  • layernorm_layer (Module) – The layer normalization layer

  • activation_hooks_kwargs (dict) – The activation hooks kwargs (not used)

accumulate()

Return accumulated channel importance scores.

Return type:

Tensor

classmethod dump_activations_logs(activation_hooks, activations_log_dir, args)

At the end of the default implementation of dumping activation scores to disc.

Save aggregated channel importance results.

Parameters:
  • activation_hooks (dict[str, ForwardHook])

  • activations_log_dir (Path | str)

  • args (DictConfig)

Return type:

None

load_state_dict(state_dict)

Load the internal state from a checkpoint.

Parameters:

state_dict (dict)

Return type:

None

state_dict()

Return the internal state for checkpointing.

Return type:

dict

to_dict()

Convert to dict format for saving.

Return type:

dict[str, Tensor]