utils

Utility functions for prune-related and search-space related tasks.

Note

Generally, methods in the modelopt.torch.nas module should use these utility functions directly instead of accessing the SearchSpace class. This is to ensure that potentially required pre- and post-processing operations are performed correctly.

Classes

enable_modelopt_patches

Context manager to enable modelopt patches such as those for autonas/fastnas.

no_modelopt_patches

Context manager to disable modelopt patches to the model.

set_modelopt_patches_enabled

Context manager that sets patches to on or off.

Functions

inference_flops

Get the inference FLOPs of a PyTorch model.

print_search_space_summary

Print the search space summary.

get_subnet_config

Return the config dict of all hyperparameters.

sample

Sample searchable hparams using the provided sample_func and return resulting config.

select

Select the sub-net according to the provided config dict.

is_modelopt_patches_enabled

Check if modelopt patches for model are enabled.

replace_forward

Context manager to temporarily replace the forward method of the underlying type of a model.

class enable_modelopt_patches

Bases: _DecoratorContextManager

Context manager to enable modelopt patches such as those for autonas/fastnas.

It can also be used as a decorator (make sure to instantiate with parenthesis).

For example:

modelopt_model.train()
modelopt_model(inputs)  # architecture changes

with mtn.no_modelopt():
    with mtn.enable_modelopt():
        modelopt_model(inputs)  # architecture changes


@mtn.enable_modelopt()
def forward(model, inputs):
    return model(inputs)


with mtn.no_modelopt():
    forward(modelopt_model, inputs)  # architecture changes because of decorator on forward
__init__()

Constructor.

get_subnet_config(model, configurable=None)

Return the config dict of all hyperparameters.

Parameters:
  • model (Module) – A model that contains DynamicModule(s).

  • configurable (bool | None) – None -> all hps, True -> configurable hps without duplicates

Returns:

A dict of (parameter_name, choice) that specifies an active subnet.

Return type:

Dict[str, Any]

inference_flops(network, dummy_input=None, data_shape=None, unit=1000000.0, return_str=False)

Get the inference FLOPs of a PyTorch model.

Parameters:
  • network (Module) – The PyTorch model.

  • args – The dummy input as defined in mtn.convert().

  • data_shape (Tuple | None) – The shape of the dummy input if the dummy input is a single tensor. If provided, args must be None.

  • unit (float) – The unit to return the number of parameters in. Default is 1e6 (million).

  • return_str (bool) – Whether to return the number of FLOPs as a string.

  • dummy_input (Any | Tuple | None) –

Returns:

The number of inference FLOPs in the given unit as either string or float.

Return type:

float | str

is_modelopt_patches_enabled()

Check if modelopt patches for model are enabled.

Return type:

bool

class no_modelopt_patches

Bases: _DecoratorContextManager

Context manager to disable modelopt patches to the model.

Disabling modelopt patches is useful when you want to use the model’s original behavior For example, you can use this to perform a forward pass without NAS operations.

It can also be used as a decorator (make sure to instantiate with parenthesis).

For example:

modelopt_model.train()
modelopt_model(inputs)  # architecture changes

with mtn.no_modelopt():
    modelopt_model(inputs)  # architecture does not change


@mtn.no_modelopt()
def forward(model, inputs):
    return model(inputs)


forward(modelopt_model, inputs)  # architecture does not change
__init__()

Constructor.

print_search_space_summary(model, skipped_hparams=['kernel_size'])

Print the search space summary.

Parameters:
  • model (Module) – A model that contains DynamicModule(s).

  • skipped_hparams (List[str]) –

Return type:

None

replace_forward(model, new_forward)

Context manager to temporarily replace the forward method of the underlying type of a model.

The original forward function is temporarily accessible via model.forward_original.

Parameters:
  • model (Module) – The model whose type’s forward method is to be temporarily replaced.

  • new_forward (Callable) – The new forward method. The forward method should either be a bound method to the model instance or take the model (self) as the first argument.

Return type:

Iterator[None]

For example:

fake_forward = lambda _: None

with replace_forward(model, fake_forward):
    out = model(inputs)  # this output is None

out_original = model(inputs)  # this output is the original output
sample(model, sample_func=<function choice>)

Sample searchable hparams using the provided sample_func and return resulting config.

Parameters:
  • model (Module) – A searchable model that contains one or more DynamicModule(s).

  • sample_func (Callable[[Sequence[T]], T] | Dict[str, Callable[[Sequence[T]], T]]) – A sampling function for hyperparameters. Default: random sampling.

Returns:

A dict of (parameter_name, choice) that specifies an active subnet.

Return type:

Dict[str, Any]

select(model, config, strict=True)

Select the sub-net according to the provided config dict.

Parameters:
  • model (Module) – A model that contains DynamicModule(s).

  • config (Dict[str, Any]) – Config of the target subnet as returned by mtn.config() and mtn.search().

  • strict (bool) – Raise an error when the config does not contain all necessary keys.

Return type:

None

class set_modelopt_patches_enabled

Bases: _DecoratorContextManager

Context manager that sets patches to on or off.

It can be used as context manager or as a function. If used as function, operations are disabled globally (thread local).

Parameters:

enabled – whether to enable (True) or disable (False) patched methods.

For example:

modelopt_model.train()
modelopt_model(inputs)  # architecture changes

mtn.set_modelopt_enabled(False)
modelopt_model(inputs)  # architecture does not change

with mtn.set_modelopt_enabled(True):
    modelopt_model(inputs)  # architecture changes

modelopt_model(inputs)  # architecture does not change
__init__(enabled)

Constructor.

Parameters:

enabled (bool) –

clone()

Clone the context manager.