earth2studio.models.px.DiagnosticWrapper#

class earth2studio.models.px.DiagnosticWrapper(px_model, dx_model, prepare_dx_input_coords=None, prepare_dx_input_tensor=None, prepare_output_coords=None, prepare_output_tensor=None)[source]#

Wraps a prognostic model and one or more diagnostic models into a single prognostic model. The micro-pipeline this wrapper encapsulates has the following four steps:

  1. Execute one step of the prognostic model

  2. Prepare output of prognostic model for each diagnostic model

  3. Execute forward pass each diagnostic model using the prepare prognostic data

  4. Prepare outputs of prognostic/diagnostic for final return

The wrapper provides customizable methods for preparing diagnostic model inputs and outputs. If not provided, default methods are have the following requirements:

  • All diagnostics must have the same output coordinate systems with the exception

of the variable dimension - Both the prognostic and diagnostic models must have lat/lon grid systems.

Note

Custom callables or classes implementing the Protocol interfaces can be provided to override default behavior such as skipping interpolation or changing concatenation logic. This will be required for many diagnostic models. The prepare functions must implement the appropriate Protocol (__call__ method with matching signature):

  • PrepareDxInputCoords: Prepares coordinate systems

  • PrepareDxInputTensor: Prepares tensors with optional interpolation

  • PrepareOutputCoords: Prepares final output coordinate systems

  • PrepareOutputTensor: Prepares final output tensors

Parameters:
  • px_model (PrognosticModel) – The prognostic model to use as the base model.

  • dx_model (DiagnosticModel | list[DiagnosticModel]) – Single diagnostic model or list of diagnostic models whose outputs are concatenated to the prognostic model output.

  • prepare_dx_input_coords (PrepareDxInputCoords | list[PrepareDxInputCoords] | None, optional) – Callable or Protocol-implementing object to prepare coordinate system for diagnostic model input. Can be a single instance (applied to all diagnostics) or a list (one per diagnostic). If None, uses PrepareInputCoordsDefault for each diagnostic, by default None

  • prepare_dx_input_tensor (PrepareDxInputTensor | list[PrepareDxInputTensor] | None, optional) – Callable or Protocol-implementing object to prepare tensor for diagnostic model input. Can be a single instance (applied to all diagnostics) or a list (one per diagnostic). If None, uses PrepareInputTensorDefault with interpolation for each diagnostic, by default None

  • prepare_output_coords (PrepareOutputCoords | None, optional) – Callable or Protocol-implementing object to prepare output coordinate system. If None, uses PrepareOutputCoordsDefault which concatenates all variables, by default None

  • prepare_output_tensor (PrepareOutputTensor | None, optional) – Callable or Protocol-implementing object to prepare output tensor. If None, uses PrepareOutputTensorDefault which concatenates all outputs, by default None

__call__(x, coords)[source]#

Runs prognostic model 1 step

Parameters:
  • x (torch.Tensor) – Input tensor

  • coords (CoordSystem) – Input coordinate system

Returns:

  • x (torch.Tensor)

  • coords (CoordSystem)

Return type:

tuple[Tensor, OrderedDict[str, ndarray]]

create_iterator(x, coords)[source]#

Creates a iterator which can be used to perform time-integration of the prognostic model. Will return the initial condition first (0th step).

Parameters:
  • x (torch.Tensor) – Input tensor

  • coords (CoordSystem) – Input coordinate system

Yields:

Iterator[tuple[torch.Tensor, CoordSystem]] – Iterator that generates time-steps of the prognostic model container the output data tensor and coordinate system dictionary.

Return type:

Iterator[tuple[Tensor, OrderedDict[str, ndarray]]]