model_sparsify
Entry points for KV cache sparsity: sparsify() and calibrate().
Functions
Calibrate TriAttention frequency statistics. |
|
Apply KV cache sparsity optimization to a model. |
- calibrate(model, config=None, forward_loop=None)
Calibrate TriAttention frequency statistics.
Runs a forward pass with hooks to capture pre-RoPE Q vectors, inverts RoPE, and computes per-head frequency centers. Results are stored in the model’s modelopt_state metadata.
- Parameters:
model (nn.Module) – Model with TriAttention mode applied.
config (dict[str, Any] | TriAttentionConfig | None) – Optional config override.
forward_loop (ForwardLoop | None) – Callable that runs forward passes on calibration data. If None, calibration is skipped (no-op).
- Returns:
The model with calibration data stored in metadata.
- Return type:
nn.Module
- sparsify(model, config, forward_loop=None)
Apply KV cache sparsity optimization to a model.
Registers the TriAttention mode on the model. Call
calibrate()afterwards to compute frequency statistics from calibration data.- Parameters:
model (nn.Module) – The model to optimize.
config (dict[str, Any] | TriAttentionConfig) – TriAttentionConfig or dict with config values.
forward_loop (ForwardLoop | None) – Optional forward loop for integrated calibration.
- Returns:
The model with TriAttention mode applied (in-place).
- Return type:
nn.Module