Statistics#
Statistics are distinct from prognostic and diagnostic models in principle because we assume that statistics reduce existing coordinates so that the output tensors have a coordinate system that is a subset of the input coordinate system. This makes statistics less flexible than diagnostic models but have fewer API requirements.
Statistics Interface#
Statistics API only specify a __call__()
method that matches similar methods
across the package.
@runtime_checkable
class Statistic(Protocol):
"""Statistic interface."""
@property
def reduction_dimensions(self) -> list[str]:
"""Gives the input dimensions of which the statistic performs a reduction
over. The is used to determine, a priori, the output dimensions of a statistic.
"""
pass
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the computed statistic, corresponding to the given input coordinates
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
The base API hints at, and inspection of the earth2studio.statistics.moments
examples, the use of a few properties to make statistic handling easier:
reduction_dimensions
, which are a list of dimensions that will be reduced over,
weights
, which must be broadcastable with reduction_dimensions
, and batch_update
,
which is useful for applying statistics when data comes in streams/batches.
Where applicable, specified reduction_dimensions
set a requirement for the
coordinates passed in the call method.
Custom Statistics#
Integrating your own statistics is easy, just satisfy the interface above. We recommend users look at the custom statistic example in the Extending Earth2Studio examples.
Metrics#
Like statistics, metrics are reductions across existing dimensions. Unlike statistics, which are usually defined over a single input, we define metrics to take a pair of inputs. Otherwise, the API and requirements are similar to the statistics requirements.
Metrics Interface#
) -> tuple[torch.Tensor, CoordSystem]:
"""Apply statistic to data `x`, with coordinates `coords` and reduce
over dimensions `reduction_dimensions`.
Parameters
----------
x : torch.Tensor
Input tensor intended to apply statistic to.
coords : CoordSystem
Ordered dict representing coordinate system that describes the tensor.
`reduction_dimensions` must be in coords.
"""
pass
@runtime_checkable
class Metric(Protocol):
"""Metrics interface."""
@property
def reduction_dimensions(self) -> list[str]:
pass
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the computed statistic, corresponding to the given input coordinates
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
-------
CoordSystem
Coordinate system dictionary
"""
pass
def __call__(
self,
x: torch.Tensor,
x_coords: CoordSystem,
y: torch.Tensor,
y_coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
"""Apply metric to data `x` and `y`, checking that their coordinates
are broadcastable. While reducing over `reduction_dimensions`.
Parameters
----------
x : torch.Tensor
Input tensor #1 intended to apply metric to. `x` is typically understood
to be the forecast or prediction tensor.
x_coords : CoordSystem
Ordered dict representing coordinate system that describes the `x` tensor.
`reduction_dimensions` must be in coords.
y : torch.Tensor
Input tensor #2 intended to apply statistic to. `y` is typically the observation
or validation tensor.
y_coords : CoordSystem
Ordered dict representing coordinate system that describes the `y` tensor.
`reduction_dimensions` must be in coords.
"""
pass
Contributing Statistics and Metrics#
Want to add your own statistics or metrics to the package? Great, we will be happy to
work with you. At the minimum we expect the model to abide by the interfaces defined
above. We may also work with the user to ensure that there are reduction_dimensions
applicable and, if possible, weight and batching support possible.