triattention
TriAttention: Trigonometric KV cache compression.
Functions
Build geometric offset sequence [1, 2, 4, 8, ..., max_length]. |
|
Invert RoPE rotation to recover pre-RoPE representation. |
|
Rotate tensor for RoPE. |
|
Convert real tensor to complex representation for frequency analysis. |
- build_geometric_offsets(max_length, device)
Build geometric offset sequence [1, 2, 4, 8, …, max_length].
Used for multi-distance scoring in TriAttention — each offset represents a future distance at which the key’s importance is evaluated.
- Parameters:
max_length (int) – Maximum offset value (must be >= 1).
device (device) – Device for the output tensor.
- Returns:
1D float tensor of powers of 2 up to max_length.
- Return type:
Tensor
- invert_rope(rotated, cos, sin, scale, *, style='half')
Invert RoPE rotation to recover pre-RoPE representation.
- Parameters:
rotated (Tensor) – RoPE-rotated tensor.
cos (Tensor) – Cosine table from rotary embedding.
sin (Tensor) – Sine table from rotary embedding.
scale (float) – Attention scaling factor applied during RoPE.
style (str) – RoPE pairing style (‘half’ or ‘interleaved’).
- Returns:
Pre-RoPE tensor with RoPE rotation undone.
- Return type:
Tensor
- rotate_half(x, *, style='half')
Rotate tensor for RoPE. Supports ‘half’ (front/back) and ‘interleaved’ (even/odd).
- Parameters:
x (Tensor)
style (str)
- Return type:
Tensor
- to_complex_pairs(tensor, *, style='half')
Convert real tensor to complex representation for frequency analysis.
Maps head_dim real values to head_dim/2 complex values. For ‘half’ style: real part = first half of dimensions, imag part = second half.
- Parameters:
tensor (Tensor) – Real-valued tensor with even last dimension.
style (str) – RoPE pairing style (‘half’ or ‘interleaved’).
- Returns:
Complex tensor with last dimension halved.
- Return type:
Tensor