earth2studio.models.px
.InterpModAFNO#
- class earth2studio.models.px.InterpModAFNO(interp_model, center, scale, geop, lsm, px_model=None, num_interp_steps=6)[source]#
ModAFNO interpolation for global prognostic models. Interpolates a forecast model to a shorter time-step size (by default from 6 to 1 hour). Operates on 0.25 degree lat-lon equirectangular grid with 73 variables.
Note
For more information on the model, please refer to:
Warning
The model requires a base forecast model to be set before execution. This can be done by setting the px_model attribute.
- Parameters:
interp_model (torch.nn.Module) – The interpolation model that performs the time interpolation
center (torch.Tensor) – Model center normalization tensor
scale (torch.Tensor) – Model scale normalization tensors
geop (torch.Tensor) – Geopotential height data used as a static feature
lsm (torch.Tensor) – Land-sea mask data used as a static feature
px_model (PrognosticModel, optional) – The base forecast model that produces the coarse time resolution forecasts. If not provide, should be set by the user before executing the model, by default None.
num_interp_steps (int, optional) – Number of interpolation steps to perform between forecast steps, by default 6
- __call__(x, coords)[source]#
Runs prognostic model 1 step
- Parameters:
x (torch.Tensor) – Input tensor
coords (CoordSystem) – Input coordinate system
- Returns:
Output tensor and coordinate system 1 hour in the future
- Return type:
tuple[torch.Tensor, CoordSystem]
- create_iterator(x, coords)[source]#
Creates a iterator which can be used to perform time-integration of the prognostic model. Will return the initial condition first (0th step).
- Parameters:
x (torch.Tensor) – Input tensor
coords (CoordSystem) – Input coordinate system
- Yields:
Iterator[tuple[torch.Tensor, CoordSystem]] – Iterator that generates time-steps of the prognostic model container the output data tensor and coordinate system dictionary.
- Return type:
Iterator[tuple[Tensor, OrderedDict[str, ndarray]]]