base_hooks_analysis
Analysis tools for evaluating importance scores from hooks.
Functions
Compute reconstruction error after pruning input channels of a linear layer. |
- evaluate_importance_scores(linear_layer, activations_batches, importance_scores, prune_ratio=0.2)
Compute reconstruction error after pruning input channels of a linear layer.
This function simulates channel pruning by zeroing out input channels identified as least important, then measures how much the layer’s output changes.
- Parameters:
linear_layer (Linear) – The linear layer to analyze with shape (out_features, in_features). For example: nn.Linear(in_features=1024, out_features=4096)
activations_batches (list[Tensor]) – List of input activation tensors. Each tensor has shape [seq_len, batch_size, in_features]. The last dimension must match linear_layer.in_features. Example: List of [16, 8, 1024] tensors
importance_scores (Tensor) – Importance score for each input channel (feature). Shape: [in_features]. Lower scores = less important. Example: [1024] tensor with one score per input feature
prune_ratio (float) – Fraction of input channels to prune (default: 0.2 means prune 20%).
- Returns:
rmse: Root mean squared error between original and pruned output
cosine_similarity: Cosine similarity between original and pruned output
num_pruned: Number of input channels pruned
- Return type:
Dictionary containing averaged metrics across all activation batches
Example
>>> layer = nn.Linear(in_features=1024, out_features=4096) >>> # Collect multiple batches for robust evaluation >>> activations_list = [torch.randn(16, 8, 1024) for _ in range(100)] >>> scores = torch.randn(1024) # one score per input feature >>> metrics = evaluate_importance_scores(layer, activations_list, scores, 0.2) >>> print(f"RMSE: {metrics['rmse']:.4f}, Pruned: {metrics['num_pruned']} channels")
Note
This simulates pruning (zeros out inputs) without modifying layer weights
“Channels” refers to INPUT features, not output features