triattention

TriAttention: Trigonometric KV cache compression.

Functions

build_geometric_offsets

Build geometric offset sequence [1, 2, 4, 8, ..., max_length].

invert_rope

Invert RoPE rotation to recover pre-RoPE representation.

rotate_half

Rotate tensor for RoPE.

to_complex_pairs

Convert real tensor to complex representation for frequency analysis.

build_geometric_offsets(max_length, device)

Build geometric offset sequence [1, 2, 4, 8, …, max_length].

Used for multi-distance scoring in TriAttention — each offset represents a future distance at which the key’s importance is evaluated.

Parameters:
  • max_length (int) – Maximum offset value (must be >= 1).

  • device (device) – Device for the output tensor.

Returns:

1D float tensor of powers of 2 up to max_length.

Return type:

Tensor

invert_rope(rotated, cos, sin, scale, *, style='half')

Invert RoPE rotation to recover pre-RoPE representation.

Parameters:
  • rotated (Tensor) – RoPE-rotated tensor.

  • cos (Tensor) – Cosine table from rotary embedding.

  • sin (Tensor) – Sine table from rotary embedding.

  • scale (float) – Attention scaling factor applied during RoPE.

  • style (str) – RoPE pairing style (‘half’ or ‘interleaved’).

Returns:

Pre-RoPE tensor with RoPE rotation undone.

Return type:

Tensor

rotate_half(x, *, style='half')

Rotate tensor for RoPE. Supports ‘half’ (front/back) and ‘interleaved’ (even/odd).

Parameters:
  • x (Tensor)

  • style (str)

Return type:

Tensor

to_complex_pairs(tensor, *, style='half')

Convert real tensor to complex representation for frequency analysis.

Maps head_dim real values to head_dim/2 complex values. For ‘half’ style: real part = first half of dimensions, imag part = second half.

Parameters:
  • tensor (Tensor) – Real-valued tensor with even last dimension.

  • style (str) – RoPE pairing style (‘half’ or ‘interleaved’).

Returns:

Complex tensor with last dimension halved.

Return type:

Tensor