# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from earth2studio.utils.coords import handshake_dim
from earth2studio.utils.type import CoordSystem
from .moments import mean, variance
[docs]
class rmse:
"""
Statistic for calculating the root mean squared error of two tensors
over a set of given dimensions.
Parameters
----------
reduction_dimensions: List[str]
A list of names corresponding to dimensions to perform the
statistical reduction over. Example: ['lat', 'lon']
weights: torch.Tensor = None
A tensor containing weights to assign to the reduction dimensions.
Note that these weights must have the same number of dimensions
as passed in reduction_dimensions.
Example: if reduction_dimensions = ['lat', 'lon'] then
assert weights.ndim == 2.
batch_update: bool = False
Whether to applying batch updates to the rmse with each invocation of __call__.
This is particularly useful when data is recieved in a stream of batches. Each
invocation of __call__ will return the running rmse. In particular, it will apply
the square root operation after calculating the running mean squared error.
"""
def __init__(
self,
reduction_dimensions: list[str],
weights: torch.Tensor = None,
batch_update: bool = False,
):
self.mean = mean(
reduction_dimensions, weights=weights, batch_update=batch_update
)
self.weights = weights
self._reduction_dimensions = reduction_dimensions
self.batch_update = batch_update
def __str__(self) -> str:
return "_".join(self._reduction_dimensions + ["rmse"])
@property
def reduction_dimensions(self) -> list[str]:
return self._reduction_dimensions
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the prognostic model
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = input_coords.copy()
for dimension in self.reduction_dimensions:
handshake_dim(input_coords, dimension)
output_coords.pop(dimension)
return output_coords
[docs]
def __call__(
self,
x: torch.Tensor,
x_coords: CoordSystem,
y: torch.Tensor,
y_coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
"""
Apply metric to data `x` and `y`, checking that their coordinates
are broadcastable. While reducing over `reduction_dims`.
If batch_update was passed True upon metric initialization then this method
returns the running sample RMSE over all seen batches.
Parameters
----------
x : torch.Tensor
Input tensor, typically the forecast or prediction tensor, but RMSE is
symmetric with respect to `x` and `y`.
x_coords : CoordSystem
Ordered dict representing coordinate system that describes the `x` tensor.
`reduction_dimensions` must be in coords.
y : torch.Tensor
Input tensor #2 intended to be used as validation data, but ACC is symmetric
with respect to `x` and `y`.
y_coords : CoordSystem
Ordered dict representing coordinate system that describes the `y` tensor.
`reduction_dimensions` must be in coords.
Returns
-------
tuple[torch.Tensor, CoordSystem]
Returns root mean squared error tensor with appropriate reduced coordinates.
"""
mse, output_coords = self.mean((x - y) ** 2, x_coords)
return torch.sqrt(mse), output_coords
[docs]
class spread_skill_ratio:
"""Metric for calculating the spread/skill ratio of an ensemble forecast.
Specifically, the spread is defined as the standard deviation of the ensemble
forecast. The skill is defined as the rmse of the ensemble mean prediction. The
ratio of these two quantities is defined as the spread/skill ratio.
Parameters
----------
ensemble_dimension : str
The dimension over which the spread and skill are calculated over.
This should usually be "ensemble".
reduction_dimensions : list[str]
Dimensions to reduce (mean) the spread/skill ratio over. This is commonly done
over time but can also be the globe or some region.
Example: ['time', 'lat', 'lon']
ensemble_weights : torch.Tensor | None, optional
A one-dimensional tensor containing weights to assign to the ensemble_dimension,
by default None.
reduction_weights : torch.Tensor, optional
A tensor containing weights to assign to the reduction dimensions.
Note that these weights must have the same number of dimensions
as passed in reduction_dimensions.
Example: if reduction_dimensions = ['lat', 'lon'] then
assert weights.ndim == 2. Ignored if None, by default None.
ensemble_batch_update : bool, optional
Whether to applying batch updates to the ensemble mean and variance components
of the spread and skill with each invocation of __call__. This is particularly
useful when ensemble data is recieved in a stream of batches. Each invocation of
__call__ will return the running spread/skill ratio., by default False.
reduction_batch_update : bool, optional
Whether to applying batch updates to the reduction rmse and averaging components
of the spread/skill with each invocation of __call__. This is particularly
useful when time data is recieved in a stream of batches., by default False.
"""
def __init__(
self,
ensemble_dimension: str,
reduction_dimensions: list[str],
ensemble_weights: torch.Tensor | None = None,
reduction_weights: torch.Tensor = None,
ensemble_batch_update: bool = False,
reduction_batch_update: bool = False,
):
self.ensemble_dimension = [ensemble_dimension]
self._reduction_dimensions = reduction_dimensions
self.ensemble_mean = mean(
reduction_dimensions=self.ensemble_dimension,
weights=ensemble_weights,
batch_update=ensemble_batch_update,
)
self.ensemble_var = variance(
reduction_dimensions=self.ensemble_dimension,
weights=ensemble_weights,
batch_update=ensemble_batch_update,
)
self.reduced_rmse = rmse(
reduction_dimensions=reduction_dimensions,
weights=reduction_weights,
batch_update=reduction_batch_update,
)
self.reduced_mean = mean(
reduction_dimensions=reduction_dimensions,
weights=reduction_weights,
batch_update=reduction_batch_update,
)
def __str__(self) -> str:
return "_".join(
self.ensemble_dimension + self._reduction_dimensions + ["spread_skill"]
)
@property
def reduction_dimensions(self) -> list[str]:
return self.ensemble_dimension + self._reduction_dimensions
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the prognostic model
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = input_coords.copy()
for dimension in self.reduction_dimensions:
handshake_dim(input_coords, dimension)
output_coords.pop(dimension)
return output_coords
[docs]
def __call__(
self,
x: torch.Tensor,
x_coords: CoordSystem,
y: torch.Tensor,
y_coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
"""
Apply metric to data `x` and `y`, checking that their coordinates
are broadcastable. While reducing over `reduction_dims`.
If batch_update was passed True upon metric initialization then this method
returns the running sample RMSE over all seen batches.
Parameters
----------
x : torch.Tensor
The ensemble forecast input tensor. This is the tensor over which the
ensemble mean and spread are calculated with respect to.
x_coords : CoordSystem
Ordered dict representing coordinate system that describes the `x` tensor.
`reduction_dimensions` must be in coords.
y : torch.Tensor
The observation input tensor.
y_coords : CoordSystem
Ordered dict representing coordinate system that describes the `y` tensor.
`reduction_dimensions` must be in coords.
Returns
-------
tuple[torch.Tensor, CoordSystem]
Returns root mean squared error tensor with appropriate reduced coordinates.
"""
em, output_coords = self.ensemble_mean(x, x_coords)
skill, output_coords = self.reduced_rmse(em, output_coords, y, y_coords)
spread, output_coords = self.reduced_mean(*self.ensemble_var(x, x_coords))
return skill / torch.sqrt(spread), output_coords