Module#

API#

class warp_nn.modules.module.Module(*args, **kwargs)[source]#

Bases: ABC

Base abstract class for all the modules.

Modules can contain other modules (sub-modules), organized in a nested tree structure. Such sub-modules can be assigned as regular attributes to the parent module.

Important

Sub-modules assigned as regular attributes to the parent module are not registered automatically. Therefore, it is necessary to call the __post_init__() method before exiting the initialization of the module (i.e. at the end of the class constructor).

__call__(*args, **kwargs) Any[source]#

Forward pass of the module.

Raises:

NotImplementedError – If the module subclass does not implement the method.

__post_init__() None[source]#

Register sub-modules and parameters.

Important

A module subclass must call this method to register sub-modules and parameters assigned as regular attributes to it, unless they have already been registered manually.

load_state_dict(
state_dict: dict[str, array],
) None[source]#

Load a state dictionary into the module.

Parameters:

state_dict – The state dictionary to load into the module.

Raises:

NotImplementedError – If the state dictionary contains an unsupported type.

modules() list[Module][source]#

Get the registered modules.

The modules will be returned in the order that they were registered.

Returns:

A list of modules.

named_modules() list[str, Module][source]#

Get the registered modules and their names.

The modules will be returned in the order that they were registered.

Returns:

A tuple of (name, module) pairs.

named_parameters() list[str, Parameter][source]#

Get the registered parameters and their names.

The parameters will be returned in the order that they were registered.

Returns:

A list of (name, parameter) pairs.

parameters(
*,
include_submodules: bool = True,
as_array: bool = True,
) list[Parameter | array][source]#

Get the registered parameters.

The parameters will be returned in the order that they were registered.

Parameters:
  • include_submodules – Whether to include the parameters of registered the sub-modules.

  • as_array – Whether to return the parameters as Warp arrays or as Parameter instances.

Returns:

A list of parameters.

register_module(
name: str,
module: Module,
) Module[source]#

Register a module to the module.

The modules will be registered in the order that this method is called.

Parameters:
  • name – The name of the module.

  • module – The module to register.

Returns:

The module itself.

raises:

TypeError: If the module is not a Module subclass. KeyError: If the module with the same name already exists.

register_parameter(
name: str,
parameter: Parameter,
) Parameter[source]#

Register a parameter to the module.

The parameters will be registered in the order that this method is called.

Parameters:
  • name – The name of the parameter.

  • parameter – The parameter to register.

Returns:

The parameter itself.

raises:

TypeError: If the parameter is not a Parameter subclass. KeyError: If the parameter with the same name already exists.

state_dict(
*,
destination: dict[str, array] | None = None,
prefix: str = '',
) dict[str, array][source]#

Get the state dictionary, which is a reference to all the parameters of the modules and sub-modules.

Parameters:
  • destination – The destination dictionary to store the state dictionary. This argument is used for internal recursion and should not be set by the user.

  • prefix – The prefix to add to the names of the parameters and modules. This argument is used for internal recursion and should not be set by the user.

Returns:

The state dictionary.

to(
device: Device,
) Module[source]#

Move the module to the specified device.

Parameters:

device – The device to move the module to.

Returns:

The module itself.

property device: Device[source]#

Device on which the module is allocated.