Running Ensemble Inference#

Simple ensemble inference workflow.

This example will demonstrate how to run a simple inference workflow to generate a ensemble forecast using one of the built in models of Earth-2 Inference Studio.

In this example you will learn:

  • How to instantiate a built in prognostic model

  • Creating a data source and IO object

  • Select a perturbation method

  • Running a simple built in workflow for ensembling

  • Post-processing results

Set Up#

All workflows inside Earth2Studio require constructed components to be handed to them. In this example, we will use the built in ensemble workflow earth2studio.run.ensemble().

def ensemble(
    time: list[str] | list[datetime] | list[np.datetime64],
    nsteps: int,
    nensemble: int,
    prognostic: PrognosticModel,
    data: DataSource,
    io: IOBackend,
    perturbation: Perturbation,
    batch_size: int | None = None,
    output_coords: CoordSystem = OrderedDict({}),
    device: torch.device | None = None,
) -> IOBackend:
    """Built in ensemble workflow.

    Parameters
    ----------
    time : list[str] | list[datetime] | list[np.datetime64]
        List of string, datetimes or np.datetime64
    nsteps : int
        Number of forecast steps
    nensemble : int
        Number of ensemble members to run inference for.
    prognostic : PrognosticModel
        Prognostic models
    data : DataSource
        Data source
    io : IOBackend
        IO object
    perturbation_method : Perturbation
        Method to perturb the initial condition to create an ensemble.
    batch_size: int, optional
        Number of ensemble members to run in a single batch,
        by default None.
    output_coords: CoordSystem, optional
        IO output coordinate system override, by default OrderedDict({})
    device : torch.device, optional
        Device to run inference on, by default None

    Returns
    -------
    IOBackend
        Output IO object
    """

We need the following:

import os

os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv

load_dotenv()  # TODO: make common example prep function

import numpy as np

from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
from earth2studio.models.px import FCN
from earth2studio.perturbation import SphericalGaussian
from earth2studio.run import ensemble

# Load the default model package which downloads the check point from NGC
package = FCN.load_default_package()
model = FCN.load_model(package)

# Instantiate the pertubation method
sg = SphericalGaussian(noise_amplitude=0.15)

# Create the data source
data = GFS()

# Create the IO handler, store in memory
chunks = {"ensemble": 1, "time": 1, "lead_time": 1}
io = ZarrBackend(
    file_name="outputs/03_ensemble_sg.zarr",
    chunks=chunks,
    backend_kwargs={"overwrite": True},
)

Execute the Workflow#

With all components initialized, running the workflow is a single line of Python code. Workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information.

For the forecast we will predict for 10 steps (for FCN, this is 60 hours) with 8 ensemble members which will be ran in 2 batches with batch size 4.

nsteps = 10
nensemble = 8
batch_size = 2
io = ensemble(
    ["2024-01-01"],
    nsteps,
    nensemble,
    model,
    data,
    io,
    sg,
    batch_size=batch_size,
    output_coords={"variable": np.array(["t2m", "tcwv"])},
)
2025-05-15 03:03:39.654 | INFO     | earth2studio.run:ensemble:315 - Running ensemble inference!
2025-05-15 03:03:39.655 | INFO     | earth2studio.run:ensemble:323 - Inference device: cuda

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.704 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 0-993995

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.724 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 199346060-588823

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.725 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 324794050-867372

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.725 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 329739828-930772

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.726 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 452597070-961302

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.727 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 451628742-968328

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.727 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 386359113-963363

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.728 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 143527711-755695

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.728 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 387322476-948106

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.729 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 402321768-876246

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.752 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 323956279-837771

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.774 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 330670600-938837

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.775 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 193737151-726758

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.775 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 247139652-717479

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.776 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 199934883-595444

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.777 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 391722290-987401

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.799 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 194463909-743465

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.800 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 247857131-804857

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.800 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 406629528-962408

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.801 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 414179964-1179422

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.823 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 246334297-805355

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.845 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 323061199-895080

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.846 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 253514796-920355

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.847 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 407591936-940269

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.848 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 252556659-958137

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-05-15 03:03:39.848 | DEBUG    | earth2studio.data.gfs:fetch_array:352 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20240101/00/atmos/gfs.t00z.pgrb2.0p25.f000 393705863-838502

Fetching GFS data:   0%|          | 0/26 [00:00<?, ?it/s]
Fetching GFS data:   4%|▍         | 1/26 [00:00<00:03,  6.88it/s]
Fetching GFS data:  23%|██▎       | 6/26 [00:00<00:02,  7.00it/s]
Fetching GFS data:  27%|██▋       | 7/26 [00:01<00:02,  6.97it/s]
Fetching GFS data:  62%|██████▏   | 16/26 [00:01<00:00, 18.53it/s]
Fetching GFS data:  88%|████████▊ | 23/26 [00:01<00:00, 25.88it/s]
Fetching GFS data: 100%|██████████| 26/26 [00:01<00:00, 19.45it/s]
2025-05-15 03:03:41.107 | SUCCESS  | earth2studio.run:ensemble:345 - Fetched data from GFS
2025-05-15 03:03:41.115 | WARNING  | earth2studio.io.zarr:add_array:192 - Datetime64 not supported in zarr 3.0, converting to int64 nanoseconds since epoch
2025-05-15 03:03:41.119 | WARNING  | earth2studio.io.zarr:add_array:198 - Timedelta64 not supported in zarr 3.0, converting to int64 nanoseconds since epoch
2025-05-15 03:03:41.139 | INFO     | earth2studio.run:ensemble:367 - Starting 8 Member Ensemble Inference with             4 number of batches.

Total Ensemble Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Running batch 0 inference:   0%|          | 0/11 [00:00<?, ?it/s]

Running batch 0 inference:   9%|▉         | 1/11 [00:00<00:03,  2.67it/s]

Running batch 0 inference:  18%|█▊        | 2/11 [00:00<00:04,  2.06it/s]

Running batch 0 inference:  27%|██▋       | 3/11 [00:01<00:03,  2.10it/s]

Running batch 0 inference:  36%|███▋      | 4/11 [00:01<00:03,  2.11it/s]

Running batch 0 inference:  45%|████▌     | 5/11 [00:02<00:02,  2.18it/s]

Running batch 0 inference:  55%|█████▍    | 6/11 [00:02<00:02,  2.24it/s]

Running batch 0 inference:  64%|██████▎   | 7/11 [00:03<00:01,  2.20it/s]

Running batch 0 inference:  73%|███████▎  | 8/11 [00:03<00:01,  2.20it/s]

Running batch 0 inference:  82%|████████▏ | 9/11 [00:04<00:00,  2.23it/s]

Running batch 0 inference:  91%|█████████ | 10/11 [00:04<00:00,  2.21it/s]

Running batch 0 inference: 100%|██████████| 11/11 [00:05<00:00,  2.19it/s]


Total Ensemble Batches:  25%|██▌       | 1/4 [00:08<00:26,  8.68s/it]

Running batch 2 inference:   0%|          | 0/11 [00:00<?, ?it/s]

Running batch 2 inference:   9%|▉         | 1/11 [00:00<00:02,  3.64it/s]

Running batch 2 inference:  18%|█▊        | 2/11 [00:00<00:03,  2.58it/s]

Running batch 2 inference:  27%|██▋       | 3/11 [00:01<00:03,  2.50it/s]

Running batch 2 inference:  36%|███▋      | 4/11 [00:01<00:03,  2.31it/s]

Running batch 2 inference:  45%|████▌     | 5/11 [00:02<00:02,  2.30it/s]

Running batch 2 inference:  55%|█████▍    | 6/11 [00:02<00:02,  2.29it/s]

Running batch 2 inference:  64%|██████▎   | 7/11 [00:02<00:01,  2.25it/s]

Running batch 2 inference:  73%|███████▎  | 8/11 [00:03<00:01,  2.21it/s]

Running batch 2 inference:  82%|████████▏ | 9/11 [00:03<00:00,  2.19it/s]

Running batch 2 inference:  91%|█████████ | 10/11 [00:04<00:00,  2.18it/s]

Running batch 2 inference: 100%|██████████| 11/11 [00:04<00:00,  2.18it/s]


Total Ensemble Batches:  50%|█████     | 2/4 [00:17<00:17,  8.51s/it]

Running batch 4 inference:   0%|          | 0/11 [00:00<?, ?it/s]

Running batch 4 inference:   9%|▉         | 1/11 [00:00<00:02,  3.47it/s]

Running batch 4 inference:  18%|█▊        | 2/11 [00:00<00:03,  2.58it/s]

Running batch 4 inference:  27%|██▋       | 3/11 [00:01<00:03,  2.35it/s]

Running batch 4 inference:  36%|███▋      | 4/11 [00:01<00:03,  2.27it/s]

Running batch 4 inference:  45%|████▌     | 5/11 [00:02<00:02,  2.22it/s]

Running batch 4 inference:  55%|█████▍    | 6/11 [00:02<00:02,  2.20it/s]

Running batch 4 inference:  64%|██████▎   | 7/11 [00:03<00:01,  2.18it/s]

Running batch 4 inference:  73%|███████▎  | 8/11 [00:03<00:01,  2.16it/s]

Running batch 4 inference:  82%|████████▏ | 9/11 [00:04<00:00,  2.14it/s]

Running batch 4 inference:  91%|█████████ | 10/11 [00:04<00:00,  2.13it/s]

Running batch 4 inference: 100%|██████████| 11/11 [00:04<00:00,  2.13it/s]


Total Ensemble Batches:  75%|███████▌  | 3/4 [00:25<00:08,  8.52s/it]

Running batch 6 inference:   0%|          | 0/11 [00:00<?, ?it/s]

Running batch 6 inference:   9%|▉         | 1/11 [00:00<00:02,  3.45it/s]

Running batch 6 inference:  18%|█▊        | 2/11 [00:00<00:03,  2.47it/s]

Running batch 6 inference:  27%|██▋       | 3/11 [00:01<00:03,  2.29it/s]

Running batch 6 inference:  36%|███▋      | 4/11 [00:01<00:03,  2.24it/s]

Running batch 6 inference:  45%|████▌     | 5/11 [00:02<00:02,  2.20it/s]

Running batch 6 inference:  55%|█████▍    | 6/11 [00:02<00:02,  2.17it/s]

Running batch 6 inference:  64%|██████▎   | 7/11 [00:03<00:01,  2.19it/s]

Running batch 6 inference:  73%|███████▎  | 8/11 [00:03<00:01,  2.18it/s]

Running batch 6 inference:  82%|████████▏ | 9/11 [00:04<00:00,  2.20it/s]

Running batch 6 inference:  91%|█████████ | 10/11 [00:04<00:00,  2.16it/s]

Running batch 6 inference: 100%|██████████| 11/11 [00:04<00:00,  2.20it/s]


Total Ensemble Batches: 100%|██████████| 4/4 [00:34<00:00,  8.49s/it]
Total Ensemble Batches: 100%|██████████| 4/4 [00:34<00:00,  8.51s/it]
2025-05-15 03:04:15.183 | SUCCESS  | earth2studio.run:ensemble:412 - Inference complete

Post Processing#

The last step is to post process our results. Cartopy is a great library for plotting fields on projections of a sphere.

Notice that the Zarr IO function has additional APIs to interact with the stored data.

import cartopy.crs as ccrs
import matplotlib.pyplot as plt

forecast = "2024-01-01"


def plot_(axi, data, title, cmap):
    """Convenience function for plotting pcolormesh."""
    # Plot the field using pcolormesh
    im = axi.pcolormesh(
        io["lon"][:],
        io["lat"][:],
        data,
        transform=ccrs.PlateCarree(),
        cmap=cmap,
    )
    plt.colorbar(im, ax=axi, shrink=0.6, pad=0.04)
    # Set title
    axi.set_title(title)
    # Add coastlines and gridlines
    axi.coastlines()
    axi.gridlines()


for variable, cmap in zip(["tcwv"], ["Blues"]):
    step = 4  # lead time = 24 hrs

    plt.close("all")
    # Create a Robinson projection
    projection = ccrs.Robinson()

    # Create a figure and axes with the specified projection
    fig, (ax1, ax2, ax3) = plt.subplots(
        nrows=1, ncols=3, subplot_kw={"projection": projection}, figsize=(16, 3)
    )

    plot_(
        ax1,
        io[variable][0, 0, step],
        f"{forecast} - Lead time: {6*step}hrs - Member: {0}",
        cmap,
    )
    plot_(
        ax2,
        io[variable][1, 0, step],
        f"{forecast} - Lead time: {6*step}hrs - Member: {1}",
        cmap,
    )
    plot_(
        ax3,
        np.std(io[variable][:, 0, step], axis=0),
        f"{forecast} - Lead time: {6*step}hrs - Std",
        cmap,
    )

    plt.savefig(f"outputs/03_{forecast}_{variable}_{step}_ensemble.jpg")
2024-01-01 - Lead time: 24hrs - Member: 0, 2024-01-01 - Lead time: 24hrs - Member: 1, 2024-01-01 - Lead time: 24hrs - Std

Total running time of the script: (3 minutes 55.459 seconds)

Gallery generated by Sphinx-Gallery