searcher
Standard interface to implement a searcher algorithm.
A searcher is useful whenever we want to search/optimize over a set of hyperparameters in the model. Searchers are usually used in conjunction with a mode, which can define a search space via its entrypoints, i.e., convert the model into a search space. The searcher then optimizes over this search space.
Classes
A basic search interface that can be used to search/optimize a model. |
- class BaseSearcher
Bases:
ABC
A basic search interface that can be used to search/optimize a model.
The base interface supports basic features like setting up a search, checkpointing, and loading logic and defines a minimal workflow to follow.
- final __init__()
We don’t allow to override __init__ method.
- Return type:
None
- after_search()
Optional post-processing steps after the search.
- Return type:
None
- before_search()
Optional pre-processing steps before the search.
- Return type:
None
- config: Dict[str, Any]
- constraints: Dict[str, str | float | Dict | None]
- construct_forward_loop(silent=True, progress_bar_msg=None, max_iter_data_loader=None, post_process_fn=False)
Get runnable forward loop on the model using the provided configs.
- Return type:
Callable[[Module], None] | None
- property default_search_config: Dict[str, Any]
Get the default config for the searcher.
- abstract property default_state_dict: Dict[str, Any]
Return default state dict.
- deployment: Dict[str, str] | None
- dummy_input: Any | Tuple
- eval_score(silent=True)
Optionally silent evaluation of the score function.
- Return type:
float
- forward_loop: Callable[[Module], None] | None
- property has_score: bool
Check if the model has a score function.
- load_search_checkpoint()
Load function for search checkpoint returning indicator whether checkpoint was loaded.
- Return type:
bool
- model: Module
- reset_search()
Reset search at the beginning.
- Return type:
None
- abstract run_search()
Run actual search.
- Return type:
None
- sanitize_search_config(config)
Sanitize the search config dict.
- Parameters:
config (Dict[str, Any] | None) –
- Return type:
Dict[str, Any]
- save_search_checkpoint()
Save function for search checkpoint.
- Return type:
None
- final search(model, constraints, dummy_input=None, config=None)
Search a given prunable model for the best sub-net and return the search model.
The best sub-net maximizes the score given by
score_func
while satisfying theconstraints
.- Parameters:
model (Module) – The converted model to be searched.
constraints (Dict[str, str | float | Dict | None]) – The dictionary from constraint name to upper bound the searched model has to satisfy.
dummy_input (Any | Tuple | None) – Arguments of
model.forward()
. This is used for exporting and calculating inference-based metrics, such as latency/FLOPs. The format ofdummy_inputs
follows the convention of theargs
argument in torch.onnx.export.config (Dict[str, Any] | None) – Additional optional arguments to configure the search.
- Return type:
Dict[str, Any]
- Returns: A tuple (subnet, state_dict) where
subnet is the searched subnet (nn.Module), which can be used for subsequent tasks like fine-tuning, state_dict contains the history and detailed stats of the search procedure.
- final state_dict()
The state dictionary that can be stored/loaded.
- Return type:
Dict[str, Any]