sparse_attention

Extensible sparse attention module.

Classes

SparseAttentionModule

Generic sparse attention module wrapper for applying sparsity to attention layers.

class SparseAttentionModule

Bases: DynamicModule

Generic sparse attention module wrapper for applying sparsity to attention layers.

This module wraps existing attention implementations to add sparse attention capabilities. The activation mechanism is delegated to the configured method via method.get_sparse_context(module), so each method defines how it integrates with the forward pass (e.g. softmax patching, kernel flags).

Forward Flow:

  1. Check if sparse attention is enabled (pass-through if disabled)

  2. Obtain method-specific context via _sparse_method_instance.get_sparse_context(self)

  3. Run the original forward inside the context

  4. Collect statistics if stats manager is enabled

Attributes:

_enabled: bool

Whether sparse attention is enabled

_method: str

The sparse attention method to use (e.g., “flash_skip_softmax”)

_method_config: dict

Configuration dictionary for the sparse method (threshold, br, bc, etc.)

_sparse_method_instance: SparseAttentionMethod

Instance of the configured sparse attention method

disable()

Disable sparse attention for this module.

enable()

Enable sparse attention for this module.

forward(*args, **kwargs)

Forward with selected sparse attention method.

This method dispatches to the appropriate sparse attention implementation based on the configured method and backend.

get_stats()

Get sparsity statistics from the stats manager.

Returns:

Dictionary with sparsity statistics including ‘average_sparsity’ if available. Returns empty dict if stats manager is not enabled.

Return type:

dict

get_threshold_info()

Get threshold information from the sparse method instance.

Returns:

Dictionary with threshold information from the sparse method.

Return type:

dict[str, Any]

property is_enabled: bool

Check if sparse attention is enabled.

set_from_attribute_config(attribute_cfg=None)

Set sparse attention attributes from configuration.

Similar to TensorQuantizer.set_from_attribute_config.

Parameters:

attribute_cfg (SparseAttentionAttributeConfig | dict | None) – Sparse attention attribute configuration.