sparse_attention
Extensible sparse attention module.
Classes
Generic sparse attention module wrapper for applying sparsity to attention layers. |
- class SparseAttentionModule
Bases:
DynamicModuleGeneric sparse attention module wrapper for applying sparsity to attention layers.
This module wraps existing attention implementations to add sparse attention capabilities by patching torch.nn.functional.softmax.
Forward Flow:
Check if sparse attention is enabled (pass-through if disabled)
Create softmax patch context with sparse_softmax function
Apply sparse attention by patching F.softmax: - Patches torch.nn.functional.softmax with sparse_softmax - sparse_softmax applies method’s sparsity logic before softmax
Forward through original attention with sparsity applied
Requirements:
Model must be loaded with attn_implementation=”eager” for proper softmax interception
Only PyTorch backend is supported (patches F.softmax)
Attributes:
- _enabled: bool
Whether sparse attention is enabled
- _method: str
The sparse attention method to use (e.g., “flash_skip_softmax”)
- _method_config: dict
Configuration dictionary for the sparse method (threshold, br, bc, etc.)
- _sparse_method_instance: SparseAttentionMethod
Instance of the configured sparse attention method
- disable()
Disable sparse attention for this module.
- enable()
Enable sparse attention for this module.
- forward(*args, **kwargs)
Forward with selected sparse attention method.
This method dispatches to the appropriate sparse attention implementation based on the configured method and backend.
- get_stats()
Get sparsity statistics from the stats manager.
- Returns:
Dictionary with sparsity statistics including ‘average_sparsity’ if available. Returns empty dict (statistics collection will be added in calibration PR).
- Return type:
dict
- property is_enabled: bool
Check if sparse attention is enabled.
- set_from_attribute_config(attribute_cfg=None)
Set sparse attention attributes from configuration.
Similar to TensorQuantizer.set_from_attribute_config.
- Parameters:
attribute_cfg (SparseAttentionAttributeConfig | dict | None) – Sparse attention attribute configuration.