gradnas

Module implementing gradnas pruning algorithm for search.

Summary:

gradnas algorithm gives a better score to sort various pruning choices over L1 norm (fastnas) for language models.

Details:

Further, we can get scores for hparams which are implemented even abstractly. For example, we can use this algorithm to sort the heads in a multi-head attention layer. The attention heads do not have a unique tensor parameter associated to it.

We are ranking the prunable choices for a particular hparam based on Sum((gradient of loss wrt pruning mask)^2). The pruning mask of an hparam is a binary mask indicating which choices of the hparam are pruned (0 means pruned and 1 means not pruned).

While calculating the backward gradient of loss, the masks are set to 1 at all tensors. See more about masks being used to measure sensitivity in this paper: https://arxiv.org/pdf/1905.10650.pdf

Classes

GradientBinarySearcher

Binary searcher for gradient algorithm.

GradientDataManager

Class for managing gradient data for an hparam.

class GradientBinarySearcher

Bases: BinarySearcher

Binary searcher for gradient algorithm.

SETUP_GRADIENT_FUNC: Dict[Type[DynamicModule], Callable[[DynamicModule], Tuple[GradientDataManager, RemovableHandle]]]

Setup search with gradient-based score.

Return type:

None

property default_search_config: Dict[str, Any]

Get the default config for the searcher.

static gradnas_score_func(model)

Score function for gradnas algorithm.

If we prune N neurons from layer L, the total degradation is the sum of degradation values of the N pruned neurons. In fast algorithm, the degradation due to pruning is estimated directly from validation_score(model after pruning). Rest of the algorithm is exactly the same as fast algorithm.

Parameters:

model (Module) –

Return type:

float

We can only optimize over certain types of hparams in gradient binary search.

sanitize_search_config(config)

Sanitize the search config dict.

Parameters:

config (Dict[str, Any] | None) –

Return type:

Dict[str, Any]

class GradientDataManager

Bases: object

Class for managing gradient data for an hparam.

__init__(shape, model, reduce_func=<function GradientDataManager.<lambda>>)

Initialize GradientDataManager.

process_gradient()

Process gradient of the mask.

property score

The score of the hparam based on the stored gradients.