Earth2Studio is now OSS!

Statistics#

Statistics are distinct from prognostic and diagnostic models in principle because we assume that statistics reduce existing coordinates so that the output tensors have a coordinate system that is a subset of the input coordinate system. This makes statistics less flexible than diagnostic models but have fewer API requirements.

Statistics Interface#

Statistics API only specify a __call__() method that matches similar methods across the package.

@runtime_checkable
class Statistic(Protocol):
    """Statistic interface."""

    @property
    def reduction_dimensions(self) -> list[str]:
        """Gives the input dimensions of which the statistic performs a reduction
        over. The is used to determine, a priori, the output dimensions of a statistic.
        """
        pass

    def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
        """Output coordinate system of the computed statistic, corresponding to the given input coordinates

        Parameters
        ----------
        input_coords : CoordSystem
            Input coordinate system to transform into output_coords

        Returns

The base API hints at, and inspection of the earth2studio.statistics.moments examples, the use of a few properties to make statistic handling easier: reduction_dimensions, which are a list of dimensions that will be reduced over, weights, which must be broadcastable with reduction_dimensions, and batch_update, which is useful for applying statistics when data comes in streams/batches.

Where applicable, specified reduction_dimensions set a requirement for the coordinates passed in the call method.

Custom Statistics#

Integrating your own statistics is easy, just satisfy the interface above. We recommend users look at the custom statistic example in the Extending Earth2Studio examples.

Metrics#

Like statistics, metrics are reductions across existing dimensions. Unlike statistics, which are usually defined over a single input, we define metrics to take a pair of inputs. Otherwise, the API and requirements are similar to the statistics requirements.

Metrics Interface#

    ) -> tuple[torch.Tensor, CoordSystem]:
        """Apply statistic to data `x`, with coordinates `coords` and reduce
        over dimensions `reduction_dimensions`.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor intended to apply statistic to.
        coords : CoordSystem
            Ordered dict representing coordinate system that describes the tensor.
            `reduction_dimensions` must be in coords.
        """
        pass


@runtime_checkable
class Metric(Protocol):
    """Metrics interface."""

    @property
    def reduction_dimensions(self) -> list[str]:
        pass

    def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
        """Output coordinate system of the computed statistic, corresponding to the given input coordinates

        Parameters
        ----------
        input_coords : CoordSystem
            Input coordinate system to transform into output_coords

        Returns
        -------
        CoordSystem
            Coordinate system dictionary
        """
        pass

    def __call__(
        self,
        x: torch.Tensor,
        x_coords: CoordSystem,
        y: torch.Tensor,
        y_coords: CoordSystem,
    ) -> tuple[torch.Tensor, CoordSystem]:
        """Apply metric to data `x` and `y`, checking that their coordinates
        are broadcastable. While reducing over `reduction_dimensions`.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor #1 intended to apply metric to. `x` is typically understood
            to be the forecast or prediction tensor.
        x_coords : CoordSystem
            Ordered dict representing coordinate system that describes the `x` tensor.
            `reduction_dimensions` must be in coords.
        y : torch.Tensor
            Input tensor #2 intended to apply statistic to. `y` is typically the observation
            or validation tensor.
        y_coords : CoordSystem
            Ordered dict representing coordinate system that describes the `y` tensor.
            `reduction_dimensions` must be in coords.
        """
        pass

Contributing Statistics and Metrics#

Want to add your own statistics or metrics to the package? Great, we will be happy to work with you. At the minimum we expect the model to abide by the interfaces defined above. We may also work with the user to ensure that there are reduction_dimensions applicable and, if possible, weight and batching support possible.