sparse_attention
Extensible sparse attention module.
Classes
Generic sparse attention module wrapper for applying sparsity to attention layers. |
- class SparseAttentionModule
Bases:
DynamicModuleGeneric sparse attention module wrapper for applying sparsity to attention layers.
This module wraps existing attention implementations to add sparse attention capabilities. The activation mechanism is delegated to the configured method via
method.get_sparse_context(module), so each method defines how it integrates with the forward pass (e.g. softmax patching, kernel flags).Forward Flow:
Check if sparse attention is enabled (pass-through if disabled)
Obtain method-specific context via
_sparse_method_instance.get_sparse_context(self)Run the original forward inside the context
Collect statistics if stats manager is enabled
Attributes:
- _enabled: bool
Whether sparse attention is enabled
- _method: str
The sparse attention method to use (e.g., “flash_skip_softmax”)
- _method_config: dict
Configuration dictionary for the sparse method (threshold, br, bc, etc.)
- _sparse_method_instance: SparseAttentionMethod
Instance of the configured sparse attention method
- disable()
Disable sparse attention for this module.
- enable()
Enable sparse attention for this module.
- forward(*args, **kwargs)
Forward with selected sparse attention method.
This method dispatches to the appropriate sparse attention implementation based on the configured method and backend.
- get_stats()
Get sparsity statistics from the stats manager.
- Returns:
Dictionary with sparsity statistics including ‘average_sparsity’ if available. Returns empty dict if stats manager is not enabled.
- Return type:
dict
- get_threshold_info()
Get threshold information from the sparse method instance.
- Returns:
Dictionary with threshold information from the sparse method.
- Return type:
dict[str, Any]
- property is_enabled: bool
Check if sparse attention is enabled.
- set_from_attribute_config(attribute_cfg=None)
Set sparse attention attributes from configuration.
Similar to TensorQuantizer.set_from_attribute_config.
- Parameters:
attribute_cfg (SparseAttentionAttributeConfig | dict | None) – Sparse attention attribute configuration.