base_hooks_analysis

Analysis tools for evaluating importance scores from hooks.

Functions

evaluate_importance_scores

Compute reconstruction error after pruning input channels of a linear layer.

evaluate_importance_scores(linear_layer, activations_batches, importance_scores, prune_ratio=0.2)

Compute reconstruction error after pruning input channels of a linear layer.

This function simulates channel pruning by zeroing out input channels identified as least important, then measures how much the layer’s output changes.

Parameters:
  • linear_layer (Linear) – The linear layer to analyze with shape (out_features, in_features). For example: nn.Linear(in_features=1024, out_features=4096)

  • activations_batches (list[Tensor]) – List of input activation tensors. Each tensor has shape [seq_len, batch_size, in_features]. The last dimension must match linear_layer.in_features. Example: List of [16, 8, 1024] tensors

  • importance_scores (Tensor) – Importance score for each input channel (feature). Shape: [in_features]. Lower scores = less important. Example: [1024] tensor with one score per input feature

  • prune_ratio (float) – Fraction of input channels to prune (default: 0.2 means prune 20%).

Returns:

  • rmse: Root mean squared error between original and pruned output

  • cosine_similarity: Cosine similarity between original and pruned output

  • num_pruned: Number of input channels pruned

Return type:

Dictionary containing averaged metrics across all activation batches

Example

>>> layer = nn.Linear(in_features=1024, out_features=4096)
>>> # Collect multiple batches for robust evaluation
>>> activations_list = [torch.randn(16, 8, 1024) for _ in range(100)]
>>> scores = torch.randn(1024)  # one score per input feature
>>> metrics = evaluate_importance_scores(layer, activations_list, scores, 0.2)
>>> print(f"RMSE: {metrics['rmse']:.4f}, Pruned: {metrics['num_pruned']} channels")

Note

  • This simulates pruning (zeros out inputs) without modifying layer weights

  • “Channels” refers to INPUT features, not output features