
Module implementing gradnas pruning algorithm for search.


gradnas algorithm gives a better score to sort various pruning choices over L1 norm (fastnas) for language models.


Further, we can get scores for hparams which are implemented even abstractly. For example, we can use this algorithm to sort the heads in a multi-head attention layer. The attention heads do not have a unique tensor parameter associated to it.

We are ranking the prunable choices for a particular hparam based on Sum((gradient of loss wrt pruning mask)^2). The pruning mask of an hparam is a binary mask indicating which choices of the hparam are pruned (0 means pruned and 1 means not pruned).

While calculating the backward gradient of loss, the masks are set to 1 at all tensors. See more about masks being used to measure sensitivity in this paper:



Binary searcher for gradient algorithm.


Class for managing gradient data for an hparam.

class GradientBinarySearcher

Bases: BinarySearcher

Binary searcher for gradient algorithm.

SETUP_GRADIENT_FUNC: Dict[Type[DynamicModule], Callable[[DynamicModule], Tuple[GradientDataManager, RemovableHandle]]]

Setup search with gradient-based score.

Return type:


property default_search_config: Dict[str, Any]

Get the default config for the searcher.

static gradnas_score_func(model)

Score function for gradnas algorithm.

If we prune N neurons from layer L, the total degradation is the sum of degradation values of the N pruned neurons. In fast algorithm, the degradation due to pruning is estimated directly from validation_score(model after pruning). Rest of the algorithm is exactly the same as fast algorithm.


model (Module) –

Return type:


We can only optimize over certain types of hparams in gradient binary search.


Sanitize the search config dict.


config (Dict[str, Any] | None) –

Return type:

Dict[str, Any]

class GradientDataManager

Bases: object

Class for managing gradient data for an hparam.

__init__(shape, model, reduce_func=<function GradientDataManager.<lambda>>)

Initialize GradientDataManager.


Process gradient of the mask.

property score

The score of the hparam based on the stored gradients.