conversion
Conversion and restoration utilities for sparse attention.
Functions
Convert model to use sparse attention. |
|
Disable sparse attention for matching modules. |
|
Enable sparse attention for matching modules. |
|
Check if a model has sparse attention applied. |
|
Replace regular attention modules with sparse attention modules. |
|
Restore sparse attention model from saved state. |
|
Restore sparse attention state from state dict. |
|
Set sparse attention attributes for modules matching pattern. |
|
Apply sparse attention configuration to model. |
|
Update metadata with sparse attention state. |
- convert_to_sparse_attention_model(model, config)
Convert model to use sparse attention.
- Parameters:
model (ModelLikeModule) – Model to convert
config (SparseAttentionConfig) – Sparse attention configuration
- Returns:
Tuple of (converted_model, metadata)
- Return type:
tuple[Module, dict[str, Any]]
- disable_sparse_attention(model, wildcard_or_filter_func)
Disable sparse attention for matching modules.
Similar to mtq.disable_quantizer for API consistency.
- Parameters:
model (Module) – Model with sparse attention applied
wildcard_or_filter_func (str | Callable) – Wildcard string or filter function to match module names. For example: “lm_head”, “layer_0”, etc.
Example
>>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn >>> model = sparse_attn.sparsify(model, config) >>> # Disable sparse attention for lm_head >>> sparse_attn.disable_sparse_attention(model, "*lm_head*")
- enable_sparse_attention(model, wildcard_or_filter_func)
Enable sparse attention for matching modules.
Similar to mtq.enable_quantizer for API consistency.
- Parameters:
model (Module) – Model with sparse attention applied
wildcard_or_filter_func (str | Callable) – Wildcard string or filter function to match module names. For example: “attention”, “attn”, etc.
Example
>>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn >>> model = sparse_attn.sparsify(model, config) >>> # Re-enable sparse attention for all attention modules >>> sparse_attn.enable_sparse_attention(model, "*attention*")
- is_attn_sparsified(model)
Check if a model has sparse attention applied.
Similar to quantization’s is_quantized for API consistency.
- Parameters:
model (Module) – Model to check
- Returns:
True if model contains any SparseAttentionModule instances
- Return type:
bool
- replace_sparse_attention_modules(model, version=None)
Replace regular attention modules with sparse attention modules.
Recursively replace all attention modules in the model with their sparse attention counterparts.
- Parameters:
model (Module) – Model to process
version – State version for tracking (optional)
- restore_sparse_attention_model(model, config, metadata)
Restore sparse attention model from saved state.
- Parameters:
model (ModelLikeModule) – Model to restore
config (SparseAttentionConfig) – Sparse attention configuration
metadata (dict[str, Any]) – Saved metadata
- Returns:
Restored model
- Return type:
Module
- restore_sparse_attention_state(model, state_dict)
Restore sparse attention state from state dict.
- Parameters:
model (Module) – Model with sparse attention modules
state_dict (dict[str, Any]) – Saved state dictionary
- set_sparse_attention_attribute(model, wildcard_or_filter, attribute_cfg)
Set sparse attention attributes for modules matching pattern.
Similar to quantization’s set_quantizer_attribute.
- Parameters:
model (Module) – Model to configure
wildcard_or_filter (str | Callable) – Pattern to match module names
attribute_cfg (dict[str, Any]) – Attributes to apply (must include ‘method’)
- set_sparse_attention_by_cfg(model, sparse_cfg)
Apply sparse attention configuration to model.
Similar to quantization’s set_quantizer_by_cfg.
- Parameters:
model (Module) – Model with sparse attention modules
sparse_cfg (dict) – Sparse configuration dictionary mapping patterns to attributes
- update_sparse_attention_metadata(model, config, metadata)
Update metadata with sparse attention state.
- Parameters:
model (Module) – Model with sparse attention
config (SparseAttentionConfig) – Configuration used
metadata (dict[str, Any]) – Metadata dict to update
- Return type:
None