earth2studio.models.px.DiagnosticWrapper#
- class earth2studio.models.px.DiagnosticWrapper(px_model, dx_model, prepare_dx_input_coords=None, prepare_dx_input_tensor=None, prepare_output_coords=None, prepare_output_tensor=None)[source]#
Wraps a prognostic model and one or more diagnostic models into a single prognostic model. The micro-pipeline this wrapper encapsulates has the following four steps:
Execute one step of the prognostic model
Prepare output of prognostic model for each diagnostic model
Execute forward pass each diagnostic model using the prepare prognostic data
Prepare outputs of prognostic/diagnostic for final return
The wrapper provides customizable methods for preparing diagnostic model inputs and outputs. If not provided, default methods are have the following requirements:
All diagnostics must have the same output coordinate systems with the exception
of the variable dimension - Both the prognostic and diagnostic models must have lat/lon grid systems.
Note
Custom callables or classes implementing the Protocol interfaces can be provided to override default behavior such as skipping interpolation or changing concatenation logic. This will be required for many diagnostic models. The prepare functions must implement the appropriate Protocol (__call__ method with matching signature):
PrepareDxInputCoords: Prepares coordinate systems
PrepareDxInputTensor: Prepares tensors with optional interpolation
PrepareOutputCoords: Prepares final output coordinate systems
PrepareOutputTensor: Prepares final output tensors
- Parameters:
px_model (PrognosticModel) – The prognostic model to use as the base model.
dx_model (DiagnosticModel | list[DiagnosticModel]) – Single diagnostic model or list of diagnostic models whose outputs are concatenated to the prognostic model output.
prepare_dx_input_coords (PrepareDxInputCoords | list[PrepareDxInputCoords] | None, optional) – Callable or Protocol-implementing object to prepare coordinate system for diagnostic model input. Can be a single instance (applied to all diagnostics) or a list (one per diagnostic). If None, uses PrepareInputCoordsDefault for each diagnostic, by default None
prepare_dx_input_tensor (PrepareDxInputTensor | list[PrepareDxInputTensor] | None, optional) – Callable or Protocol-implementing object to prepare tensor for diagnostic model input. Can be a single instance (applied to all diagnostics) or a list (one per diagnostic). If None, uses PrepareInputTensorDefault with interpolation for each diagnostic, by default None
prepare_output_coords (PrepareOutputCoords | None, optional) – Callable or Protocol-implementing object to prepare output coordinate system. If None, uses PrepareOutputCoordsDefault which concatenates all variables, by default None
prepare_output_tensor (PrepareOutputTensor | None, optional) – Callable or Protocol-implementing object to prepare output tensor. If None, uses PrepareOutputTensorDefault which concatenates all outputs, by default None
- __call__(x, coords)[source]#
Runs prognostic model 1 step
- Parameters:
x (torch.Tensor) – Input tensor
coords (CoordSystem) – Input coordinate system
- Returns:
x (torch.Tensor)
coords (CoordSystem)
- Return type:
tuple[Tensor, OrderedDict[str, ndarray]]
- create_iterator(x, coords)[source]#
Creates a iterator which can be used to perform time-integration of the prognostic model. Will return the initial condition first (0th step).
- Parameters:
x (torch.Tensor) – Input tensor
coords (CoordSystem) – Input coordinate system
- Yields:
Iterator[tuple[torch.Tensor, CoordSystem]] – Iterator that generates time-steps of the prognostic model container the output data tensor and coordinate system dictionary.
- Return type:
Iterator[tuple[Tensor, OrderedDict[str, ndarray]]]