utils
Utility functions for prune-related and search-space related tasks.
Note
Generally, methods in the modelopt.torch.nas module should use these utility
functions directly instead of accessing the SearchSpace class. This is to ensure that
potentially required pre- and post-processing operations are performed correctly.
Classes
Context manager to enable |
|
Context manager to disable |
|
Context manager that sets patches to on or off. |
Functions
Get the inference FLOPs of a PyTorch model. |
|
Print the search space summary. |
|
Return the config dict of all hyperparameters. |
|
Sample searchable hparams using the provided sample_func and return resulting config. |
|
Select the sub-net according to the provided config dict. |
|
Check if modelopt patches for model are enabled. |
|
Context manager to temporarily replace the forward method of the underlying type of a model. |
- class enable_modelopt_patches
Bases:
_DecoratorContextManagerContext manager to enable
modeloptpatches such as those for autonas/fastnas.It can also be used as a decorator (make sure to instantiate with parenthesis).
For example:
modelopt_model.train() modelopt_model(inputs) # architecture changes with mtn.no_modelopt(): with mtn.enable_modelopt(): modelopt_model(inputs) # architecture changes @mtn.enable_modelopt() def forward(model, inputs): return model(inputs) with mtn.no_modelopt(): forward(modelopt_model, inputs) # architecture changes because of decorator on forward
- __init__()
Constructor.
- get_subnet_config(model, configurable=None)
Return the config dict of all hyperparameters.
- Parameters:
model (Module) – A model that contains DynamicModule(s).
configurable (bool | None) – None -> all hps, True -> configurable hps without duplicates
- Returns:
A dict of
(parameter_name, choice)that specifies an active subnet.- Return type:
dict[str, Any]
- inference_flops(network, dummy_input=None, data_shape=None, unit=1000000.0, return_str=False)
Get the inference FLOPs of a PyTorch model.
- Parameters:
network (Module) – The PyTorch model.
args – The dummy input as defined in
mtn.convert().data_shape (tuple[Any, ...] | None) – The shape of the dummy input if the dummy input is a single tensor. If provided,
argsmust beNone.unit (float) – The unit to return the number of parameters in. Default is 1e6 (million).
return_str (bool) – Whether to return the number of FLOPs as a string.
dummy_input (Any | tuple[Any, ...] | None)
- Returns:
The number of inference FLOPs in the given unit as either string or float.
- Return type:
float | str
- is_modelopt_patches_enabled()
Check if modelopt patches for model are enabled.
- Return type:
bool
- class no_modelopt_patches
Bases:
_DecoratorContextManagerContext manager to disable
modeloptpatches to the model.Disabling
modeloptpatches is useful when you want to use the model’s original behavior For example, you can use this to perform a forward pass without NAS operations.It can also be used as a decorator (make sure to instantiate with parenthesis).
For example:
modelopt_model.train() modelopt_model(inputs) # architecture changes with mtn.no_modelopt(): modelopt_model(inputs) # architecture does not change @mtn.no_modelopt() def forward(model, inputs): return model(inputs) forward(modelopt_model, inputs) # architecture does not change
- __init__()
Constructor.
- print_search_space_summary(model, skipped_hparams=['kernel_size'])
Print the search space summary.
- Parameters:
model (Module) – A model that contains DynamicModule(s).
skipped_hparams (list[str])
- Return type:
None
- replace_forward(model, new_forward)
Context manager to temporarily replace the forward method of the underlying type of a model.
The original forward function is temporarily accessible via
model.forward_original.- Parameters:
model (Module) – The model whose type’s forward method is to be temporarily replaced.
new_forward (Callable) – The new forward method. The forward method should either be a bound method to the model instance or take the model (
self) as the first argument.
- Return type:
Iterator[None]
For example:
fake_forward = lambda _: None with replace_forward(model, fake_forward): out = model(inputs) # this output is None out_original = model(inputs) # this output is the original output
- sample(model, sample_func=<function choice>)
Sample searchable hparams using the provided sample_func and return resulting config.
- Parameters:
model (Module) – A searchable model that contains one or more DynamicModule(s).
sample_func (Callable[[Sequence[T]], T] | dict[str, Callable[[Sequence[T]], T]]) – A sampling function for hyperparameters. Default: random sampling.
- Returns:
A dict of
(parameter_name, choice)that specifies an active subnet.- Return type:
dict[str, Any]
- select(model, config, strict=True)
Select the sub-net according to the provided config dict.
- Parameters:
model (Module) – A model that contains DynamicModule(s).
config (dict[str, Any]) – Config of the target subnet as returned by
mtn.config()andmtn.search().strict (bool) – Raise an error when the config does not contain all necessary keys.
- Return type:
None
- class set_modelopt_patches_enabled
Bases:
_DecoratorContextManagerContext manager that sets patches to on or off.
It can be used as context manager or as a function. If used as function, operations are disabled globally (thread local).
- Parameters:
enabled – whether to enable (
True) or disable (False) patched methods.
For example:
modelopt_model.train() modelopt_model(inputs) # architecture changes mtn.set_modelopt_enabled(False) modelopt_model(inputs) # architecture does not change with mtn.set_modelopt_enabled(True): modelopt_model(inputs) # architecture changes modelopt_model(inputs) # architecture does not change
- __init__(enabled)
Constructor.
- Parameters:
enabled (bool)
- clone()
Clone the context manager.