model_sparsify
Main API functions for sparse attention optimization.
Functions
Applies sparse attention optimization to the model in-place. |
- sparsify(model, config, forward_loop=None)
Applies sparse attention optimization to the model in-place.
This method performs replacement of attention modules with their sparse counterparts.
- Parameters:
model (Module) – A pytorch model
config (dict[str, Any] | SparseAttentionConfig) –
A dictionary or an instance of
SparseAttentionConfigspecifying the values for keys"sparse_cfg"and"method".The
"sparse_cfg"key specifies the sparse attention configurations. The"method"key specifies the sparse attention method (e.g., “flash_skip_softmax”).Sparse attention configurations is a dictionary mapping wildcards or filter functions to its sparse attention attributes. The wildcards or filter functions are matched against the module names. The sparse attention attributes include
"threshold","enable", and method-specific parameters.An example
configdictionary is given below:The
"backend"parameter must be set to"pytorch":"pytorch": Softmax patching approach (only supported backend)
This requires the model to be loaded with
attn_implementation="eager".forward_loop (Callable[[Module], None] | None) –
Reserved for future use.
Here are a few examples for correct
forward_loopdefinitions:Example 1:
def forward_loop(model) -> None: # iterate over the data loader and forward data through the model for batch in data_loader: model(batch)
Example 2:
def forward_loop(model) -> float: # evaluate the model on the task return evaluate(model, task, ....)
Note
Calibration does not require forwarding the entire dataset through the model. Please subsample the dataset or reduce the number of batches if needed.
Important
The model must always be loaded with
attn_implementation="eager"for sparse attention to work correctly:from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( model_path, attn_implementation="eager", # Required for sparse attention torch_dtype=torch.bfloat16, )
This is because sparse attention works by patching torch.nn.functional.softmax, which is only called in the eager attention implementation.
- Returns:
A pytorch model which has sparse attention applied and optionally calibrated.
- Return type:
Module