Source code for nvtripy.frontend.module.groupnorm

#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass

from nvtripy import export, utils
from nvtripy.common import datatype
from nvtripy.common.exception import raise_error
from nvtripy.frontend.module.module import Module
from nvtripy.frontend.module.parameter import DefaultParameter
from nvtripy.frontend.tensor import Tensor

from nvtripy.frontend.module.instancenorm import InstanceNorm


[docs] @export.public_api(document_under="operations/modules") @dataclass @utils.wrappers.constant_fields(["num_groups", "num_channels", "dtype"]) class GroupNorm(Module): r""" Applies group normalization over the input tensor: :math:`\text{GroupNorm}(x) = \Large \frac{x - \bar{x}}{ \sqrt{\sigma^2 + \epsilon}} \normalsize * \gamma + \beta` where :math:`\bar{x}` is the mean and :math:`\sigma^2` is the variance. The input should have shape :math:`[N, C, D1, ...]` where :math:`N` is the batch size, :math:`C` is the number of channels, and :math:`D1, ...` are the feature dimensions. """ num_groups: int r"""The number of groups to split the channels into.""" num_channels: int """The number of channels expected in the input.""" dtype: datatype.dtype r"""The data type used to perform the operation.""" weight: Tensor r"""The :math:`\gamma` parameter of shape :math:`[\text{num_channels}]`.""" bias: Tensor r"""The :math:`\beta` parameter of shape :math:`[\text{num_channels}]`.""" eps: float """A value added to the denominator to prevent division by zero. Defaults to 1e-5.""" def __init__( self, num_groups: int, num_channels: int, dtype: datatype.dtype = datatype.float32, eps: float = 1e-5 ) -> None: r""" Args: num_groups: The number of groups to split the channels into. num_channels: The number of channels expected in the input. dtype: The data type to use for the weight and bias parameters. eps: :math:`\epsilon` value to prevent division by zero. .. code-block:: python :linenos: group_norm = tp.GroupNorm(2, 4) group_norm.weight = tp.ones(group_norm.weight.shape) group_norm.bias = tp.zeros(group_norm.bias.shape) input = tp.iota((1, 4, 1, 1), dim=1) output = group_norm(input) np_out = cp.from_dlpack(output).get() # doc: omit assert np_out.shape == (1, 4, 1, 1) torch_tensor = torch.from_dlpack(input) # doc: omit torch_gn = torch.nn.GroupNorm(2, 2).to(torch.device("cuda")) # doc: omit torch_gn.weight.data = torch.from_dlpack(group_norm.weight) # doc: omit torch_gn.bias.data = torch.from_dlpack(group_norm.bias) # doc: omit torch_out = cp.from_dlpack(torch_gn(torch_tensor).detach()).get() # doc: omit assert np_out.shape == torch_out.shape assert np.allclose(np_out, torch_out) """ super().__init__() if num_channels % num_groups: raise_error( "The number of groups must divide number of channels evenly.", details=[f"Got {num_groups} groups but {num_channels} channels."], ) self.num_groups = num_groups self.num_channels = num_channels self.dtype = dtype # Replace with random weights when #74 is completed. self.weight = DefaultParameter((num_channels,), dtype=dtype) self.bias = DefaultParameter((num_channels,), dtype=dtype) self.eps = eps
[docs] def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor": r""" Args: x: The input tensor. Returns: A tensor of the same shape as the input. """ if x.rank < 3: raise_error( f"Input must have a rank of at least 3, but got input of rank: {x.rank}", details=[ "The input should have shape [N, C, D1, ...] where N is the batch size, C is the number of channels, and D1, ... are the feature dimensions." ], ) from nvtripy.frontend.ops.reshape import reshape from nvtripy.frontend.ops.split import split from nvtripy.frontend.ops.stack import stack from nvtripy.frontend.ops.flatten import flatten from nvtripy.frontend.module.instancenorm import InstanceNorm from nvtripy.frontend.ops.ones import ones from nvtripy.frontend.ops.zeros import zeros instance_norm = InstanceNorm(self.num_groups, dtype=self.dtype, eps=self.eps) instance_norm.weight = ones((self.num_groups,), dtype=self.dtype) instance_norm.bias = zeros((self.num_groups,), dtype=self.dtype) # Use InstanceNorm as a WAR due to lack of TRT API compatibility for scale/bias with shape (num_channels, ) input_reshaped = stack(split(x, self.num_groups, 1), 1) x = instance_norm(input_reshaped) x = flatten(x, start_dim=1, end_dim=2) broadcast_shape = (1, self.num_channels) + (1,) * (x.rank - 2) return x * reshape(self.weight, broadcast_shape) + reshape(self.bias, broadcast_shape)