Prognostic Models#

Prognostic models in Earth2Studio provides a set of models designed to perform time integration. For example, given a set of atmospheric fields at a particular time and the model auto-regressively predicts the same fields 6 hours into the future.

The usage of prognostic models falls into two categories which are commonly achieved through two different APIs:

  1. Single time-step predictions

  2. Time-series predictions

The list of prognostic models that are already built into Earth2studio can be found in the API documentation earth2studio.models.px: Prognostic.

Prognostic Interface#

The full requirements for a standard prognostic model are defined explicitly in the earth2studio/models/px/base.py.


@runtime_checkable
class PrognosticModel(Protocol):
    """Prognostic model interface"""

    def __call__(
        self,
        x: torch.Tensor,
        coords: CoordSystem,
    ) -> tuple[torch.Tensor, CoordSystem]:
        """Forward pass of the prognostic model, time integrating a single time-step

        Parameters
        ----------
        x : torch.Tensor
            Input tensor intended to apply diagnostic function on
        coords : CoordSystem
            Ordered dict representing coordinate system that describes the tensor

        Returns
        -------
        tuple[torch.Tensor, CoordSystem]
            Output tensor and coordinate dictionary one time-step into the future
        """
        pass

    def create_iterator(
        self, x: torch.Tensor, coords: CoordSystem
    ) -> Iterator[tuple[torch.Tensor, CoordSystem]]:
        """Creates a iterator which can be used to perform time-integration of the
        prognostic model. Will return the initial condition first (0th step).

        Parameters
        ----------
        x : torch.Tensor
            Input tensor, which can be viewed as the initial state of the prognositc
        coords : CoordSystem
            Input coordinate system

        Yields
        ------
        Iterator[tuple[torch.Tensor, CoordSystem]]
            Iterator that generates time-steps of the prognostic model container the
            output data tensor and coordinate system dictionary.
        """
        pass

    def input_coords(self) -> CoordSystem:
        """Input coordinate system of prognostic model, time dimension should contain
        time-delta objects

        Returns
        -------
        CoordSystem
            Coordinate system dictionary
        """
        pass

    def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
        """Output coordinate system of the prognostic model give an input coordinate
        system.

        Parameters
        ----------
        input_coords : CoordSystem
            Input coordinate system to transform into output_coords

        Returns
        -------
        CoordSystem
            Coordinate system dictionary

        Raises
        ------
        ValueError
            If input_coords are not valid
        """
        pass

    def to(self, device: Any) -> PrognosticModel:
        """Moves prognostic model onto inference device, this is typically satisfied via
        `torch.nn.Module`.

        Parameters
        ----------
        device : Any
            Object representing the inference device, typically `torch.device` or str

        Returns
        -------
        PrognosticModel
            Returns instance of prognostic
        """
        pass

Note

Prognostic models do not need to inherit this protocol, this is simply used to define the required APIs. Prognostic models can maintain their internal state when using the iterator if necessary.

Prognostic models also tend to extend two classes:

  1. earth2studio.models.px.utils.PrognosticMixin: which is a util class that defines iterator hooks used in all the built in models. These provide a finer level of control over the time-series prediction of models.

  2. earth2studio.models.auto.AutoModel: Defines APIs for models that have checkpoints that can be auto downloaded and cached. See the AutoModels guide for additional details.

Prognostic Usage#

Loading a Pre-trained Prognostic#

The following two commands can be used to download and load a pre-trained built prognostic model. More information on automatic downloading of checkpoints can be found in the AutoModels section.

from earth2studio.models.px import PrognosticModel

model_package = PrognosticModel.load_default_package()
model = PrognosticModel.load_model(model_package)

Single Step Prediction#

A prognostic model can be called for single time-step using the call function.

# Assume model is an instance of a PrognosticModel
x = torch.Tensor(...)  # Input tensor
coords = CoordSystem(...)  # Coordinate system
x, coords = model(x, coords)  # Predict a single time-step

Time-series Prediction#

To predict a time-series, the create generator API can be used to create an iterable data source to generate time-series data as the model rolls out.

# Assume model is an instance of a PrognosticModel
x = torch.Tensor(...)  # Input tensor
coords = CoordSystem(...)  # Coordinate system
model_iterator = model.create_iterator(x, coords)  # Create iterator for time integration
for step, (x, coords) in enumerate(model_iterator):
    # Perform operations for each time-step
    # First output should always be time-step 0 (the input)

Custom Prognostic Models#

Integrating your own prognostic is easy, just satisfy the interface above. We recommend users have a look at the custom prognostic example which will step users through the simple process of implementing their own prognostic model for their personal needs in the Extending Earth2Studio examples.

Contributing a Prognostic Models#

Want to add your prognostic to the package? Great, we will be happy to work with you. At the minimum we expect the model to abide by the defined interface as well as meet the requirements set forth in our contribution guide. Typically, users are expected to provide the weights of the model in a downloadable location that can be fetched.

Open an issue when you have an initial implementation you would like us to review.