Prognostic Models#
Prognostic models in Earth2Studio provides a set of models designed to perform time integration. For example, given a set of atmospheric fields at a particular time and the model auto-regressively predicts the same fields 6 hours into the future.
The usage of prognostic models falls into two categories which are commonly achieved through two different APIs:
Single time-step predictions
Time-series predictions
The list of prognostic models that are already built into Earth2studio can be found in the API documentation earth2studio.models.px: Prognostic.
Prognostic Interface#
The full requirements for a standard prognostic model are defined explicitly in the
earth2studio/models/px/base.py
.
@runtime_checkable
class PrognosticModel(Protocol):
"""Prognostic model interface"""
def __call__(
self,
x: torch.Tensor,
coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
"""Forward pass of the prognostic model, time integrating a single time-step
Parameters
----------
x : torch.Tensor
Input tensor intended to apply diagnostic function on
coords : CoordSystem
Ordered dict representing coordinate system that describes the tensor
Returns
-------
tuple[torch.Tensor, CoordSystem]
Output tensor and coordinate dictionary one time-step into the future
"""
pass
def create_iterator(
self, x: torch.Tensor, coords: CoordSystem
) -> Iterator[tuple[torch.Tensor, CoordSystem]]:
"""Creates a iterator which can be used to perform time-integration of the
prognostic model. Will return the initial condition first (0th step).
Parameters
----------
x : torch.Tensor
Input tensor, which can be viewed as the initial state of the prognositc
coords : CoordSystem
Input coordinate system
Yields
------
Iterator[tuple[torch.Tensor, CoordSystem]]
Iterator that generates time-steps of the prognostic model container the
output data tensor and coordinate system dictionary.
"""
pass
def input_coords(self) -> CoordSystem:
"""Input coordinate system of prognostic model, time dimension should contain
time-delta objects
Returns
-------
CoordSystem
Coordinate system dictionary
"""
pass
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the prognostic model give an input coordinate
system.
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
-------
CoordSystem
Coordinate system dictionary
Raises
------
ValueError
If input_coords are not valid
"""
pass
def to(self, device: Any) -> PrognosticModel:
"""Moves prognostic model onto inference device, this is typically satisfied via
`torch.nn.Module`.
Parameters
----------
device : Any
Object representing the inference device, typically `torch.device` or str
Returns
-------
PrognosticModel
Returns instance of prognostic
"""
pass
Note
Prognostic models do not need to inherit this protocol, this is simply used to define the required APIs. Prognostic models can maintain their internal state when using the iterator if necessary.
Prognostic models also tend to extend two classes:
earth2studio.models.px.utils.PrognosticMixin
: which is a util class that defines iterator hooks used in all the built in models. These provide a finer level of control over the time-series prediction of models.earth2studio.models.auto.AutoModel
: Defines APIs for models that have checkpoints that can be auto downloaded and cached. See the AutoModels guide for additional details.
Prognostic Usage#
Loading a Pre-trained Prognostic#
The following two commands can be used to download and load a pre-trained built prognostic model. More information on automatic downloading of checkpoints can be found in the AutoModels section.
from earth2studio.models.px import PrognosticModel
model_package = PrognosticModel.load_default_package()
model = PrognosticModel.load_model(model_package)
Single Step Prediction#
A prognostic model can be called for single time-step using the call function.
# Assume model is an instance of a PrognosticModel
x = torch.Tensor(...) # Input tensor
coords = CoordSystem(...) # Coordinate system
x, coords = model(x, coords) # Predict a single time-step
Time-series Prediction#
To predict a time-series, the create generator API can be used to create an iterable data source to generate time-series data as the model rolls out.
# Assume model is an instance of a PrognosticModel
x = torch.Tensor(...) # Input tensor
coords = CoordSystem(...) # Coordinate system
model_iterator = model.create_iterator(x, coords) # Create iterator for time integration
for step, (x, coords) in enumerate(model_iterator):
# Perform operations for each time-step
# First output should always be time-step 0 (the input)
Custom Prognostic Models#
Integrating your own prognostic is easy, just satisfy the interface above. We recommend users have a look at the custom prognostic example which will step users through the simple process of implementing their own prognostic model for their personal needs in the Extending Earth2Studio examples.
Contributing a Prognostic Models#
Want to add your prognostic to the package? Great, we will be happy to work with you. At the minimum we expect the model to abide by the defined interface as well as meet the requirements set forth in our contribution guide. Typically, users are expected to provide the weights of the model in a downloadable location that can be fetched.
Open an issue when you have an initial implementation you would like us to review.