network
Utility functions for PyTorch models.
Functions
Compare two dictionaries and return keys with unmatched values. |
|
Get the key attributes of a PyTorch model. |
|
Get the device of a PyTorch module. |
|
Get the same padding for a given kernel size. |
|
Initialize a model from a model-like object. |
|
Check if the model is using channels last memory format. |
|
Check if a PyTorch model is parallelized. |
|
Function taken from the original tf repo. |
|
Convert model to the same device, dtype and memory layout as the target_model. |
|
Get the number of parameters of a PyTorch model. |
|
Get the number of parameters of a PyTorch model from a forward pass. |
|
Remove all batch normalization layers in the network. |
|
Run multiple forward passes with a model according to the provided data loader. |
|
The set function that complements nn.Module.get_submodule(). |
|
Standardize model arguments according to torch.onnx.export. |
|
Standardize a model-like tuple. |
|
Standardize model arguments according to torch.onnx.export and give them a name. |
|
Standardize a constructor-like tuple. |
|
Unwrap a model that is wrapped by supported wrapper module or return original model. |
|
Set any gradients in the model's parameters to None. |
|
Create a hook to clear gradients for a parameter. |
|
Get the cleaned module name (i.e, the name before wrapping with sharded modules). |
- compare_dict(dict1, dict2)
Compare two dictionaries and return keys with unmatched values.
- Parameters:
dict1 (Dict[str, Any]) –
dict2 (Dict[str, Any]) –
- Return type:
Tuple[str, …]
- create_param_grad_clear_hook(param)
Create a hook to clear gradients for a parameter.
The hook will be fired after the gradient is accumulated for the parameter. Important: For this to work,
accum_grad
should be kept alive as longs as this utility is needed.
- get_model_attributes(model)
Get the key attributes of a PyTorch model.
- Parameters:
model (Module) –
- Return type:
Dict[str, Any]
- get_module_device(module)
Get the device of a PyTorch module.
- Parameters:
module (Module) –
- Return type:
device
- get_same_padding(kernel_size)
Get the same padding for a given kernel size.
- Parameters:
kernel_size (int | Tuple[int, int]) –
- Return type:
int | tuple
- get_unwrapped_name(name)
Get the cleaned module name (i.e, the name before wrapping with sharded modules).
- Parameters:
name (str) –
- Return type:
str
- init_model_from_model_like(model)
Initialize a model from a model-like object.
- Parameters:
model (Module | Type[Module] | Tuple | Callable) – A model-like object. Can be a nn.Module (returned as it is), a model class or callable, or a tuple. If a tuple, it must be of the form (model_cls_or_callable,) or (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs). Model will be initialized as
model_cls_or_callable(*args, **kwargs)
.- Return type:
Module
- is_channels_last(model)
Check if the model is using channels last memory format.
- Parameters:
model (Module) –
- is_parallel(model)
Check if a PyTorch model is parallelized.
- Parameters:
model (Module) –
- Return type:
bool
- make_divisible(v, divisor, min_val=None)
Function taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
- Parameters:
v (int | float) –
divisor (int | None) –
- Return type:
int | float
- model_to(model, target_model)
Convert model to the same device, dtype and memory layout as the target_model.
- Parameters:
model (Module) –
target_model (Module) –
- param_num(network, trainable_only=False, unit=1000000.0)
Get the number of parameters of a PyTorch model.
- Parameters:
network (Module) – The PyTorch model.
trainable_only (bool) – Whether to only count trainable parameters. Default is False.
unit – The unit to return the number of parameters in. Default is 1e6 (million).
- Returns:
The number of parameters in the model in the given unit.
- Return type:
float
- param_num_from_forward(model, trainable_only=False, args=None, unit=1000000.0)
Get the number of parameters of a PyTorch model from a forward pass.
- Parameters:
network – The PyTorch model.
trainable_only (bool) – Whether to only count trainable parameters. Default is False.
unit (float) – The unit to return the number of parameters in. Default is 1e6 (million).
model (Module) –
args (Tensor | Tuple | None) –
- Returns:
The number of parameters from the model’s forward pass in the given unit.
This can helpful for dynamic modules, where the state dict might contain extra parameters that is not actively used in the model, e.g., because of a DynamicModule that is deactivated for the forward pass. We circumvent this issue by just counting parameters of modules that appear in a forward pass.
- remove_bn(model)
Remove all batch normalization layers in the network.
- Parameters:
model (Module) –
- run_forward_loop(model, data_loader, max_iters=None, collect_func=None, progress_bar=None, post_process=None)
Run multiple forward passes with a model according to the provided data loader.
- Parameters:
model – The model with which we run forward.
data_loader (Iterable) – An iterator with data samples.
max_iters (int | None) – Number of batches to run; by default it is infiinite or until
data_loader
is exhausted.collect_func (Callable[[Any], Any | Tuple] | None) –
A
Callable
that takes a batch of data from thedata_loader
as input and returns the input tomodel.forward()
such that the return value (input
) is either:a single argument (
type(input) != tuple
) corresponding tomodel.forward(input)
a tuple of arguments corresponding to
model.forward(*input)
a tuple of arguments such that
type(input[-1]) == dict
corresponding tomodel.forward(*input[:-1], **input[-1])
Note
In order to pass a dict as last non-keyword argument, you need to use a tuple as
input
and add an empty dict as the last element, e.g.,input = (x, {"y": y, "z": z}, {})
The empty dict at the end will then be interpreted as the keyword args.
See the
args
argument of torch.onnx.export for more info on the format of the return value ofcollect_func
(input
).The default
collect_func
assumes that the data loader returns a tuple, e.g.,(images, labels, ...)
, and returns the first element of the tuple.progress_bar (str | None) – Set to a description string to see the progress bar.
post_process (Callable | None) – A callable that takes the model outputs and the data as input and can be used to run any post-processing or operations such as backward pass.
- set_submodule(model, target, target_submodule)
The set function that complements nn.Module.get_submodule().
- Parameters:
model (Module) –
target (str) –
target_submodule (Module) –
- standardize_constructor_args(constructor_args)
Standardize a constructor-like tuple.
- Parameters:
constructor_args (Callable | Tuple) –
- Return type:
Tuple[Callable, Tuple, Dict]
- standardize_model_args(model_or_fw_or_sig, args, use_kwargs=False)
Standardize model arguments according to torch.onnx.export.
- Parameters:
model_or_fw_or_sig (Module | Callable | Signature) – A nn.Module, its forward method, or its forward method’s signature.
args (Any | Tuple) – Refer to the
dummy_input
parameter inmtn.profile()
.use_kwargs – Affects the return value, see below. For
use_kwargs==False
, the returned args are also compatible withtorch.onnx.export
.
- Returns:
Standardized model args that can be used in
model.forward()
in the same standardized way no matter how they were provided, see below for more info.- Return type:
Tuple
If
use_kwargs == False
, the returned args can be used asargs = standardize_model_args(model, args, use_kwargs=False) model(*args)
If
use_kwargs == True
, the returned args can be used asargs = standardize_model_args(model, args, use_kwargs=True) model.forward(*args[:-1], **args[-1])
Warning
If
use_kwargs == False
the model’sforward()
method cannot contain keyword-only arguments (e.g.forward(..., *, kw_only_args)
) without default values and you must not provide them inargs
.Warning
If
use_kwargs == False
you must not provide variable keyword arguments inargs
that are processed via variable keyword arguments in the model’sforward()
method (e.g.forward(..., **kwargs)
).
- standardize_model_like_tuple(model)
Standardize a model-like tuple.
- Parameters:
model (Module | Type[Module] | Tuple | Callable) –
- Return type:
Tuple[Type[Module], Tuple, Dict]
- standardize_named_model_args(model_or_fw_or_sig, args)
Standardize model arguments according to torch.onnx.export and give them a name.
- Parameters:
model_or_fw_or_sig (Module | Callable | Signature) – A nn.Module, its forward method, or its forward method’s signature.
args (Any | Tuple) – A tuple of args/kwargs or torch.Tensor feed into the model’s
forward()
method.
- Return type:
Tuple[Dict[str, Any], Set[str]]
- Returns: A tuple (args_normalized, args_with_default) where
- args_normalized is a dictionary of ordered model args where the key represents a unique
serialized string based on the the argument’s name in the function signature and the value contains the actual argument,
- args_with_default is a set indicating whether the argument was retrieved from the default
value in the function signature of the model’s
forward()
method or whether the argument exactly corresponds to the default value.
Note
See
standardize_model_args()
for more info as well.
- unwrap_model(model, warn=False, raise_error=False, msg='', force_unwrap=False)
Unwrap a model that is wrapped by supported wrapper module or return original model.
- Parameters:
model (Module) –
warn (bool) –
raise_error (bool) –
msg (str) –
force_unwrap (bool) –
- Return type:
Module
- zero_grad(model)
Set any gradients in the model’s parameters to None.
- Parameters:
model (Module) –
- Return type:
None