gradnas
Module implementing gradnas
pruning algorithm for search.
Summary:
gradnas algorithm gives a better score to sort various pruning choices over L1 norm (fastnas) for language models.
Details:
Further, we can get scores for hparams which are implemented even abstractly. For example, we can use this algorithm to sort the heads in a multi-head attention layer. The attention heads do not have a unique tensor parameter associated to it.
We are ranking the prunable choices for a particular hparam based on Sum((gradient of loss wrt pruning mask)^2). The pruning mask of an hparam is a binary mask indicating which choices of the hparam are pruned (0 means pruned and 1 means not pruned).
While calculating the backward gradient of loss, the masks are set to 1 at all tensors. See more about masks being used to measure sensitivity in this paper: https://arxiv.org/pdf/1905.10650.pdf
Classes
Binary searcher for gradient algorithm. |
|
Class for managing gradient data for an hparam. |
- class GradientBinarySearcher
Bases:
BinarySearcher
Binary searcher for gradient algorithm.
- SETUP_GRADIENT_FUNC: Dict[Type[DynamicModule], Callable[[DynamicModule], Tuple[GradientDataManager, RemovableHandle]]]
- before_search()
Setup search with gradient-based score.
- Return type:
None
- property default_search_config: Dict[str, Any]
Get the default config for the searcher.
- static gradnas_score_func(model)
Score function for gradnas algorithm.
If we prune N neurons from layer L, the total degradation is the sum of degradation values of the N pruned neurons. In fast algorithm, the degradation due to pruning is estimated directly from validation_score(model after pruning). Rest of the algorithm is exactly the same as fast algorithm.
- Parameters:
model (Module) –
- Return type:
float
- property hparam_names_for_search: Set[str]
We can only optimize over certain types of hparams in gradient binary search.
- sanitize_search_config(config)
Sanitize the search config dict.
- Parameters:
config (Dict[str, Any] | None) –
- Return type:
Dict[str, Any]
- class GradientDataManager
Bases:
object
Class for managing gradient data for an hparam.
- __init__(shape, model, reduce_func=<function GradientDataManager.<lambda>>)
Initialize GradientDataManager.
- process_gradient()
Process gradient of the mask.
- property score
The score of the hparam based on the stored gradients.