methods
Sparse attention methods package.
Classes
Base class for sparse attention methods. |
Functions
Get sparse attention method by name and optional version. |
|
Decorator to register sparse attention methods with version support. |
- class SparseAttentionMethod
Bases:
ABCBase class for sparse attention methods.
- __init__()
Initialize base sparse attention method.
- abstract apply_sparsity(attention_scores, sparse_mask=None)
Apply sparsity mask to attention scores.
- Parameters:
attention_scores (Tensor) – Pre-softmax attention scores [batch, heads, seq_q, seq_k]
sparse_mask (Tensor | None) – Optional pre-computed mask. If None, calculates internally.
- Returns:
Masked attention scores with sparse elements set to -inf
- Return type:
Tensor
- abstract calculate_sparsity(attention_scores)
Calculate sparsity mask and statistics without applying.
- Parameters:
attention_scores (Tensor) – Pre-softmax attention scores [batch, heads, seq_q, seq_k]
- Returns:
sparse_mask: Boolean tensor indicating which elements to keep
stats_dict: Dictionary with sparsity statistics
- Return type:
Tuple of (sparse_mask, stats_dict) where
- get_threshold_info()
Get threshold information for display/debugging.
- Returns:
‘type’: ‘static’, ‘dynamic’, or ‘none’
’value’: threshold value (for static)
’scale_factor’: scale factor (for dynamic)
Other method-specific info
- Return type:
Dictionary with threshold information. Should include
- abstract property name: str
Method name identifier.
- get_sparse_method(name, version=None)
Get sparse attention method by name and optional version.
- Parameters:
name (str) – Method name to retrieve
version (str | None) – Optional version string. If None, uses latest version.
- Returns:
Method class
- Raises:
ValueError – If method name or version is not registered
- Return type:
type[SparseAttentionMethod]
Example
>>> get_sparse_method("flash_skip_softmax") # Latest version >>> get_sparse_method("flash_skip_softmax", "v1") # Specific version
- register_sparse_method(name, version='v1')
Decorator to register sparse attention methods with version support.
- Parameters:
name (str) – Method name to register
version (str) – Version string (default: “v1”)
Example:
@register_sparse_method("my_method", version="v3") class MyMethodV3(SparseAttentionMethod): ...