Concepts#
Earth2 MIP has the following concepts:
Layered abstractions for wrapping machine learning models
Data Sources for loading initial conditions and scoring against.
Python APIs work with these abstractions to do things like scoring and ensemble inference.
Model packages allow reproducible loading of checkpoints from disk or the cloud. that work with these abstractions.
Command line tools that allow running scoring jobs in a parallel environment.
Model Wrappers#
The core model-related abstractions are
Machine learning Module (e.g.
torch.nn.Module
)Time loop (
earth2mip.time_loop.TimeLoop
).Forecast (
earth2mip.forecasts.Forecast
)
Each of these presents increasingly minimal interface.
We will demonstrate each of these interfaces for the persistence forecast. A persistence forecast is one that always returns the initial condition. It is a common baseline for a weather forecast.
A design philosphy we follow is that model wrappers should take plain torch tensors as inputs and outputs and provide static metadata (e.g. grid, channel, time) about how those tensors map onto the planet. An alternative philosophy is to use some kind of metadata-aware container either custom-built or off-the-shelf like xarray. We avoid the latter approach since there is always a temptation to implement all of mathematics on some container type or hide the arrays under several layers of containers. At the end of the day, most ML models ultimately take and receive a single multi dimesional array of data. Finally, ML developers are all familiar with basic array-like data types. We already need to wrap our models in a stack of abstractions to provide a common interface, so we don’t need to make the data more complex.
Module#
At the root is a machine learning model with inputs/outputs that correspond to
two-dimensional fields defined on the planet. These fields are reified as some
array-like data structure, the names of the fields/channels, and grid object
(earth2mip.grid.LatLongGrid
). To use a Module, fcn-mip needs to be provided
metadata about the Model’s inputs/outputs (see. earth2mip.schema.Model
).
Here as how to implement the persistence forecast as module:
import torch.nn
class PersistenceModule(torch.nn.Module):
def forward(self, x):
return x
This is a very simple, but has no metadata about the input or output, and from
an outside perspective, can be difficult to use if x
has more semantic
meaning or requirements.
Time Stepper#
- class earth2mip.time_loop.TimeStepper(*args, **kwargs)#
An functional interface that can be used for time stepping
state -> (state, output)
This uses a generic state, but concrete Tensors as input and output. This allows users to directly control the time-stepping logic and potentially modify the state in model-specific manner, but the basic initial conditions and running outputs are concrete torch Tensors.
One example is the graphcast time stepper. Graphcast uses jax and xarray to handle the state.
It should be used like this:
stepper = MyStepper() state = stepper.initialize(x, time) outputs = [] for i in range(10): state, output = stepper.step(state) outputs.append(output)
One benefit is that the state can be saved and reloaded trivially to restart the simulation.
- initialize(x, time)#
x is described by
self.input_info
- Parameters:
x (Tensor) –
time (datetime) –
- Return type:
StateT
- step(state)#
step the state and return the ml output as a tensor
The output tensor is described by
self.output_info
- Parameters:
state (StateT) –
- Return type:
tuple[StateT, Tensor]
Time Loop#
A Time Loop is a higher level interface encapsulating the Module, but also
timestepping logic, data preprocessing, and output.
earth2mip.time_loop.TimeLoop
defines this interface, and
earth2mip.networks.Inference
is implements this interface and can
turn a Module into a TimeLoop.
Here is a TimeLoop of the persistence forecast:
from earth2mip.time_loop import TimeLoop
from earth2mip.schema import Grid
import datetime
class PeristenceTimeLoop(TimeLoop):
time_step = datetime.timedelta(hours=12)
# 1 history level = only the current time as input
n_history_levels = 1
in_channel_names = ["a", "b", "c"]
out_channel_names = ["a", "b", "c"]
grid = Grid.grid_721x1440
def __call__(self, time, x, restart=None):
b, h, c, w, h == x.shape
assert b == 1
assert h == self.n_history_levels
assert c == len(self.in_channel_names)
assert (w, h) == self.grid.shape
while True:
yield time, x, None
time += self.time_step
This encapsulates the time stepping, and exposes other needed metadata.
Note
This time loop does not support restart capability.
Forecast#
Many scoring algorithms are most easily expressed as operations over 2D array of
states that we call a Forecast Array. The rows of this array correspond to
initial times, and the columns to lead times. The size of this array may be
unbounded.
For example, computing a lead time dependent metrics, such as RMSE
corresponds to averaging the square difference of Forecast Arrays of
observations and forecasts, and then averaging over the row dimension.
This is defined by the earth2mip.forecasts.Forecast
interface.
Compared to a TimeLoop, a Forecast encapsulates any time handling and initialization logic.
One advantage is that an archive of forecasts on disk can be represented as a Forecast
(see earth2mip.forecasts.XarrayForecast
).
This allows using the same code to score both static and streaming forecasts.
Finally, here is a earth2mip.forecasts.Forecast
implementation, for
a persistence forecast beginning on Jan 1, 2018 and producing ICs every 12
hours and sampling forecasts every 12 hours:
from earth2mip.forecasts import Forecast
import datetime
class PeristenceForecast(Forecast):
# only corresponds to out_channel_names
channel_names = ["a", "b", "c"]
def __init__(self, initial_data: Mapping[datetime.datetime, np.ndarray]):
self.initial_data = initial_data
def __getitem__(self, i):
initial_time = datetime.datetime(2018, 1, 1)
lead_dt = init_dt = datetime.timedelta(hours=12)
time = initial_time + init_dt * i
x = self.initial_data[time]
while True:
yield x
we can see that PeristenceForecast
encapsulates the initialization, time and
other logic.
Translating between Model Wrappers#
earth2mip provides implementations that translate between Model Wrappers.
To create a TimeLoop from a Module, use
earth2mip.networks.Inference
:
from earth2mip.networks import Inference
model = PersistenceModule()
# work around to not do any normalization
center = np.zeros([3])
scale = np.ones([3])
time_loop = Inference(
model,
center=center,
scale=scale,
grid=Grid.grid_721x1440,
time_step=datetime.timedelta(hours=12),
# note n_history_levels == n_history + 1
n_history=0,
channel_names=["a", "b", "c"],
)
To create a forecast from a TimeLoop, you can use
earth2mip.forecasts.TimeLoopForecast
:
from earth2mip.forecasts import TimeLoopForecast
forecast = TimeLoopForecast(
time_loop,
initial_data={
datetime.datetime(2018, 1, 1): np.zeros([1, 1, 3, 721, 1440])
},
)
Data Source#
Model package#
A model package is a directory containing a metadata.json
file following
this schema and any other static data
required to load the model. The model package typically contains model
parameters, normalization constants, etc.