model_sparsify

Main API functions for sparse attention optimization.

Functions

sparsify

Applies sparse attention optimization to the model in-place.

sparsify(model, config, forward_loop=None)

Applies sparse attention optimization to the model in-place.

This method performs replacement of attention modules with their sparse counterparts.

Parameters:
  • model (Module) – A pytorch model

  • config (dict[str, Any] | SparseAttentionConfig) –

    A dictionary or an instance of SparseAttentionConfig specifying the values for keys "sparse_cfg" and "method".

    The "sparse_cfg" key specifies the sparse attention configurations. The "method" key specifies the sparse attention method (e.g., “flash_skip_softmax”).

    Sparse attention configurations is a dictionary mapping wildcards or filter functions to its sparse attention attributes. The wildcards or filter functions are matched against the module names. The sparse attention attributes include "threshold", "enable", and method-specific parameters.

    An example config dictionary is given below:

    The "backend" parameter must be set to "pytorch":

    • "pytorch": Softmax patching approach (only supported backend)

    This requires the model to be loaded with attn_implementation="eager".

  • forward_loop (Callable[[Module], None] | None) –

    Reserved for future use.

    Here are a few examples for correct forward_loop definitions:

    Example 1:

    def forward_loop(model) -> None:
        # iterate over the data loader and forward data through the model
        for batch in data_loader:
            model(batch)
    

    Example 2:

    def forward_loop(model) -> float:
        # evaluate the model on the task
        return evaluate(model, task, ....)
    

    Note

    Calibration does not require forwarding the entire dataset through the model. Please subsample the dataset or reduce the number of batches if needed.

    Important

    The model must always be loaded with attn_implementation="eager" for sparse attention to work correctly:

    from transformers import AutoModelForCausalLM
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        attn_implementation="eager",  # Required for sparse attention
        torch_dtype=torch.bfloat16,
    )
    

    This is because sparse attention works by patching torch.nn.functional.softmax, which is only called in the eager attention implementation.

Returns:

A pytorch model which has sparse attention applied and optionally calibrated.

Return type:

Module