model_sparsify

Entry points for KV cache sparsity: sparsify() and calibrate().

Functions

calibrate

Calibrate TriAttention frequency statistics.

sparsify

Apply KV cache sparsity optimization to a model.

calibrate(model, config=None, forward_loop=None)

Calibrate TriAttention frequency statistics.

Runs a forward pass with hooks to capture pre-RoPE Q vectors, inverts RoPE, and computes per-head frequency centers. Results are stored in the model’s modelopt_state metadata.

Parameters:
  • model (nn.Module) – Model with TriAttention mode applied.

  • config (dict[str, Any] | TriAttentionConfig | None) – Optional config override.

  • forward_loop (ForwardLoop | None) – Callable that runs forward passes on calibration data. If None, calibration is skipped (no-op).

Returns:

The model with calibration data stored in metadata.

Return type:

nn.Module

sparsify(model, config, forward_loop=None)

Apply KV cache sparsity optimization to a model.

Registers the TriAttention mode on the model. Call calibrate() afterwards to compute frequency statistics from calibration data.

Parameters:
  • model (nn.Module) – The model to optimize.

  • config (dict[str, Any] | TriAttentionConfig) – TriAttentionConfig or dict with config values.

  • forward_loop (ForwardLoop | None) – Optional forward loop for integrated calibration.

Returns:

The model with TriAttention mode applied (in-place).

Return type:

nn.Module