triattention

TriAttention: Trigonometric KV cache compression.

Classes

HeadFrequencyStats

Per-head calibration statistics in frequency domain.

Functions

build_geometric_offsets

Build geometric offset sequence [1, 2, 4, 8, ..., max_length].

compute_frequency_statistics_from_means

Compute amplitude, phase, and MLR extra term from Q/K frequency statistics.

invert_rope

Invert RoPE rotation to recover pre-RoPE representation.

rotate_half

Rotate tensor for RoPE.

score_keys_for_round

Score cached keys for a single pruning round.

select_keys_to_keep

Select which keys to retain based on importance scores.

to_complex_pairs

Convert real tensor to complex representation for frequency analysis.

class HeadFrequencyStats

Bases: object

Per-head calibration statistics in frequency domain.

__init__(q_mean_complex, q_abs_mean)
Parameters:
  • q_mean_complex (Tensor)

  • q_abs_mean (Tensor)

Return type:

None

q_abs_mean: Tensor
q_mean_complex: Tensor
build_geometric_offsets(max_length, device)

Build geometric offset sequence [1, 2, 4, 8, …, max_length].

Used for multi-distance scoring in TriAttention — each offset represents a future distance at which the key’s importance is evaluated.

Parameters:
  • max_length (int) – Maximum offset value (must be >= 1).

  • device (device) – Device for the output tensor.

Returns:

1D float tensor of powers of 2 up to max_length.

Return type:

Tensor

compute_frequency_statistics_from_means(q_mean_complex, q_abs_mean, k_unrot, *, style='half', disable_mlr=False)

Compute amplitude, phase, and MLR extra term from Q/K frequency statistics.

Parameters:
  • q_mean_complex (Tensor) – Mean of Q in complex frequency domain, shape (freq_count,).

  • q_abs_mean (Tensor) – Mean of |Q| in frequency domain, shape (freq_count,).

  • k_unrot (Tensor) – Unrotated key vectors, shape (num_keys, head_dim).

  • style (str) – RoPE pairing style.

  • disable_mlr (bool) – If True, use q_abs_mean directly instead of q_abs_mean - |q_mean|.

Returns:

Amplitude, shape (num_keys, freq_count). phi: Phase, shape (num_keys, freq_count). extra: MLR extra term, shape (num_keys, freq_count).

Return type:

amp

invert_rope(rotated, cos, sin, scale, *, style='half')

Invert RoPE rotation to recover pre-RoPE representation.

Parameters:
  • rotated (Tensor) – RoPE-rotated tensor.

  • cos (Tensor) – Cosine table from rotary embedding.

  • sin (Tensor) – Sine table from rotary embedding.

  • scale (float) – Attention scaling factor applied during RoPE.

  • style (str) – RoPE pairing style (‘half’ or ‘interleaved’).

Returns:

Pre-RoPE tensor with RoPE rotation undone.

Return type:

Tensor

rotate_half(x, *, style='half')

Rotate tensor for RoPE. Supports ‘half’ (front/back) and ‘interleaved’ (even/odd).

Parameters:
  • x (Tensor)

  • style (str)

Return type:

Tensor

score_keys_for_round(key_indices, round_start, amp, phi, omega, extra, offsets, aggregation, freq_scale_sq, disable_trig=False)

Score cached keys for a single pruning round.

Evaluates the trigonometric importance formula over multiple future offsets and aggregates scores.

Parameters:
  • key_indices (Tensor) – Position indices of cached keys, shape (num_keys,).

  • round_start (int) – Current generation position.

  • amp (Tensor) – Amplitude per key per frequency, shape (num_keys, freq_count).

  • phi (Tensor) – Phase per key per frequency, shape (num_keys, freq_count).

  • omega (Tensor) – RoPE frequencies (inv_freq), shape (freq_count,).

  • extra (Tensor) – MLR extra term, shape (num_keys, freq_count).

  • offsets (Tensor) – Geometric offsets for future distance sampling, shape (num_offsets,).

  • aggregation (str) – ‘mean’ or ‘max’ over offsets.

  • freq_scale_sq (Tensor) – Per-frequency scaling weights, shape (freq_count,).

  • disable_trig (bool) – If True, use only the additive (MLR) term.

Returns:

Importance scores, shape (num_keys,). Higher = more important.

Return type:

Tensor

select_keys_to_keep(scores, *, kv_budget=None)

Select which keys to retain based on importance scores.

Parameters:
  • scores (Tensor) – Importance scores, shape (num_keys,). Higher = more important.

  • kv_budget (int | None) – Absolute number of tokens to retain. Keeps top-K. If budget >= num_keys, keeps all.

Returns:

Boolean mask, shape (num_keys,). True = keep, False = evict.

Return type:

Tensor

to_complex_pairs(tensor, *, style='half')

Convert real tensor to complex representation for frequency analysis.

Maps head_dim real values to head_dim/2 complex values. For ‘half’ style: real part = first half of dimensions, imag part = second half.

Parameters:
  • tensor (Tensor) – Real-valued tensor with even last dimension.

  • style (str) – RoPE pairing style (‘half’ or ‘interleaved’).

Returns:

Complex tensor with last dimension halved.

Return type:

Tensor