Running Diagnostic Inference#

Basic prognostic + diagnostic inference workflow.

This example will demonstrate how to run a deterministic inference workflow that couples a prognostic model with a diagnostic model. This diagnostic model will predict a new atmospheric quantity from the predicted fields of the prognostic.

In this example you will learn:

  • How to instantiate a prognostic model

  • How to instantiate a diagnostic model

  • Creating a data source and IO object

  • Running the built in diagnostic workflow

  • Post-processing results

Set Up#

For this example, the built in diagnostic workflow earth2studio.run.diagnostic() will be used.

def diagnostic(
    time: list[str] | list[datetime] | list[np.datetime64],
    nsteps: int,
    prognostic: PrognosticModel,
    diagnostic: DiagnosticModel,
    data: DataSource,
    io: IOBackend,
    output_coords: CoordSystem = OrderedDict({}),
    device: torch.device | None = None,
) -> IOBackend:
    """Built in diagnostic workflow.
    This workflow creates a determinstic inference pipeline that couples a prognostic
    model with a diagnostic model.

    Parameters
    ----------
    time : list[str] | list[datetime] | list[np.datetime64]
        List of string, datetimes or np.datetime64
    nsteps : int
        Number of forecast steps
    prognostic : PrognosticModel
        Prognostic model
    diagnostic: DiagnosticModel
        Diagnostic model, must be on same coordinate axis as prognostic
    data : DataSource
        Data source
    io : IOBackend
        IO object
    output_coords: CoordSystem, optional
        IO output coordinate system override, by default OrderedDict({})
    device : torch.device, optional
        Device to run inference on, by default None

    Returns
    -------
    IOBackend
        Output IO object
    """

Thus, we need the following:

import os

os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv

load_dotenv()  # TODO: make common example prep function

from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
from earth2studio.models.dx import PrecipitationAFNO
from earth2studio.models.px import FCN

# Load the default model package which downloads the check point from NGC
package = FCN.load_default_package()
prognostic_model = FCN.load_model(package)

package = PrecipitationAFNO.load_default_package()
diagnostic_model = PrecipitationAFNO.load_model(package)

# Create the data source
data = GFS()

# Create the IO handler, store in memory
io = ZarrBackend()
/usr/local/lib/python3.10/dist-packages/modulus/models/module.py:360: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  model_dict = torch.load(

Execute the Workflow#

With all components initialized, running the workflow is a single line of Python code. Workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information.

import earth2studio.run as run

nsteps = 8
io = run.diagnostic(
    ["2021-06-01"], nsteps, prognostic_model, diagnostic_model, data, io
)

print(io.root.tree())
2025-01-23 04:38:28.162 | INFO     | earth2studio.run:diagnostic:190 - Running diagnostic workflow!
2025-01-23 04:38:28.162 | INFO     | earth2studio.run:diagnostic:197 - Inference device: cuda
2025-01-23 04:38:28.268 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:209 - Fetching GFS index file: 2021-06-01 00:00:00 lead 0:00:00

Fetching GFS for 2021-06-01 00:00:00:   0%|          | 0/26 [00:00<?, ?it/s]

2025-01-23 04:38:28.271 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u10m at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:   0%|          | 0/26 [00:00<?, ?it/s]

2025-01-23 04:38:28.300 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v10m at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:   0%|          | 0/26 [00:00<?, ?it/s]

2025-01-23 04:38:28.326 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t2m at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:   0%|          | 0/26 [00:00<?, ?it/s]

2025-01-23 04:38:28.352 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: sp at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:   0%|          | 0/26 [00:00<?, ?it/s]
Fetching GFS for 2021-06-01 00:00:00:  15%|█▌        | 4/26 [00:00<00:00, 37.43it/s]

2025-01-23 04:38:28.379 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: msl at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  15%|█▌        | 4/26 [00:00<00:00, 37.43it/s]

2025-01-23 04:38:28.404 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t850 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  15%|█▌        | 4/26 [00:00<00:00, 37.43it/s]

2025-01-23 04:38:28.431 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u1000 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  15%|█▌        | 4/26 [00:00<00:00, 37.43it/s]

2025-01-23 04:38:28.457 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v1000 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  15%|█▌        | 4/26 [00:00<00:00, 37.43it/s]
Fetching GFS for 2021-06-01 00:00:00:  31%|███       | 8/26 [00:00<00:00, 37.86it/s]

2025-01-23 04:38:28.483 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z1000 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  31%|███       | 8/26 [00:00<00:00, 37.86it/s]

2025-01-23 04:38:28.510 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u850 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  31%|███       | 8/26 [00:00<00:00, 37.86it/s]

2025-01-23 04:38:28.536 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v850 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  31%|███       | 8/26 [00:00<00:00, 37.86it/s]

2025-01-23 04:38:28.562 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z850 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  31%|███       | 8/26 [00:00<00:00, 37.86it/s]
Fetching GFS for 2021-06-01 00:00:00:  46%|████▌     | 12/26 [00:00<00:00, 37.93it/s]

2025-01-23 04:38:28.589 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u500 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  46%|████▌     | 12/26 [00:00<00:00, 37.93it/s]

2025-01-23 04:38:28.614 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v500 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  46%|████▌     | 12/26 [00:00<00:00, 37.93it/s]

2025-01-23 04:38:28.641 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z500 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  46%|████▌     | 12/26 [00:00<00:00, 37.93it/s]

2025-01-23 04:38:28.667 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t500 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  46%|████▌     | 12/26 [00:00<00:00, 37.93it/s]
Fetching GFS for 2021-06-01 00:00:00:  62%|██████▏   | 16/26 [00:00<00:00, 38.08it/s]

2025-01-23 04:38:28.693 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z50 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  62%|██████▏   | 16/26 [00:00<00:00, 38.08it/s]

2025-01-23 04:38:28.719 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: r500 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  62%|██████▏   | 16/26 [00:00<00:00, 38.08it/s]

2025-01-23 04:38:28.745 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: r850 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  62%|██████▏   | 16/26 [00:00<00:00, 38.08it/s]

2025-01-23 04:38:28.772 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: tcwv at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  62%|██████▏   | 16/26 [00:00<00:00, 38.08it/s]
Fetching GFS for 2021-06-01 00:00:00:  77%|███████▋  | 20/26 [00:00<00:00, 37.99it/s]

2025-01-23 04:38:28.799 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u100m at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  77%|███████▋  | 20/26 [00:00<00:00, 37.99it/s]

2025-01-23 04:38:28.825 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v100m at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  77%|███████▋  | 20/26 [00:00<00:00, 37.99it/s]

2025-01-23 04:38:28.851 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u250 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  77%|███████▋  | 20/26 [00:00<00:00, 37.99it/s]

2025-01-23 04:38:28.877 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v250 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  77%|███████▋  | 20/26 [00:00<00:00, 37.99it/s]
Fetching GFS for 2021-06-01 00:00:00:  92%|█████████▏| 24/26 [00:00<00:00, 38.05it/s]

2025-01-23 04:38:28.904 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z250 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  92%|█████████▏| 24/26 [00:00<00:00, 38.05it/s]

2025-01-23 04:38:28.930 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t250 at 2021-06-01 00:00:00_0:00:00

Fetching GFS for 2021-06-01 00:00:00:  92%|█████████▏| 24/26 [00:00<00:00, 38.05it/s]
Fetching GFS for 2021-06-01 00:00:00: 100%|██████████| 26/26 [00:00<00:00, 38.00it/s]
2025-01-23 04:38:29.051 | SUCCESS  | earth2studio.run:diagnostic:220 - Fetched data from GFS
2025-01-23 04:38:29.057 | INFO     | earth2studio.run:diagnostic:252 - Inference starting!

Running inference:   0%|          | 0/9 [00:00<?, ?it/s]
Running inference:  11%|█         | 1/9 [00:00<00:01,  5.28it/s]
Running inference:  22%|██▏       | 2/9 [00:00<00:02,  3.48it/s]
Running inference:  33%|███▎      | 3/9 [00:00<00:01,  3.15it/s]
Running inference:  44%|████▍     | 4/9 [00:01<00:01,  3.03it/s]
Running inference:  56%|█████▌    | 5/9 [00:01<00:01,  2.94it/s]
Running inference:  67%|██████▋   | 6/9 [00:01<00:01,  2.89it/s]
Running inference:  78%|███████▊  | 7/9 [00:02<00:00,  2.87it/s]
Running inference:  89%|████████▉ | 8/9 [00:02<00:00,  2.85it/s]
Running inference: 100%|██████████| 9/9 [00:03<00:00,  2.81it/s]
Running inference: 100%|██████████| 9/9 [00:03<00:00,  2.96it/s]
2025-01-23 04:38:32.096 | SUCCESS  | earth2studio.run:diagnostic:266 - Inference complete
/
 ├── lat (720,) float64
 ├── lead_time (9,) timedelta64[h]
 ├── lon (1440,) float64
 ├── time (1,) datetime64[ns]
 └── tp (1, 9, 720, 1440) float32

Post Processing#

The last step is to plot the resulting predicted total precipitation. The power of diagnostic models is that they allow the prediction of any variable from a pre-trained prognostic model.

Note

The built in workflow will only save the direct outputs of the diagnostic. In this example only total precipitation is accessible for plotting. If you wish to save outputs of both the prognostic and diagnostic, we recommend writing a custom workflow.

from datetime import datetime

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np

forecast = datetime(2021, 6, 1)
variable = "tp"
step = 8  # lead time = 48 hrs

plt.close("all")
# Create a Orthographic projection of USA
projection = ccrs.Orthographic(-100, 40)

# Create a figure and axes with the specified projection
fig, ax = plt.subplots(subplot_kw={"projection": projection}, figsize=(10, 6))

# Plot the field using pcolormesh
levels = np.arange(0.0, 0.01, 0.001)
im = ax.contourf(
    io["lon"][:],
    io["lat"][:],
    io[variable][0, step],
    levels,
    transform=ccrs.PlateCarree(),
    vmax=0.01,
    vmin=0.00,
    cmap="terrain",
)

# Set title
ax.set_title(f"{forecast.strftime('%Y-%m-%d')} - Lead time: {6*step}hrs")

# Add coastlines and gridlines6
ax.set_extent([220, 340, 20, 70])  # [lat min, lat max, lon min, lon max]
ax.coastlines()
ax.gridlines()
plt.colorbar(
    im, ax=ax, ticks=levels, shrink=0.75, pad=0.04, label="Total precipitation (m)"
)

plt.savefig("outputs/02_tp_prediction.jpg")
2021-06-01 - Lead time: 48hrs

Total running time of the script: (0 minutes 23.284 seconds)

Gallery generated by Sphinx-Gallery