algorithms
High-level search and model design algorithms to help you optimize your model.
Functions
Profile statistics of the search space of a converted model or a regular model. |
|
Search a given prunable model for the best sub-net and return the search model. |
- profile(model, dummy_input=None, constraints=None, deployment=None, strict=False, verbose=True, use_centroid=False)
Profile statistics of the search space of a converted model or a regular model.
- Parameters:
model (Module) – The model to be profiled. Can be converted or not.
dummy_input (Any | tuple[Any, ...] | None) –
Arguments of
model.forward(). This is used for exporting and calculating inference-based metrics, such as latency/FLOPs. The format ofdummy_inputsfollows the convention of theargsargument in torch.onnx.export. Specifically,dummy_inputcan be:a single argument (
type(dummy_input) != tuple) corresponding tomodel.forward(dummy_input)
a tuple of arguments corresponding to
model.forward(*dummy_input)
a tuple of arguments such that
type(dummy_input[-1]) == dictcorresponding tomodel.forward(*dummy_input[:-1], **dummy_input[-1])
Warning
In this case the model’s
forward()method cannot contain keyword-only arguments (e.g.forward(..., *, kw_only_args)) or variable keyword arguments (e.g.forward(..., **kwargs)) since these cannot be sorted into positional arguments.
Note
In order to pass a dict as last non-keyword argument, you need to use a tuple as
dummy_inputand add an empty dict as the last element, e.g.,dummy_input = (x, {"y": y, "z": z}, {})
The empty dict at the end will then be interpreted as the keyword args.
See torch.onnx.export for more info.
Note that if you provide a
{arg_name}with batch sizeb, the results will be computed based on batch sizeb.constraints (dict[str, str | float | dict | None] | ConstraintsFunc | None) –
The dictionary from constraint name to upper bound the searched model has to satisfy. Currently, we support
flopsandparamsas constraints. The constraints dictionary generally takes the following form:constraints = {"params": 5.0e6, "flops": 4.5e8}
Note
We recommend to simply provide the most relevant constraint, e.g., flops:
constraints = {"flops": 4.5e8}
Note that you can also provide a percentage value instead of absolute value, e.g.,
# search for a model with <= 60% of the original model flops constraints = {"flops": "60%"}
strict (bool) – Raise an error if constraints are not satisfiable.
verbose (bool) – Print detailed profiling results.
use_centroid (bool) – By default, profile reports median of the evaluation results from randomly sampled subnets (instead of the evaluation result from deterministic centroid subnet). Set use_centroid to True to use the deterministic centroid for profiling.
deployment (dict[str, str] | None)
- Return type:
tuple[bool, dict[str, dict]]
- Returns: A tuple (is_all_sat, stats) where
is_all_sat is a bool indicating whether all constraints can be satisfied. stats is a dictionary containing statistics for the search space if the model is converted, e.g., the FLOPs and params for the min, centroid, max subnets and their max/min ratios, size of the search space, number of configurable hparams.
- search(model, constraints, dummy_input, config=None)
Search a given prunable model for the best sub-net and return the search model.
The best sub-net maximizes the score given by
score_funcwhile satisfying theconstraints.- Parameters:
model (Module) – The converted model to be searched.
constraints (dict[str, str | float | dict | None]) –
The dictionary from constraint name to upper bound the searched model has to satisfy. Currently, we support
flopsandparamsas constraints. The constraints dictionary generally takes the following form:constraints = {"params": 5.0e6, "flops": 4.5e8}
We recommend to simply provide the most relevant constraint, e.g., flops:
constraints = {"flops": 4.5e8}
You can also provide a percentage value instead of absolute value, e.g.,
# search for a model with <= 60% of the original model flops constraints = {"flops": "60%"}
dummy_input (Any | tuple[Any, ...]) –
Arguments of
model.forward(). This is used for exporting and calculating inference-based metrics, such as latency/FLOPs. The format ofdummy_inputsfollows the convention of theargsargument in torch.onnx.export. Specifically,dummy_inputcan be:a single argument (
type(dummy_input) != tuple) corresponding tomodel.forward(dummy_input)
a tuple of arguments corresponding to
model.forward(*dummy_input)
a tuple of arguments such that
type(dummy_input[-1]) == dictcorresponding tomodel.forward(*dummy_input[:-1], **dummy_input[-1])
Warning
In this case the model’s
forward()method cannot contain keyword-only arguments (e.g.forward(..., *, kw_only_args)) or variable keyword arguments (e.g.forward(..., **kwargs)) since these cannot be sorted into positional arguments.
Note
In order to pass a dict as last non-keyword argument, you need to use a tuple as
dummy_inputand add an empty dict as the last element, e.g.,dummy_input = (x, {"y": y, "z": z}, {})
The empty dict at the end will then be interpreted as the keyword args.
See torch.onnx.export for more info.
Note that if you provide a
{arg_name}with batch sizeb, the results will be computed based on batch sizeb.config (dict[str, Any] | None) –
Additional optional arguments to configure the search. Currently, we support:
checkpoint: Path to save/restore checkpoint with dictionary containing intermediate search state. If provided, the intermediate search state will be automatically restored before search (if exists) and stored/saved during search.verbose: Whether to print detailed search space profiling and search stats during search.forward_loop: ACallablethat takes a model as input and runs a forward loop on it. It is recommended to choose the data loader used inside the forward loop carefully to reduce the runtime. Cannot be provided at the same time asdata_loaderandcollect_func.data_loader: An iterator yielding batches of data for calibrating the normalization layers in the model or compute gradient scores. It is recommended to use the same data loader as for training but with significantly fewer iterations. Cannot be provided at the same time asforward_loop.collect_func: ACallablethat takes a batch of data from the data loader as input and returns the input tomodel.forward()as described inrun_forward_loop. Cannot be provided at the same time asforward_loop.max_iter_data_loader: Maximum number of iterations to run the data loader.score_func: A callable taking the model as input and returning a single accuracy/score metric (float). This metric will be maximized during search.Note
The
score_funcis required forautonasandfastnasmodes. It will be evaluated on models in eval mode (model.eval()).loss_func: ACallablewhich takes the model output (i.e output ofmodel.forward()) and the batch of data as its inputs and returns a scalar loss. This is a required argument if the model is converted viagradnasmode.It should be possible to run a backward pass on the loss value returned by this method.
collect_funcwill be used to gather the inputs tomodel.forward()from a batch of data yielded by``data_loader``.loss_funcshould support the following usage:for i, batch in enumerate(data_loader): if i >= max_iter_data_loader: break # Assuming collect_func returns a tuple of arguments output = model(*collect_func(batch)) loss = loss_func(output, batch) loss.backward()
Note
Additional configuration options may be added by individual algorithms. Please refer to the documentation of the individual algorithms for more information.
- Return type:
tuple[Module, dict[str, Any]]
- Returns: A tuple (subnet, state_dict) where
subnet is the searched subnet (nn.Module), which can be used for subsequent tasks like fine-tuning, state_dict contains the history and detailed stats of the search procedure.
Note
The given model is modified (exported) in-place to match the best subnet found by the search algorithm. The returned subnet is thus a reference to the same model instance as the input model.