earth2studio.models.px
.DiagnosticWrapper#
- class earth2studio.models.px.DiagnosticWrapper(px_model, dx_models, interpolate_coords=False, keep_px_output=True)[source]#
Wraps a prognostic model and one or more diagnostic models into a single prognostic model. This allows diagnostic model outputs to be included in workflows that expect a prognostic model.
The outputs of the diagnostic models are concatenated the output of the prognostic in the order given in dx_models.
Results will be returned in the coordinate system of the last diagnostic model.
Model compatibility requirements: - Input variables of each model in the chain
[px_model, dx_models[0], dx_models[1], …] must be available in the outputs of one of the previous models.
If interpolate_coords == False, the coordinates of each model in the chain must be mappable to the next model using earth2studio.utils.coords.map_coords.
If interpolate_coords == True, the coordinates of each model in the chain must be possible to interpolate to the next model using earth2studio.utils.interp.LatLonInterpolation.
- Parameters:
px_model (PrognosticModel) – The prognostic model to use as the base model.
dx_models (DiagnosticModel | Sequence[DiagnosticModel]) – The diagnostic models whose outputs are concatenated to the output of px_model.
interpolate_coords (bool, default False) – Whether to use bilinear interpolation to map spatial coordinates. If False, nearest neighbor interpolation will be used. Must be set to True if any models have 2D latitude/longitude coordinates.
keep_px_output (bool, default True) – Whether to include output of px_model in the input.
- __call__(x, coords)[source]#
Runs prognostic model 1 step.
- Parameters:
x (torch.Tensor) – Input tensor
coords (CoordSystem) – Coordinate system, should have dimensions
[time, variable, *domain_dims]
- Returns:
x (torch.Tensor)
coords (CoordSystem)
- Return type:
tuple[Tensor, OrderedDict[str, ndarray]]
- create_iterator(x, coords)[source]#
Creates a iterator which can be used to perform time-integration of the prognostic model. Will return the initial condition first (0th step).
- Parameters:
x (torch.Tensor) – Input tensor
coords (CoordSystem) – Input coordinate system
- Yields:
Iterator[tuple[torch.Tensor, CoordSystem]] – Iterator that generates time-steps of the prognostic model container the output data tensor and coordinate system dictionary.
- Return type:
Iterator[tuple[Tensor, OrderedDict[str, ndarray]]]