network

Utility functions for PyTorch models.

Functions

compare_dict

Compare two dictionaries and return keys with unmatched values.

get_model_attributes

Get the key attributes of a PyTorch model.

get_module_device

Get the device of a PyTorch module.

get_same_padding

Get the same padding for a given kernel size.

init_model_from_model_like

Initialize a model from a model-like object.

is_channels_last

Check if the model is using channels last memory format.

is_parallel

Check if a PyTorch model is parallelized.

make_divisible

Function taken from the original tf repo.

model_to

Convert model to the same device, dtype and memory layout as the target_model.

param_num

Get the number of parameters of a PyTorch model.

param_num_from_forward

Get the number of parameters of a PyTorch model from a forward pass.

remove_bn

Remove all batch normalization layers in the network.

run_forward_loop

Run multiple forward passes with a model according to the provided data loader.

set_submodule

The set function that complements nn.Module.get_submodule().

standardize_model_args

Standardize model arguments according to torch.onnx.export.

standardize_model_like_tuple

Standardize a model-like tuple.

standardize_named_model_args

Standardize model arguments according to torch.onnx.export and give them a name.

standardize_constructor_args

Standardize a constructor-like tuple.

unwrap_model

Unwrap a model that is wrapped by supported wrapper module or return original model.

zero_grad

Set any gradients in the model's parameters to None.

create_param_grad_clear_hook

Create a hook to clear gradients for a parameter.

get_unwrapped_name

Get the cleaned module name (i.e, the name before wrapping with sharded modules).

compare_dict(dict1, dict2)

Compare two dictionaries and return keys with unmatched values.

Parameters:
  • dict1 (Dict[str, Any]) –

  • dict2 (Dict[str, Any]) –

Return type:

Tuple[str, …]

create_param_grad_clear_hook(param)

Create a hook to clear gradients for a parameter.

The hook will be fired after the gradient is accumulated for the parameter. Important: For this to work, accum_grad should be kept alive as longs as this utility is needed.

get_model_attributes(model)

Get the key attributes of a PyTorch model.

Parameters:

model (Module) –

Return type:

Dict[str, Any]

get_module_device(module)

Get the device of a PyTorch module.

Parameters:

module (Module) –

Return type:

device

get_same_padding(kernel_size)

Get the same padding for a given kernel size.

Parameters:

kernel_size (int | Tuple[int, int]) –

Return type:

int | tuple

get_unwrapped_name(name)

Get the cleaned module name (i.e, the name before wrapping with sharded modules).

Parameters:

name (str) –

Return type:

str

init_model_from_model_like(model)

Initialize a model from a model-like object.

Parameters:

model (Module | Type[Module] | Tuple | Callable) – A model-like object. Can be a nn.Module (returned as it is), a model class or callable, or a tuple. If a tuple, it must be of the form (model_cls_or_callable,) or (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs). Model will be initialized as model_cls_or_callable(*args, **kwargs).

Return type:

Module

is_channels_last(model)

Check if the model is using channels last memory format.

Parameters:

model (Module) –

is_parallel(model)

Check if a PyTorch model is parallelized.

Parameters:

model (Module) –

Return type:

bool

make_divisible(v, divisor, min_val=None)

Function taken from the original tf repo.

It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py

Parameters:
  • v (int | float) –

  • divisor (int | None) –

Return type:

int | float

model_to(model, target_model)

Convert model to the same device, dtype and memory layout as the target_model.

Parameters:
  • model (Module) –

  • target_model (Module) –

param_num(network, trainable_only=False, unit=1000000.0)

Get the number of parameters of a PyTorch model.

Parameters:
  • network (Module) – The PyTorch model.

  • trainable_only (bool) – Whether to only count trainable parameters. Default is False.

  • unit – The unit to return the number of parameters in. Default is 1e6 (million).

Returns:

The number of parameters in the model in the given unit.

Return type:

float

param_num_from_forward(model, trainable_only=False, args=None, unit=1000000.0)

Get the number of parameters of a PyTorch model from a forward pass.

Parameters:
  • network – The PyTorch model.

  • trainable_only (bool) – Whether to only count trainable parameters. Default is False.

  • unit (float) – The unit to return the number of parameters in. Default is 1e6 (million).

  • model (Module) –

  • args (Tensor | Tuple | None) –

Returns:

The number of parameters from the model’s forward pass in the given unit.

This can helpful for dynamic modules, where the state dict might contain extra parameters that is not actively used in the model, e.g., because of a DynamicModule that is deactivated for the forward pass. We circumvent this issue by just counting parameters of modules that appear in a forward pass.

remove_bn(model)

Remove all batch normalization layers in the network.

Parameters:

model (Module) –

run_forward_loop(model, data_loader, max_iters=None, collect_func=None, progress_bar=None, post_process=None)

Run multiple forward passes with a model according to the provided data loader.

Parameters:
  • model – The model with which we run forward.

  • data_loader (Iterable) – An iterator with data samples.

  • max_iters (int | None) – Number of batches to run; by default it is infiinite or until data_loader is exhausted.

  • collect_func (Callable[[Any], Any | Tuple] | None) –

    A Callable that takes a batch of data from the data_loader as input and returns the input to model.forward() such that the return value (input) is either:

    1. a single argument (type(input) != tuple) corresponding to

      model.forward(input)
      
    2. a tuple of arguments corresponding to

      model.forward(*input)
      
    3. a tuple of arguments such that type(input[-1]) == dict corresponding to

      model.forward(*input[:-1], **input[-1])
      

    Note

    In order to pass a dict as last non-keyword argument, you need to use a tuple as input and add an empty dict as the last element, e.g.,

    input = (x, {"y": y, "z": z}, {})
    

    The empty dict at the end will then be interpreted as the keyword args.

    See the args argument of torch.onnx.export for more info on the format of the return value of collect_func (input).

    The default collect_func assumes that the data loader returns a tuple, e.g., (images, labels, ...), and returns the first element of the tuple.

  • progress_bar (str | None) – Set to a description string to see the progress bar.

  • post_process (Callable | None) – A callable that takes the model outputs and the data as input and can be used to run any post-processing or operations such as backward pass.

set_submodule(model, target, target_submodule)

The set function that complements nn.Module.get_submodule().

Parameters:
  • model (Module) –

  • target (str) –

  • target_submodule (Module) –

standardize_constructor_args(constructor_args)

Standardize a constructor-like tuple.

Parameters:

constructor_args (Callable | Tuple) –

Return type:

Tuple[Callable, Tuple, Dict]

standardize_model_args(model_or_fw_or_sig, args, use_kwargs=False)

Standardize model arguments according to torch.onnx.export.

Parameters:
  • model_or_fw_or_sig (Module | Callable | Signature) – A nn.Module, its forward method, or its forward method’s signature.

  • args (Any | Tuple) – Refer to the dummy_input parameter in mtn.profile().

  • use_kwargs – Affects the return value, see below. For use_kwargs==False, the returned args are also compatible with torch.onnx.export.

Returns:

Standardized model args that can be used in model.forward() in the same standardized way no matter how they were provided, see below for more info.

Return type:

Tuple

  • If use_kwargs == False, the returned args can be used as

    args = standardize_model_args(model, args, use_kwargs=False)
    model(*args)
    
  • If use_kwargs == True, the returned args can be used as

    args = standardize_model_args(model, args, use_kwargs=True)
    model.forward(*args[:-1], **args[-1])
    

Warning

If use_kwargs == False the model’s forward() method cannot contain keyword-only arguments (e.g. forward(..., *, kw_only_args)) without default values and you must not provide them in args.

Warning

If use_kwargs == False you must not provide variable keyword arguments in args that are processed via variable keyword arguments in the model’s forward() method (e.g. forward(..., **kwargs)).

standardize_model_like_tuple(model)

Standardize a model-like tuple.

Parameters:

model (Module | Type[Module] | Tuple | Callable) –

Return type:

Tuple[Type[Module], Tuple, Dict]

standardize_named_model_args(model_or_fw_or_sig, args)

Standardize model arguments according to torch.onnx.export and give them a name.

Parameters:
  • model_or_fw_or_sig (Module | Callable | Signature) – A nn.Module, its forward method, or its forward method’s signature.

  • args (Any | Tuple) – A tuple of args/kwargs or torch.Tensor feed into the model’s forward() method.

Return type:

Tuple[Dict[str, Any], Set[str]]

Returns: A tuple (args_normalized, args_with_default) where
args_normalized is a dictionary of ordered model args where the key represents a unique

serialized string based on the the argument’s name in the function signature and the value contains the actual argument,

args_with_default is a set indicating whether the argument was retrieved from the default

value in the function signature of the model’s forward() method or whether the argument exactly corresponds to the default value.

Note

See standardize_model_args() for more info as well.

unwrap_model(model, warn=False, raise_error=False, msg='', force_unwrap=False)

Unwrap a model that is wrapped by supported wrapper module or return original model.

Parameters:
  • model (Module) –

  • warn (bool) –

  • raise_error (bool) –

  • msg (str) –

  • force_unwrap (bool) –

Return type:

Module

zero_grad(model)

Set any gradients in the model’s parameters to None.

Parameters:

model (Module) –

Return type:

None