# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import operator
from typing import Any, Dict, Iterator, List, Tuple, Union, Set, Sequence, TypeVar
from tripy import export, utils
from tripy.common.exception import raise_error, str_from_stack_info
from tripy.frontend.module.parameter import Parameter
from tripy.logging import logger
T = TypeVar("T")
def _check_param_compatible(original_param, new_param, param_name):
if not isinstance(original_param, Parameter):
is_compatible = original_param._is_compatible(new_param)
if not is_compatible:
f"For parameter: {param_name}, new parameter is not compatible with the existing parameter.",
def _is_homogeneous_container(container: Sequence, typ: T):
return all(isinstance(elem, typ) for elem in container)
def _contains_types(container: Sequence, types: type):
return any(any(isinstance(elem, typ) for typ in types) for elem in container)
class Module:
Base class used to define neural network modules.
You can nest modules by assigning them as attributes of other modules.
Child modules or :class:`tripy.Parameter` s may be contained in Python ``list``\s or ``dict``\s.
If using ``dict``\s, the keys must be strings.
Nested data structures (for example, ``list``\s of ``list``\s) are not supported.
Taking child modules as an example, this is allowed:
self.linear = tp.Linear(2, 2)
self.list_modules = [tp.Linear(2, 2), tp.Linear(2, 2)]
self.dict_modules = {
"linear": tp.Linear(2, 2),
"layernorm": tp.LayerNorm(2),
Whereas this is not supported:
self.list_modules = [[tp.Linear(2, 2)], [tp.Linear(2, 2)]]
self.dict_modules = {
(1, "linear"): tp.Linear(2, 2),
.. code-block:: python
:caption: Example
class AddBias(tp.Module):
def __init__(self):
self.bias = tp.Parameter(tp.Tensor([1.0, 1.0], dtype=tp.float32))
def __call__(self, x):
return x + self.bias
add_bias = AddBias()
input = tp.Tensor([1.0, 1.0], dtype=tp.float32)
output = add_bias(input)
assert np.array_equal(cp.from_dlpack(output).get(), np.array([2.0, 2.0]))
def __setattr__(self, name: str, value: Any) -> None:
if isinstance(value, Parameter) or name in dict(self.named_parameters()):
_check_param_compatible(getattr(self, name, None), value, name)
super().__setattr__(name, value)
# avoid infinite recursion during initialization
if value is None:
if isinstance(value, List) or isinstance(value, Dict):
container = value if isinstance(value, List) else value.values()
if _contains_types(container, [Parameter, Module]) and (
not _is_homogeneous_container(container, Parameter) and not _is_homogeneous_container(container, Module)
stack_info = utils.get_stack_info()
stack_info_msg = str_from_stack_info(stack_info)
"A container of mixed types will not be registered with this module's state_dict()."
+ (f"\nNote: container was set here: {stack_info_msg}" if stack_info_msg else "")
def state_dict(self) -> Dict[str, Parameter]:
Returns a dictionary mapping names to parameters in the module.
This will recurse over any nested child modules.
A dictionary mapping names to parameters.
.. code-block:: python
:caption: Example
# doc: print-locals state_dict
class MyModule(tp.Module):
def __init__(self):
self.param = tp.Parameter(tp.ones((2,), dtype=tp.float32))
self.linear1 = tp.Linear(2, 2)
self.linear2 = tp.Linear(2, 2)
module = MyModule()
state_dict = module.state_dict()
assert set(state_dict.keys()) == {"param", "linear1.weight", "linear1.bias", "linear2.weight", "linear2.bias"}
state_dict = copy.copy(dict(self.named_parameters()))
for child_name, child in self.named_children():
child_state_dict = child.state_dict()
for name, param in child_state_dict.items():
# We add a prefix for any parameters coming from nested modules
# so they can be disambiguated correctly in higher level modules.
state_dict[f"{child_name}.{name}"] = param
return state_dict
def load_state_dict(self, state_dict: Dict[str, Parameter], strict: bool = True) -> Tuple[Set[str], Set[str]]:
Loads parameters from the provided ``state_dict`` into the current module.
This will recurse over any nested child modules.
state_dict: A dictionary mapping names to parameters.
strict: If True, keys in ``state_dict`` must exactly match those in this module. If not,
an error will be raised.
A ``tuple`` of two ``set``\s of strings representing:
- missing_keys: keys that are expected by this module but not provided in ``state_dict``.
- unexpected_keys: keys that are not expected by this module but provided in ``state_dict``.
.. code-block:: python
:caption: Example
# doc: no-print-locals
class MyModule(tp.Module): # doc: omit
def __init__(self): # doc: omit
super().__init__() # doc: omit
self.param = tp.Parameter(tp.ones((2,), dtype=tp.float32)) # doc: omit
self.linear1 = tp.Linear(2, 2) # doc: omit
self.linear2 = tp.Linear(2, 2) # doc: omit
module = MyModule() # doc: omit
state_dict = module.state_dict() # doc: omit
# Using the `module` and `state_dict` from the `state_dict()` example:
print(f"Before: {module.param}")
state_dict["param"] = tp.Parameter(tp.zeros((2,), dtype=tp.float32))
print(f"After: {module.param}")
assert np.array_equal(cp.from_dlpack(module.state_dict()["param"]).get(), np.array(np.zeros((2,), dtype=np.float32)))
.. seealso:: :func:`state_dict`
def find_module(module: Union[Module, List, Dict], sub_strs: List[str]):
while sub_strs:
child_name = sub_strs.pop(0)
if isinstance(module, list):
module = module[int(child_name)]
elif isinstance(module, dict):
module = module[child_name]
elif isinstance(module, Module):
module = operator.attrgetter(child_name)(module)
return module
expected_keys = set(self.state_dict())
provided_keys = set(state_dict)
missing_keys = expected_keys - provided_keys
unexpected_keys = provided_keys - expected_keys
if strict and (missing_keys or unexpected_keys):
details = []
if missing_keys:
details.append(f"Missing keys: {missing_keys}\n")
if unexpected_keys:
details.append(f"Unexpected keys:\n{unexpected_keys}\n\nNote: Expected keys were:\n{expected_keys}")
"state_dict is incompatible.",
for nested_attr_name, param in state_dict.items():
if nested_attr_name in unexpected_keys:
submodule_name, _, param_name = nested_attr_name.rpartition(".")
# If there is no submodule, it means we are accessing a parameter of self
module = self
if submodule_name:
# try to access module.submodule_name as it's the most common case
module = operator.attrgetter(submodule_name)(self)
except AttributeError:
logger.verbose(f"Cannot access {submodule_name} directly, trying to find the correct module.")
# find module starting from the beginning
module = find_module(module, submodule_name.split("."))
if not isinstance(param, Parameter):
param = Parameter(param)
if isinstance(module, Module):
_check_param_compatible(getattr(module, param_name), param, nested_attr_name)
setattr(module, param_name, param)
elif isinstance(module, list):
_check_param_compatible(module[int(param_name)], param, nested_attr_name)
module[int(param_name)] = param
elif isinstance(module, dict):
_check_param_compatible(module[param_name], param, nested_attr_name)
module[param_name] = param
return (missing_keys, unexpected_keys)
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
Returns an iterator over immediate children of this module, yielding tuples
containing the name of the child module and the child module itself.
An iterator over tuples containing the name of the child module and the child module itself.
.. code-block:: python
:caption: Example
# doc: no-print-locals
class StackedLinear(tp.Module):
def __init__(self):
self.linear1 = tp.Linear(2, 2)
self.linear2 = tp.Linear(2, 2)
stacked_linear = StackedLinear()
for name, module in stacked_linear.named_children():
print(f"{name}: {type(module).__name__}")
assert [name for name, _ in stacked_linear.named_children()] == ["linear1", "linear2"]
yield from self._iterate_members_of_type(Module)
def named_parameters(self) -> Iterator[Tuple[str, Parameter]]:
An iterator over tuples containing the name of a parameter and the parameter itself.
.. code-block:: python
:caption: Example
# doc: no-print-locals
class Linear(tp.Module):
def __init__(self):
self.alpha = tp.Parameter(1)
self.beta = tp.Parameter(2)
linear = Linear()
for name, parameter in linear.named_parameters():
print(f"{name}: {parameter}")
assert [name for name, _ in linear.named_parameters()] == ["alpha", "beta"]
yield from self._iterate_members_of_type(Parameter)
def _iterate_members_of_type(self, typ: T) -> Iterator[Tuple[str, T]]:
for name, value in vars(self).items():
if isinstance(value, typ):
yield name, value
elif isinstance(value, List) and _is_homogeneous_container(value, typ):
for i, obj in enumerate(value):
yield f"{name}.{i}", obj
elif isinstance(value, Dict) and _is_homogeneous_container(value.values(), typ):
for key, obj in value.items():
yield f"{name}.{key}", obj
def __str__(self):
from textwrap import indent
class_name = self.__class__.__name__
module_str = f"{class_name}(\n"
# Add children with hierarchical indentation
for name, child in self.named_children():
c = indent(str(child), prefix=" ")
module_str += f" {name}=\n{c},\n"
# Add parameters with hierarchical indentation
for name, param in self.named_parameters():
module_str += f" {name}={param.shape},\n"
module_str += f")"
return module_str