GroupNorm

class tripy.GroupNorm(num_groups: int, num_channels: int, dtype: dtype = float32, eps: float = 1e-05)[source]

Bases: Module

Applies group normalization over the input tensor:

\(\text{GroupNorm}(x) = \Large \frac{x - \bar{x}}{ \sqrt{\sigma^2 + \epsilon}} \normalsize * \gamma + \beta\)

where \(\bar{x}\) is the mean and \(\sigma^2\) is the variance.

Parameters:
  • num_groups (int) – The number of groups to split the channels into.

  • num_channels (int) – The number of channels expected in the input.

  • dtype (dtype) – The data type to use for the weight and bias parameters.

  • eps (float) – \(\epsilon\) value to prevent division by zero.

Example
Example
1group_norm = tp.GroupNorm(2, 4)
2group_norm.weight = tp.ones_like(group_norm.weight)
3group_norm.bias = tp.zeros_like(group_norm.bias)
4
5input = tp.iota((1, 4, 2, 2), dim=1)
6output = group_norm(input)
>>> group_norm.state_dict()
{}
>>> input
tensor(
    [[[[0.0000, 0.0000],
       [0.0000, 0.0000]],

      [[1.0000, 1.0000],
       [1.0000, 1.0000]],

      [[2.0000, 2.0000],
       [2.0000, 2.0000]],

      [[3.0000, 3.0000],
       [3.0000, 3.0000]]]], 
    dtype=float32, loc=gpu:0, shape=(1, 4, 2, 2))
>>> output
tensor(
    [[[[-1.0000, -1.0000],
       [-1.0000, -1.0000]],

      [[1.0000, 1.0000],
       [1.0000, 1.0000]],

      [[-1.0000, -1.0000],
       [-1.0000, -1.0000]],

      [[1.0000, 1.0000],
       [1.0000, 1.0000]]]], 
    dtype=float32, loc=gpu:0, shape=(1, 4, 2, 2))
num_groups: int

The number of groups to split the channels into.

num_channels: int

The number of channels expected in the input.

dtype: dtype

The data type used to perform the operation.

weight: Parameter

The \(\gamma\) parameter of shape \([\text{num_channels}]\).

bias: Parameter

The \(\beta\) parameter of shape \([\text{num_channels}]\).

eps: float

A value added to the denominator to prevent division by zero. Defaults to 1e-5.

__call__(x: Tensor) Tensor[source]
Parameters:

x (Tensor) – The input tensor.

Returns:

A tensor of the same shape as the input.

Return type:

Tensor