utils
Utility functions for prune-related and search-space related tasks.
Note
Generally, methods in the modelopt.torch.nas
module should use these utility
functions directly instead of accessing the SearchSpace
class. This is to ensure that
potentially required pre- and post-processing operations are performed correctly.
Classes
Context manager to enable |
|
Context manager to disable |
|
Context manager that sets patches to on or off. |
Functions
Get the inference FLOPs of a PyTorch model. |
|
Print the search space summary. |
|
Return the config dict of all hyperparameters. |
|
Sample searchable hparams using the provided sample_func and return resulting config. |
|
Select the sub-net according to the provided config dict. |
|
Check if modelopt patches for model are enabled. |
|
Context manager to temporarily replace the forward method of the underlying type of a model. |
- class enable_modelopt_patches
Bases:
_DecoratorContextManager
Context manager to enable
modelopt
patches such as those for autonas/fastnas.It can also be used as a decorator (make sure to instantiate with parenthesis).
For example:
modelopt_model.train() modelopt_model(inputs) # architecture changes with mtn.no_modelopt(): with mtn.enable_modelopt(): modelopt_model(inputs) # architecture changes @mtn.enable_modelopt() def forward(model, inputs): return model(inputs) with mtn.no_modelopt(): forward(modelopt_model, inputs) # architecture changes because of decorator on forward
- __init__()
Constructor.
- get_subnet_config(model, configurable=None)
Return the config dict of all hyperparameters.
- Parameters:
model (Module) – A model that contains DynamicModule(s).
configurable (bool | None) – None -> all hps, True -> configurable hps without duplicates
- Returns:
A dict of
(parameter_name, choice)
that specifies an active subnet.- Return type:
Dict[str, Any]
- inference_flops(network, dummy_input=None, data_shape=None, unit=1000000.0, return_str=False)
Get the inference FLOPs of a PyTorch model.
- Parameters:
network (Module) – The PyTorch model.
args – The dummy input as defined in
mtn.convert()
.data_shape (Tuple | None) – The shape of the dummy input if the dummy input is a single tensor. If provided,
args
must beNone
.unit (float) – The unit to return the number of parameters in. Default is 1e6 (million).
return_str (bool) – Whether to return the number of FLOPs as a string.
dummy_input (Any | Tuple | None) –
- Returns:
The number of inference FLOPs in the given unit as either string or float.
- Return type:
float | str
- is_modelopt_patches_enabled()
Check if modelopt patches for model are enabled.
- Return type:
bool
- class no_modelopt_patches
Bases:
_DecoratorContextManager
Context manager to disable
modelopt
patches to the model.Disabling
modelopt
patches is useful when you want to use the model’s original behavior For example, you can use this to perform a forward pass without NAS operations.It can also be used as a decorator (make sure to instantiate with parenthesis).
For example:
modelopt_model.train() modelopt_model(inputs) # architecture changes with mtn.no_modelopt(): modelopt_model(inputs) # architecture does not change @mtn.no_modelopt() def forward(model, inputs): return model(inputs) forward(modelopt_model, inputs) # architecture does not change
- __init__()
Constructor.
- print_search_space_summary(model, skipped_hparams=['kernel_size'])
Print the search space summary.
- Parameters:
model (Module) – A model that contains DynamicModule(s).
skipped_hparams (List[str]) –
- Return type:
None
- replace_forward(model, new_forward)
Context manager to temporarily replace the forward method of the underlying type of a model.
The original forward function is temporarily accessible via
model.forward_original
.- Parameters:
model (Module) – The model whose type’s forward method is to be temporarily replaced.
new_forward (Callable) – The new forward method. The forward method should either be a bound method to the model instance or take the model (
self
) as the first argument.
- Return type:
Iterator[None]
For example:
fake_forward = lambda _: None with replace_forward(model, fake_forward): out = model(inputs) # this output is None out_original = model(inputs) # this output is the original output
- sample(model, sample_func=<function choice>)
Sample searchable hparams using the provided sample_func and return resulting config.
- Parameters:
model (Module) – A searchable model that contains one or more DynamicModule(s).
sample_func (Callable[[Sequence[T]], T] | Dict[str, Callable[[Sequence[T]], T]]) – A sampling function for hyperparameters. Default: random sampling.
- Returns:
A dict of
(parameter_name, choice)
that specifies an active subnet.- Return type:
Dict[str, Any]
- select(model, config, strict=True)
Select the sub-net according to the provided config dict.
- Parameters:
model (Module) – A model that contains DynamicModule(s).
config (Dict[str, Any]) – Config of the target subnet as returned by
mtn.config()
andmtn.search()
.strict (bool) – Raise an error when the config does not contain all necessary keys.
- Return type:
None
- class set_modelopt_patches_enabled
Bases:
_DecoratorContextManager
Context manager that sets patches to on or off.
It can be used as context manager or as a function. If used as function, operations are disabled globally (thread local).
- Parameters:
enabled – whether to enable (
True
) or disable (False
) patched methods.
For example:
modelopt_model.train() modelopt_model(inputs) # architecture changes mtn.set_modelopt_enabled(False) modelopt_model(inputs) # architecture does not change with mtn.set_modelopt_enabled(True): modelopt_model(inputs) # architecture changes modelopt_model(inputs) # architecture does not change
- __init__(enabled)
Constructor.
- Parameters:
enabled (bool) –
- clone()
Clone the context manager.