Note
Go to the end to download the full example code.
Running Diagnostic Inference#
Basic prognostic + diagnostic inference workflow.
This example will demonstrate how to run a deterministic inference workflow that couples a prognostic model with a diagnostic model. This diagnostic model will predict a new atmospheric quantity from the predicted fields of the prognostic.
In this example you will learn:
How to instantiate a prognostic model
How to instantiate a diagnostic model
Creating a data source and IO object
Running the built in diagnostic workflow
Post-processing results
Set Up#
For this example, the built in diagnostic workflow earth2studio.run.diagnostic()
will be used.
def diagnostic(
time: list[str] | list[datetime] | list[np.datetime64],
nsteps: int,
prognostic: PrognosticModel,
diagnostic: DiagnosticModel,
data: DataSource,
io: IOBackend,
output_coords: CoordSystem = OrderedDict({}),
device: torch.device | None = None,
) -> IOBackend:
"""Built in diagnostic workflow.
This workflow creates a determinstic inference pipeline that couples a prognostic
model with a diagnostic model.
Parameters
----------
time : list[str] | list[datetime] | list[np.datetime64]
List of string, datetimes or np.datetime64
nsteps : int
Number of forecast steps
prognostic : PrognosticModel
Prognostic model
diagnostic: DiagnosticModel
Diagnostic model, must be on same coordinate axis as prognostic
data : DataSource
Data source
io : IOBackend
IO object
output_coords: CoordSystem, optional
IO output coordinate system override, by default OrderedDict({})
device : torch.device, optional
Device to run inference on, by default None
Returns
-------
IOBackend
Output IO object
"""
Thus, we need the following:
Prognostic Model: Use the built in FourCastNet Model
earth2studio.models.px.FCN
.Diagnostic Model: Use the built in precipitation AFNO model
earth2studio.models.dx.PrecipitationAFNO
.Datasource: Pull data from the GFS data api
earth2studio.data.GFS
.IO Backend: Save the outputs into a Zarr store
earth2studio.io.ZarrBackend
.
import os
os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv
load_dotenv() # TODO: make common example prep function
from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
from earth2studio.models.dx import PrecipitationAFNO
from earth2studio.models.px import FCN
# Load the default model package which downloads the check point from NGC
package = FCN.load_default_package()
prognostic_model = FCN.load_model(package)
package = PrecipitationAFNO.load_default_package()
diagnostic_model = PrecipitationAFNO.load_model(package)
# Create the data source
data = GFS()
# Create the IO handler, store in memory
io = ZarrBackend()
Downloading fcn.zip: 0%| | 0.00/267M [00:00<?, ?B/s]
Downloading fcn.zip: 1%| | 2.27M/267M [00:00<00:11, 23.8MB/s]
Downloading fcn.zip: 10%|▉ | 26.1M/267M [00:00<00:01, 156MB/s]
Downloading fcn.zip: 20%|█▉ | 52.3M/267M [00:00<00:01, 210MB/s]
Downloading fcn.zip: 29%|██▉ | 78.5M/267M [00:00<00:00, 236MB/s]
Downloading fcn.zip: 39%|███▉ | 105M/267M [00:00<00:00, 251MB/s]
Downloading fcn.zip: 49%|████▉ | 132M/267M [00:00<00:00, 261MB/s]
Downloading fcn.zip: 59%|█████▉ | 158M/267M [00:00<00:00, 265MB/s]
Downloading fcn.zip: 69%|██████▉ | 184M/267M [00:00<00:00, 267MB/s]
Downloading fcn.zip: 79%|███████▊ | 210M/267M [00:00<00:00, 268MB/s]
Downloading fcn.zip: 89%|████████▊ | 236M/267M [00:01<00:00, 271MB/s]
Downloading fcn.zip: 99%|█████████▊| 263M/267M [00:01<00:00, 273MB/s]
Downloading fcn.zip: 100%|██████████| 267M/267M [00:01<00:00, 250MB/s]
Downloading precipitation_afno.zip: 0%| | 0.00/261M [00:00<?, ?B/s]
Downloading precipitation_afno.zip: 5%|▍ | 12.0M/261M [00:00<00:02, 126MB/s]
Downloading precipitation_afno.zip: 9%|▉ | 24.1M/261M [00:00<00:02, 119MB/s]
Downloading precipitation_afno.zip: 19%|█▉ | 49.7M/261M [00:00<00:01, 185MB/s]
Downloading precipitation_afno.zip: 29%|██▉ | 75.1M/261M [00:00<00:00, 216MB/s]
Downloading precipitation_afno.zip: 39%|███▉ | 101M/261M [00:00<00:00, 237MB/s]
Downloading precipitation_afno.zip: 49%|████▉ | 128M/261M [00:00<00:00, 250MB/s]
Downloading precipitation_afno.zip: 59%|█████▉ | 154M/261M [00:00<00:00, 258MB/s]
Downloading precipitation_afno.zip: 69%|██████▉ | 180M/261M [00:00<00:00, 262MB/s]
Downloading precipitation_afno.zip: 78%|███████▊ | 205M/261M [00:00<00:00, 262MB/s]
Downloading precipitation_afno.zip: 88%|████████▊ | 230M/261M [00:01<00:00, 258MB/s]
Downloading precipitation_afno.zip: 97%|█████████▋| 254M/261M [00:01<00:00, 254MB/s]
Downloading precipitation_afno.zip: 100%|██████████| 261M/261M [00:01<00:00, 238MB/s]
Execute the Workflow#
With all components initialized, running the workflow is a single line of Python code. Workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information.
import earth2studio.run as run
nsteps = 8
io = run.diagnostic(
["2021-06-01"], nsteps, prognostic_model, diagnostic_model, data, io
)
print(io.root.tree())
2025-05-15 03:03:16.172 | INFO | earth2studio.run:diagnostic:190 - Running diagnostic workflow!
2025-05-15 03:03:16.172 | INFO | earth2studio.run:diagnostic:197 - Inference device: cuda
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.654 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 427023039-950918
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.656 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 350106602-970343
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.657 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 349147153-959449
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.659 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 343243269-852573
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.661 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 265295671-828915
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.662 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 209303792-746874
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.663 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 433893232-1229275
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.667 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 422428446-507899
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.668 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 406739323-978379
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.670 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 215284398-610650
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.671 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 270673885-954744
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.672 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 263723951-825006
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.674 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 414253228-840220
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.675 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 407717702-959428
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.677 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 472734498-976207
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.678 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 344095842-882412
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.679 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 155987833-763786
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.681 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 426047081-975958
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.682 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 342333604-909665
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.683 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 215895048-622148
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.685 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 264548957-746714
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.686 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 412154325-999310
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.687 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 210050666-761020
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.689 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 0-1002486
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.691 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 269727465-946420
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
2025-05-15 03:03:16.692 | DEBUG | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20210601/00/atmos/gfs.t00z.pgrb2.0p25.f000 471746333-988165
Fetching GFS data: 0%| | 0/26 [00:00<?, ?it/s]
Fetching GFS data: 4%|▍ | 1/26 [00:00<00:13, 1.91it/s]
Fetching GFS data: 8%|▊ | 2/26 [00:00<00:07, 3.20it/s]
Fetching GFS data: 15%|█▌ | 4/26 [00:00<00:03, 6.06it/s]
Fetching GFS data: 42%|████▏ | 11/26 [00:00<00:00, 17.96it/s]
Fetching GFS data: 73%|███████▎ | 19/26 [00:01<00:00, 27.59it/s]
Fetching GFS data: 96%|█████████▌| 25/26 [00:01<00:00, 33.65it/s]
Fetching GFS data: 100%|██████████| 26/26 [00:01<00:00, 20.55it/s]
2025-05-15 03:03:17.977 | SUCCESS | earth2studio.run:diagnostic:220 - Fetched data from GFS
2025-05-15 03:03:17.980 | WARNING | earth2studio.io.zarr:add_array:192 - Datetime64 not supported in zarr 3.0, converting to int64 nanoseconds since epoch
2025-05-15 03:03:17.983 | WARNING | earth2studio.io.zarr:add_array:198 - Timedelta64 not supported in zarr 3.0, converting to int64 nanoseconds since epoch
2025-05-15 03:03:17.994 | INFO | earth2studio.run:diagnostic:252 - Inference starting!
Running inference: 0%| | 0/9 [00:00<?, ?it/s]
Running inference: 11%|█ | 1/9 [00:00<00:02, 2.92it/s]
Running inference: 22%|██▏ | 2/9 [00:00<00:02, 2.94it/s]
Running inference: 33%|███▎ | 3/9 [00:01<00:02, 2.93it/s]
Running inference: 44%|████▍ | 4/9 [00:01<00:01, 2.97it/s]
Running inference: 56%|█████▌ | 5/9 [00:01<00:01, 3.00it/s]
Running inference: 67%|██████▋ | 6/9 [00:02<00:01, 2.96it/s]
Running inference: 78%|███████▊ | 7/9 [00:02<00:00, 2.90it/s]
Running inference: 89%|████████▉ | 8/9 [00:02<00:00, 2.82it/s]
Running inference: 100%|██████████| 9/9 [00:03<00:00, 2.72it/s]
Running inference: 100%|██████████| 9/9 [00:03<00:00, 2.85it/s]
2025-05-15 03:03:21.152 | SUCCESS | earth2studio.run:diagnostic:266 - Inference complete
/
├── lat (720,) float64
├── lead_time (9,) int64
├── lon (1440,) float64
├── time (1,) int64
└── tp (1, 9, 720, 1440) float32
Post Processing#
The last step is to plot the resulting predicted total precipitation. The power of diagnostic models is that they allow the prediction of any variable from a pre-trained prognostic model.
Note
The built in workflow will only save the direct outputs of the diagnostic. In this example only total precipitation is accessible for plotting. If you wish to save outputs of both the prognostic and diagnostic, we recommend writing a custom workflow.
from datetime import datetime
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
forecast = datetime(2021, 6, 1)
variable = "tp"
step = 8 # lead time = 48 hrs
plt.close("all")
# Create a Orthographic projection of USA
projection = ccrs.Orthographic(-100, 40)
# Create a figure and axes with the specified projection
fig, ax = plt.subplots(subplot_kw={"projection": projection}, figsize=(10, 6))
# Plot the field using pcolormesh
levels = np.arange(0.0, 0.01, 0.001)
im = ax.contourf(
io["lon"][:],
io["lat"][:],
io[variable][0, step],
levels,
transform=ccrs.PlateCarree(),
vmax=0.01,
vmin=0.00,
cmap="terrain",
)
# Set title
ax.set_title(f"{forecast.strftime('%Y-%m-%d')} - Lead time: {6*step}hrs")
# Add coastlines and gridlines6
ax.set_extent([220, 340, 20, 70]) # [lat min, lat max, lon min, lon max]
ax.coastlines()
ax.gridlines()
plt.colorbar(
im, ax=ax, ticks=levels, shrink=0.75, pad=0.04, label="Total precipitation (m)"
)
plt.savefig("outputs/02_tp_prediction.jpg")

Total running time of the script: (0 minutes 28.377 seconds)