#
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import operator
from typing import Any, Dict, Iterator, List, Tuple, Union, Set, Sequence, TypeVar
from tripy import export, utils
from tripy.common.exception import raise_error, str_from_stack_info
from tripy.frontend.module.parameter import Parameter
from tripy.logging import logger
T = TypeVar("T")
def _check_param_compatible(original_param, new_param, param_name):
if not isinstance(original_param, Parameter):
return
is_compatible = original_param._is_compatible(new_param)
if not is_compatible:
raise_error(
f"For parameter: {param_name}, new parameter is not compatible with the existing parameter.",
details=is_compatible.error_details,
)
def _is_homogeneous_container(container: Sequence, typ: T):
return all(isinstance(elem, typ) for elem in container)
def _contains_types(container: Sequence, types: type):
return any(any(isinstance(elem, typ) for typ in types) for elem in container)
[docs]
@export.public_api(document_under="modules/index.rst")
class Module:
r"""
Base class used to define neural network modules.
You can nest modules by assigning them as attributes of other modules.
Child modules or :class:`tripy.Parameter` s may be contained in Python ``list``\s or ``dict``\s.
If using ``dict``\s, the keys must be strings.
Nested data structures (for example, ``list``\s of ``list``\s) are not supported.
Taking child modules as an example, this is allowed:
::
self.linear = tp.Linear(2, 2)
self.list_modules = [tp.Linear(2, 2), tp.Linear(2, 2)]
self.dict_modules = {
"linear": tp.Linear(2, 2),
"layernorm": tp.LayerNorm(2),
}
Whereas this is not supported:
::
self.list_modules = [[tp.Linear(2, 2)], [tp.Linear(2, 2)]]
self.dict_modules = {
(1, "linear"): tp.Linear(2, 2),
}
.. code-block:: python
:linenos:
:caption: Example
class AddBias(tp.Module):
def __init__(self):
super().__init__()
self.bias = tp.Parameter(tp.Tensor([1.0, 1.0], dtype=tp.float32))
def __call__(self, x):
return x + self.bias
add_bias = AddBias()
input = tp.Tensor([1.0, 1.0], dtype=tp.float32)
output = add_bias(input)
assert np.array_equal(cp.from_dlpack(output).get(), np.array([2.0, 2.0]))
"""
def __setattr__(self, name: str, value: Any) -> None:
if isinstance(value, Parameter) or name in dict(self.named_parameters()):
_check_param_compatible(getattr(self, name, None), value, name)
super().__setattr__(name, value)
# avoid infinite recursion during initialization
if value is None:
return
if isinstance(value, List) or isinstance(value, Dict):
container = value if isinstance(value, List) else value.values()
if _contains_types(container, [Parameter, Module]) and (
not _is_homogeneous_container(container, Parameter) and not _is_homogeneous_container(container, Module)
):
stack_info = utils.get_stack_info()
stack_info.fetch_source_code()
stack_info_msg = str_from_stack_info(stack_info)
logger.warning(
"A container of mixed types will not be registered with this module's state_dict()."
+ (f"\nNote: container was set here: {stack_info_msg}" if stack_info_msg else "")
)
[docs]
def state_dict(self) -> Dict[str, Parameter]:
r"""
Returns a dictionary mapping names to parameters in the module.
This will recurse over any nested child modules.
Returns:
A dictionary mapping names to parameters.
.. code-block:: python
:linenos:
:caption: Example
# doc: print-locals state_dict
class MyModule(tp.Module):
def __init__(self):
super().__init__()
self.param = tp.Parameter(tp.ones((2,), dtype=tp.float32))
self.linear1 = tp.Linear(2, 2)
self.linear2 = tp.Linear(2, 2)
module = MyModule()
state_dict = module.state_dict()
assert set(state_dict.keys()) == {"param", "linear1.weight", "linear1.bias", "linear2.weight", "linear2.bias"}
"""
state_dict = copy.copy(dict(self.named_parameters()))
for child_name, child in self.named_children():
child_state_dict = child.state_dict()
for name, param in child_state_dict.items():
# We add a prefix for any parameters coming from nested modules
# so they can be disambiguated correctly in higher level modules.
state_dict[f"{child_name}.{name}"] = param
return state_dict
[docs]
def load_state_dict(self, state_dict: Dict[str, Parameter], strict: bool = True) -> Tuple[Set[str], Set[str]]:
r"""
Loads parameters from the provided ``state_dict`` into the current module.
This will recurse over any nested child modules.
Args:
state_dict: A dictionary mapping names to parameters.
strict: If True, keys in ``state_dict`` must exactly match those in this module. If not,
an error will be raised.
Returns:
A ``tuple`` of two ``set``\s of strings representing:
- missing_keys: keys that are expected by this module but not provided in ``state_dict``.
- unexpected_keys: keys that are not expected by this module but provided in ``state_dict``.
.. code-block:: python
:linenos:
:caption: Example
# doc: no-print-locals
class MyModule(tp.Module): # doc: omit
def __init__(self): # doc: omit
super().__init__() # doc: omit
self.param = tp.Parameter(tp.ones((2,), dtype=tp.float32)) # doc: omit
self.linear1 = tp.Linear(2, 2) # doc: omit
self.linear2 = tp.Linear(2, 2) # doc: omit
module = MyModule() # doc: omit
state_dict = module.state_dict() # doc: omit
# Using the `module` and `state_dict` from the `state_dict()` example:
print(f"Before: {module.param}")
state_dict["param"] = tp.Parameter(tp.zeros((2,), dtype=tp.float32))
module.load_state_dict(state_dict)
print(f"After: {module.param}")
assert np.array_equal(cp.from_dlpack(module.state_dict()["param"]).get(), np.array(np.zeros((2,), dtype=np.float32)))
.. seealso:: :func:`state_dict`
"""
def find_module(module: Union[Module, List, Dict], sub_strs: List[str]):
while sub_strs:
child_name = sub_strs.pop(0)
if isinstance(module, list):
module = module[int(child_name)]
elif isinstance(module, dict):
module = module[child_name]
elif isinstance(module, Module):
module = operator.attrgetter(child_name)(module)
return module
expected_keys = set(self.state_dict())
provided_keys = set(state_dict)
missing_keys = expected_keys - provided_keys
unexpected_keys = provided_keys - expected_keys
if strict and (missing_keys or unexpected_keys):
details = []
if missing_keys:
details.append(f"Missing keys: {missing_keys}\n")
if unexpected_keys:
details.append(f"Unexpected keys:\n{unexpected_keys}\n\nNote: Expected keys were:\n{expected_keys}")
raise_error(
"state_dict is incompatible.",
details,
)
for nested_attr_name, param in state_dict.items():
if nested_attr_name in unexpected_keys:
continue
submodule_name, _, param_name = nested_attr_name.rpartition(".")
# If there is no submodule, it means we are accessing a parameter of self
module = self
if submodule_name:
try:
# try to access module.submodule_name as it's the most common case
module = operator.attrgetter(submodule_name)(self)
except AttributeError:
logger.verbose(f"Cannot access {submodule_name} directly, trying to find the correct module.")
# find module starting from the beginning
module = find_module(module, submodule_name.split("."))
if not isinstance(param, Parameter):
param = Parameter(param)
if isinstance(module, Module):
_check_param_compatible(getattr(module, param_name), param, nested_attr_name)
setattr(module, param_name, param)
elif isinstance(module, list):
_check_param_compatible(module[int(param_name)], param, nested_attr_name)
module[int(param_name)] = param
elif isinstance(module, dict):
_check_param_compatible(module[param_name], param, nested_attr_name)
module[param_name] = param
return (missing_keys, unexpected_keys)
[docs]
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
r"""
Returns an iterator over immediate children of this module, yielding tuples
containing the name of the child module and the child module itself.
Returns:
An iterator over tuples containing the name of the child module and the child module itself.
.. code-block:: python
:linenos:
:caption: Example
# doc: no-print-locals
class StackedLinear(tp.Module):
def __init__(self):
super().__init__()
self.linear1 = tp.Linear(2, 2)
self.linear2 = tp.Linear(2, 2)
stacked_linear = StackedLinear()
for name, module in stacked_linear.named_children():
print(f"{name}: {type(module).__name__}")
assert [name for name, _ in stacked_linear.named_children()] == ["linear1", "linear2"]
"""
yield from self._iterate_members_of_type(Module)
[docs]
def named_parameters(self) -> Iterator[Tuple[str, Parameter]]:
r"""
Returns:
An iterator over tuples containing the name of a parameter and the parameter itself.
.. code-block:: python
:linenos:
:caption: Example
# doc: no-print-locals
class Linear(tp.Module):
def __init__(self):
super().__init__()
self.alpha = tp.Parameter(1)
self.beta = tp.Parameter(2)
linear = Linear()
for name, parameter in linear.named_parameters():
print(f"{name}: {parameter}")
assert [name for name, _ in linear.named_parameters()] == ["alpha", "beta"]
"""
yield from self._iterate_members_of_type(Parameter)
def _iterate_members_of_type(self, typ: T) -> Iterator[Tuple[str, T]]:
for name, value in vars(self).items():
if isinstance(value, typ):
yield name, value
elif isinstance(value, List) and _is_homogeneous_container(value, typ):
for i, obj in enumerate(value):
yield f"{name}.{i}", obj
elif isinstance(value, Dict) and _is_homogeneous_container(value.values(), typ):
for key, obj in value.items():
yield f"{name}.{key}", obj
def __str__(self):
from textwrap import indent
class_name = self.__class__.__name__
module_str = f"{class_name}(\n"
# Add children with hierarchical indentation
for name, child in self.named_children():
c = indent(str(child), prefix=" ")
module_str += f" {name}=\n{c},\n"
# Add parameters with hierarchical indentation
for name, param in self.named_parameters():
module_str += f" {name}={param.shape},\n"
module_str += f")"
return module_str