methods

Sparse attention methods package.

Classes

SparseAttentionMethod

Base class for sparse attention methods.

Functions

get_sparse_method

Get sparse attention method by name and optional version.

register_sparse_method

Decorator to register sparse attention methods with version support.

class SparseAttentionMethod

Bases: ABC

Base class for sparse attention methods.

__init__()

Initialize base sparse attention method.

abstract apply_sparsity(attention_scores, sparse_mask=None)

Apply sparsity mask to attention scores.

Parameters:
  • attention_scores (Tensor) – Pre-softmax attention scores [batch, heads, seq_q, seq_k]

  • sparse_mask (Tensor | None) – Optional pre-computed mask. If None, calculates internally.

Returns:

Masked attention scores with sparse elements set to -inf

Return type:

Tensor

abstract calculate_sparsity(attention_scores)

Calculate sparsity mask and statistics without applying.

Parameters:

attention_scores (Tensor) – Pre-softmax attention scores [batch, heads, seq_q, seq_k]

Returns:

  • sparse_mask: Boolean tensor indicating which elements to keep

  • stats_dict: Dictionary with sparsity statistics

Return type:

Tuple of (sparse_mask, stats_dict) where

get_threshold_info()

Get threshold information for display/debugging.

Returns:

  • ‘type’: ‘static’, ‘dynamic’, or ‘none’

  • ’value’: threshold value (for static)

  • ’scale_factor’: scale factor (for dynamic)

  • Other method-specific info

Return type:

Dictionary with threshold information. Should include

abstract property name: str

Method name identifier.

get_sparse_method(name, version=None)

Get sparse attention method by name and optional version.

Parameters:
  • name (str) – Method name to retrieve

  • version (str | None) – Optional version string. If None, uses latest version.

Returns:

Method class

Raises:

ValueError – If method name or version is not registered

Return type:

type[SparseAttentionMethod]

Example

>>> get_sparse_method("flash_skip_softmax")  # Latest version
>>> get_sparse_method("flash_skip_softmax", "v1")  # Specific version
register_sparse_method(name, version='v1')

Decorator to register sparse attention methods with version support.

Parameters:
  • name (str) – Method name to register

  • version (str) – Version string (default: “v1”)

Example:

@register_sparse_method("my_method", version="v3")
class MyMethodV3(SparseAttentionMethod): ...