# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from earth2studio.statistics.utils import _broadcast_weights
from earth2studio.utils.coords import handshake_dim
from earth2studio.utils.type import CoordSystem
[docs]
class mean:
"""
Statistic for calculating the sample mean over a set of given dimensions.
Parameters
----------
reduction_dimensions: List[str]
A list of names corresponding to dimensions to perform the
statistical reduction over. Example: ['lat', 'lon']
weights: torch.Tensor, optional
A tensor containing weights to assign to the reduction dimensions.
Note that these weights must have the same number of dimensions
as passed in reduction_dimensions.
Example: if reduction_dimensions = ['lat', 'lon'] then
assert weights.ndim == 2.
By default None.
batch_update: bool, optional
Whether to applying batch updates to the mean with each invocation of __call__.
This is particularly useful when data is recieved in a stream of batches. Each
invocation of __call__ will return the running mean.
By default False.
"""
def __init__(
self,
reduction_dimensions: list[str],
weights: torch.Tensor = None,
batch_update: bool = False,
):
if weights is not None:
if weights.ndim != len(reduction_dimensions):
raise ValueError(
"Error! Weights must be the same dimension as reduction_dimensions"
)
self._reduction_dimensions = reduction_dimensions
self.weights = weights
self.batch_update = batch_update
if self.batch_update:
self.n = 0
def __str__(self) -> str:
return "_".join(self._reduction_dimensions + ["mean"])
@property
def reduction_dimensions(self) -> list[str]:
return self._reduction_dimensions
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the computed statistic, corresponding to the given input coordinates
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = input_coords.copy()
for dimension in self.reduction_dimensions:
handshake_dim(input_coords, dimension)
output_coords.pop(dimension)
return output_coords
[docs]
def __call__(
self, x: torch.Tensor, coords: CoordSystem
) -> tuple[torch.Tensor, CoordSystem]:
"""
Apply the mean operation over the tensor x.
If batch_update was passed True upon metric initialization then this method
returns the running sample mean over all seen batches.
Parameters
----------
x: torch.Tensor
Input data to compute sample mean.
coords: CoordSystem
Coordinates referring to the input data, x.
"""
if not all([rd in coords for rd in self._reduction_dimensions]):
raise ValueError(
"Initialized reduction_dimensions do not appear in passed coords."
)
dims = [list(coords).index(rd) for rd in self._reduction_dimensions]
output_coords = CoordSystem(
{
key: coords[key]
for key in coords
if key not in self._reduction_dimensions
}
)
weights = _broadcast_weights(
self.weights, self._reduction_dimensions, coords
).to(x.device)
weights_sum = torch.sum(weights)
# If not applying batch updating then return regular mean.
if not self.batch_update:
return torch.sum(weights * x, dim=dims) / weights_sum, output_coords
# If batch updating then calculate updated mean
else:
if self.n == 0:
self.sum = torch.sum(weights * x, dim=dims)
else:
self.sum += torch.sum(weights * x, dim=dims)
self.n += weights_sum
return self.sum / self.n, output_coords
[docs]
class variance:
"""
Statistic for calculating the sample variance over a set of given dimensions.
Parameters
----------
reduction_dimensions: List[str]
A list of names corresponding to dimensions to perform the
statistical reduction over. Example: ['lat', 'lon']
weights: torch.Tensor, optional
A tensor containing weights to assign to the reduction dimensions.
Note that these weights must have the same number of dimensions
as passed in reduction_dimensions.
Example: if reduction_dimensions = ['lat', 'lon'] then
assert weights.ndim == 2.
By default None.
batch_update: bool, optional
Whether to applying batch updates to the variance with each invocation of __call__.
This is particularly useful when data is recieved in a stream of batches. Each
invocation of __call__ will return the running variance.
By default False.
"""
def __init__(
self,
reduction_dimensions: list[str],
weights: torch.Tensor = None,
batch_update: bool = False,
):
if weights is not None:
if weights.ndim != len(reduction_dimensions):
raise ValueError(
"Error! Weights must be the same dimension as reduction_dimensions"
)
self._reduction_dimensions = reduction_dimensions
self.weights = weights
self.batch_update = batch_update
if self.batch_update:
self.n = 0
def __str__(self) -> str:
return "_".join(self._reduction_dimensions + ["variance"])
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the computed statistic, corresponding to the given input coordinates
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = input_coords.copy()
for dimension in self.reduction_dimensions:
handshake_dim(input_coords, dimension)
output_coords.pop(dimension)
return output_coords
@property
def reduction_dimensions(self) -> list[str]:
return self._reduction_dimensions
[docs]
def __call__(
self, x: torch.Tensor, coords: CoordSystem
) -> tuple[torch.Tensor, CoordSystem]:
"""
Apply the sample variance operation over the tensor x.
If batch_update was passed True upon metric initialization then this method
returns the running sample variance over all seen batches.
Parameters
----------
x: torch.Tensor
Input data to compute sample variance.
coords: CoordSystem
Coordinates referring to the input data, x.
"""
if not all([rd in coords for rd in self._reduction_dimensions]):
raise ValueError(
"Initialized reduction_dimensions do not appear in passed coords."
)
dims = [list(coords).index(rd) for rd in self._reduction_dimensions]
output_coords = CoordSystem(
{
key: coords[key]
for key in coords
if key not in self._reduction_dimensions
}
)
weights = _broadcast_weights(
self.weights, self._reduction_dimensions, coords
).to(x.device)
weights_sum = torch.sum(weights)
# If not applying batch updating then return regular variance.
if not self.batch_update:
m = torch.sum(weights * x, dim=dims, keepdims=True) / weights_sum
div = weights_sum - torch.sum(weights**2) / weights_sum
return torch.sum(weights * (x - m) ** 2, dim=dims) / div, output_coords
# If batch updating then calculate updated mean
else:
temp_n = weights_sum
temp_sum = torch.sum(weights * x, dim=dims)
temp_sum2 = torch.sum(weights * (x - temp_sum / temp_n) ** 2, dim=dims)
# First batch then no correction
if self.n == 0:
self.n = temp_n
self.sum = temp_sum
self.sum2 = temp_sum2
# Second-order correction with each batch
else:
delta = self.sum * temp_n / self.n - temp_sum
self.sum += temp_sum
self.sum2 += (
temp_sum2 + self.n / temp_n / (self.n + temp_n) * delta**2
)
self.n += temp_n
return (
self.sum2 / torch.maximum(self.n - 1.0, torch.tensor(1.0)),
output_coords,
)
[docs]
class std:
"""
Statistic for calculating the sample standard deviation over a set of given dimensions.
Parameters
----------
reduction_dimensions: List[str]
A list of names corresponding to dimensions to perform the
statistical reduction over. Example: ['lat', 'lon']
weights: torch.Tensor, optional
A tensor containing weights to assign to the reduction dimensions.
Note that these weights must have the same number of dimensions
as passed in reduction_dimensions.
Example: if reduction_dimensions = ['lat', 'lon'] then
assert weights.ndim == 2.
By default None.
batch_update: bool, optional
Whether to applying batch updates to the standard deviation with each invocation of __call__.
This is particularly useful when data is recieved in a stream of batches. Each
invocation of __call__ will return the running standard deviation.
By default False.
"""
def __init__(
self,
reduction_dimensions: list[str],
weights: torch.Tensor = None,
batch_update: bool = False,
):
self.var = variance(
reduction_dimensions, weights=weights, batch_update=batch_update
)
self._reduction_dimensions = reduction_dimensions
self.weights = weights
def __str__(self) -> str:
return "_".join(self._reduction_dimensions + ["std"])
@property
def reduction_dimensions(self) -> list[str]:
return self._reduction_dimensions
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the computed statistic, corresponding to the given input coordinates
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = input_coords.copy()
for dimension in self.reduction_dimensions:
handshake_dim(input_coords, dimension)
output_coords.pop(dimension)
return output_coords
[docs]
def __call__(
self, x: torch.Tensor, coords: CoordSystem
) -> tuple[torch.Tensor, CoordSystem]:
"""
Apply the sample standard deviation operation over the tensor x.
If batch_update was passed True upon metric initialization then this method
returns the running sample standard deviation over all seen batches.
Parameters
----------
x: torch.Tensor
Input data to compute sample standard deviation.
coords: CoordSystem
Coordinates referring to the input data, x.
"""
var, output_coords = self.var(x, coords)
return torch.sqrt(var), output_coords