Diagnostic Models#

Diagnostic models in Earth2Studio provides a set of models that are independent of time, focused on predicting new or modified values. For example, given an instantaneous set of atmospheric fields a diagnostic model may predict a new field such as precipitation. These models differ from Prognostic Models since they do not perform time integration. Calculations such as statistics or metrics could fall into a diagnostic classification, but we distinguish that diagnostic models are in fact models used to predict physical processes. Not standard mathematical calculations / reductions the purpose of analysis.

The list of diagnostic models that are already built into Earth2studio can be found in the API documentation earth2studio.models.dx: Diagnostic.

Diagnostic Interface#

The full requirements for a standard diagnostic model are defined explicitly in the earth2studio/models/dx/base.py.

@runtime_checkable
class DiagnosticModel(Protocol):
    """Diagnostic model interface"""

    def __call__(
        self,
        x: torch.Tensor,
        coords: CoordSystem,
    ) -> tuple[torch.Tensor, CoordSystem]:
        """Execution of the diagnostic model that transforms physical data

        Parameters
        ----------
        x : torch.Tensor
            Input tensor intended to apply diagnostic function on
        coords : CoordSystem
            Ordered dict representing coordinate system that describes the tensor

        Returns
        -------
        tuple[torch.Tensor, CoordSystem]:
            Output tensor and respective coordinate system dictionary
        """
        pass

    def input_coords(self) -> CoordSystem:
        """Input coordinate system of diagnostic model

        Returns
        -------
        CoordSystem
            Coordinate system dictionary
        """
        pass

    def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
        """Output coordinate system of the diagnostic model given an input coordinate
        system.

        Parameters
        ----------
        input_coords : CoordSystem
            Input coordinate system to transform into output_coords

        Returns
        -------
        CoordSystem
            Coordinate system dictionary

        Raises
        ------
        ValueError
            If input_coords are not valid
        """
        pass

    def to(self, device: Any) -> DiagnosticModel:
        """Moves diagnostic model onto inference device, this is typically satisfied via
        `torch.nn.Module`.

        Parameters
        ----------
        device : Any
            Object representing the inference device, typically `torch.device` or str

        Returns
        -------
        DiagnosticModel
            Returns instance of diagnostic
        """
        pass

Note

Diagnostic models do not need to inherit this protocol, this is simply used to define the required APIs.

Diagnostic models also tend to extend one class:

  1. earth2studio.models.auto.AutoModel: Defines APIs for models that have checkpoints that can be auto downloaded and cached. See the AutoModels guide for additional details.

Diagnostic Usage#

Loading a Pre-trained Diagnostic#

The following two commands can be used to download and load a pre-trained built diagnostic model. More information on automatic downloading of checkpoints can be found in the AutoModels section.

from earth2studio.models.dx import DiagnosticModel

model_package = DiagnosticModel.load_default_package()
model = DiagnosticModel.load_model(model_package)

Prediction#

The work horse of diagnostic models is the __call__() function which takes in a data tensor with coordinate system and returns the primary output.

# Assume model is an instance of a DiagnosticModel
x = torch.Tensor(...)  # Input tensor
coords = CoordSystem(...)  # Coordinate system
x, coords = model(x, coords)  # Predict a single time-step

Custom Diagnostic Models#

Integrating your own diagnostic is easy, just satisfy the interface above. We recommend users have a look at the Extending Earth2Studio examples, which will step users through the simple process of implementing their own diagnostic model.

Contributing a Diagnostic Models#

Want to add your diagnostic to the package? Great, we will be happy to work with you. At the minimum we expect the model to abide by the defined interface and meet the requirements set forth in our contribution guide. Typically users are expected to provide the weights of the model in a downloadable location that can fetched.

Open an issue when you have an initial implementation you would like us to review. If you’re aware of an existing model and want us to implemented it, open a feature request and we will get it triaged.