sparse_attention

Extensible sparse attention module.

Classes

SparseAttentionModule

Generic sparse attention module wrapper for applying sparsity to attention layers.

class SparseAttentionModule

Bases: DynamicModule

Generic sparse attention module wrapper for applying sparsity to attention layers.

This module wraps existing attention implementations to add sparse attention capabilities by patching torch.nn.functional.softmax.

Forward Flow:

  1. Check if sparse attention is enabled (pass-through if disabled)

  2. Create softmax patch context with sparse_softmax function

  3. Apply sparse attention by patching F.softmax: - Patches torch.nn.functional.softmax with sparse_softmax - sparse_softmax applies method’s sparsity logic before softmax

  4. Forward through original attention with sparsity applied

Requirements:

  • Model must be loaded with attn_implementation=”eager” for proper softmax interception

  • Only PyTorch backend is supported (patches F.softmax)

Attributes:

_enabled: bool

Whether sparse attention is enabled

_method: str

The sparse attention method to use (e.g., “flash_skip_softmax”)

_method_config: dict

Configuration dictionary for the sparse method (threshold, br, bc, etc.)

_sparse_method_instance: SparseAttentionMethod

Instance of the configured sparse attention method

disable()

Disable sparse attention for this module.

enable()

Enable sparse attention for this module.

forward(*args, **kwargs)

Forward with selected sparse attention method.

This method dispatches to the appropriate sparse attention implementation based on the configured method and backend.

get_stats()

Get sparsity statistics from the stats manager.

Returns:

Dictionary with sparsity statistics including ‘average_sparsity’ if available. Returns empty dict (statistics collection will be added in calibration PR).

Return type:

dict

property is_enabled: bool

Check if sparse attention is enabled.

set_from_attribute_config(attribute_cfg=None)

Set sparse attention attributes from configuration.

Similar to TensorQuantizer.set_from_attribute_config.

Parameters:

attribute_cfg (SparseAttentionAttributeConfig | dict | None) – Sparse attention attribute configuration.