calibration
Calibration framework for sparse attention methods.
Classes
Dynamic threshold calibrator. |
|
Builder for RULER calibration datasets. |
Functions
Calibrate sparse attention parameters for optimal sparsity. |
- class DynamicThresholdCalibrator
Bases:
objectDynamic threshold calibrator.
The calibration fits the model:
t = 1 - exp(-a * (S / (1 - S)) ^ b / L ^ c)
to
(t_j, S_ij, L_i)tuples collected from a forward pass. Taking logs yields a linear model in(log a, b, -c):log(-log(1-t_j)) = log(a) + b * logit(S_ij) - c * log(L_i)
which is solved in closed form with
np.linalg.lstsq. At inference time, given target sparsityS, the threshold ist = 1 - exp(-a * (S / (1-S))^b / L^c).- Properties:
Bounded in
(0, 1)by construction (no clamping required).Correct asymptotes:
t->0asS->0orL->inf;t->1asS->1orL->0.
- __init__(threshold_trials=None)
Initialize dynamic threshold calibrator.
- Parameters:
threshold_trials (list[float] | None) – List of thresholds to try during calibration. Should span a range that achieves sparsities from ~10% to ~95%.
- calibrate(model, forward_loop, phase)
Calibrate (a, b, c) for the dynamic threshold model.
Algorithm: set thresholds =
threshold_trialson all modules and run ONE forward pass. Each module returns a sparsity list (one entry per threshold) per sample. For each(t_j, L_i, S_ij)triple, form:y_ij = log(-log(1 - t_j)) x_S = logit(S_ij) = log(S_ij / (1 - S_ij)) x_L = log(L_i)
The model
log(-log(1-t)) = log(a) + b*logit(S) - c*log(L)is linear in(log a, b, -c)and solved withnp.linalg.lstsq.At inference time, given target sparsity
S, the threshold ist = 1 - exp(-a * (S / (1 - S))^b / L^c).- Parameters:
model (Module) – The model with sparse attention modules.
forward_loop (Callable) – Callable that takes model and forwards calibration data.
phase (str) – Phase to calibrate (
'prefill'or'decode').
- Returns:
Dict with calibration results including
a,b,c,r_squared, andnum_data_points.- Return type:
dict[str, Any]
- class RulerDatasetBuilder
Bases:
objectBuilder for RULER calibration datasets.
- __init__(samples, max_seqlen, tokenizer_name_or_path, num_length_bins=4, max_length_filter=65536, seed=42, cache_dir=None, data_dir=None)
Initialize RULER dataset builder.
- Parameters:
samples (int) – Total number of samples to generate (distributed evenly across length bins)
max_seqlen (int) – Maximum sequence length (length bins auto-generated as powers of 2)
tokenizer_name_or_path (str | object) – HuggingFace tokenizer path or tokenizer object
seed (int) – Random seed for reproducibility
num_length_bins (int) – Number of length bins to generate (default: 4)
max_length_filter (int) – Maximum sequence length to keep (default: 65536)
cache_dir (str | None) – Optional cache directory. If None, uses ~/.cache/modelopt/data/
data_dir (str | Path | None) – Optional path to RULER data directory (contains ‘essays’ subdir). Required for NIAH tasks with essay haystack when not using pip default layout.
Note
Length bins are auto-generated as descending powers of 2: [max_seqlen, max_seqlen/2, max_seqlen/4, …] Generation stops when num_length_bins is reached or length < 1024. Subtasks are set to all the difficult tasks defined in RULER_TASKS.
- build_calibration_dataset()
Build the complete calibration dataset.
If cache_dir was set, checks cache first and returns cached data if present. Otherwise generates the dataset, saves to cache (if cache_dir set), and returns.
- Returns:
List of calibration samples with ‘input’ and ‘length’ fields
- Return type:
list[dict[str, Any]]
- calibrate_sparse_attention(model, config, forward_loop=None)
Calibrate sparse attention parameters for optimal sparsity.
Supports both prefill and decode phase calibration with per-phase target sparsity.
- Parameters:
model (Module) – Model with sparse attention modules
config (dict[str, Any]) – Sparse attention configuration dict
forward_loop (Callable | None) – Callable that forwards calibration data through model. If None, auto-generates RULER dataset. Only used for prefill.
- Returns:
Dictionary with calibration results for each phase
- Return type:
dict[str, Any]