earth2studio.models.px.DiagnosticWrapper#

class earth2studio.models.px.DiagnosticWrapper(px_model, dx_models, interpolate_coords=False, keep_px_output=True)[source]#

Wraps a prognostic model and one or more diagnostic models into a single prognostic model. This allows diagnostic model outputs to be included in workflows that expect a prognostic model.

The outputs of the diagnostic models are concatenated the output of the prognostic in the order given in dx_models.

Results will be returned in the coordinate system of the last diagnostic model.

Model compatibility requirements: - Input variables of each model in the chain

[px_model, dx_models[0], dx_models[1], …] must be available in the outputs of one of the previous models.

  • If interpolate_coords == False, the coordinates of each model in the chain must be mappable to the next model using earth2studio.utils.coords.map_coords.

  • If interpolate_coords == True, the coordinates of each model in the chain must be possible to interpolate to the next model using earth2studio.utils.interp.LatLonInterpolation.

Parameters:
  • px_model (PrognosticModel) – The prognostic model to use as the base model.

  • dx_models (DiagnosticModel | Sequence[DiagnosticModel]) – The diagnostic models whose outputs are concatenated to the output of px_model.

  • interpolate_coords (bool, default False) – Whether to use bilinear interpolation to map spatial coordinates. If False, nearest neighbor interpolation will be used. Must be set to True if any models have 2D latitude/longitude coordinates.

  • keep_px_output (bool, default True) – Whether to include output of px_model in the input.

__call__(x, coords)[source]#

Runs prognostic model 1 step.

Parameters:
  • x (torch.Tensor) – Input tensor

  • coords (CoordSystem) – Coordinate system, should have dimensions [time, variable, *domain_dims]

Returns:

  • x (torch.Tensor)

  • coords (CoordSystem)

Return type:

tuple[Tensor, OrderedDict[str, ndarray]]

create_iterator(x, coords)[source]#

Creates a iterator which can be used to perform time-integration of the prognostic model. Will return the initial condition first (0th step).

Parameters:
  • x (torch.Tensor) – Input tensor

  • coords (CoordSystem) – Input coordinate system

Yields:

Iterator[tuple[torch.Tensor, CoordSystem]] – Iterator that generates time-steps of the prognostic model container the output data tensor and coordinate system dictionary.

Return type:

Iterator[tuple[Tensor, OrderedDict[str, ndarray]]]