Note
Go to the end to download the full example code.
Running Ensemble Inference#
Simple ensemble inference workflow.
This example will demonstrate how to run a simple inference workflow to generate a ensemble forecast using one of the built in models of Earth-2 Inference Studio.
In this example you will learn:
How to instantiate a built in prognostic model
Creating a data source and IO object
Select a perturbation method
Running a simple built in workflow for ensembling
Post-processing results
Set Up#
All workflows inside Earth2Studio require constructed components to be
handed to them. In this example, we will use the built in ensemble workflow
earth2studio.run.ensemble()
.
def ensemble(
time: list[str] | list[datetime] | list[np.datetime64],
nsteps: int,
nensemble: int,
prognostic: PrognosticModel,
data: DataSource,
io: IOBackend,
perturbation: Perturbation,
batch_size: int | None = None,
output_coords: CoordSystem = OrderedDict({}),
device: torch.device | None = None,
) -> IOBackend:
"""Built in ensemble workflow.
Parameters
----------
time : list[str] | list[datetime] | list[np.datetime64]
List of string, datetimes or np.datetime64
nsteps : int
Number of forecast steps
nensemble : int
Number of ensemble members to run inference for.
prognostic : PrognosticModel
Prognostic models
data : DataSource
Data source
io : IOBackend
IO object
perturbation_method : Perturbation
Method to perturb the initial condition to create an ensemble.
batch_size: int, optional
Number of ensemble members to run in a single batch,
by default None.
output_coords: CoordSystem, optional
IO output coordinate system override, by default OrderedDict({})
device : torch.device, optional
Device to run inference on, by default None
Returns
-------
IOBackend
Output IO object
"""
We need the following:
Prognostic Model: Use the built in FourCastNet model
earth2studio.models.px.FCN
.perturbation_method: Use the Spherical Gaussian Method
earth2studio.perturbation.SphericalGaussian
.Datasource: Pull data from the GFS data api
earth2studio.data.GFS
.IO Backend: Save the outputs into a Zarr store
earth2studio.io.ZarrBackend
.
import os
os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv
load_dotenv() # TODO: make common example prep function
import numpy as np
from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
from earth2studio.models.px import FCN
from earth2studio.perturbation import SphericalGaussian
from earth2studio.run import ensemble
# Load the default model package which downloads the check point from NGC
package = FCN.load_default_package()
model = FCN.load_model(package)
# Instantiate the pertubation method
sg = SphericalGaussian(noise_amplitude=0.15)
# Create the data source
data = GFS()
# Create the IO handler, store in memory
chunks = {"ensemble": 1, "time": 1, "lead_time": 1}
io = ZarrBackend(
file_name="outputs/03_ensemble_sg.zarr",
chunks=chunks,
backend_kwargs={"overwrite": True},
)
/usr/local/lib/python3.10/dist-packages/modulus/models/module.py:360: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
model_dict = torch.load(
Execute the Workflow#
With all components initialized, running the workflow is a single line of Python code. Workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information.
For the forecast we will predict for 10 steps (for FCN, this is 60 hours) with 8 ensemble members which will be ran in 2 batches with batch size 4.
nsteps = 10
nensemble = 8
batch_size = 2
io = ensemble(
["2024-01-01"],
nsteps,
nensemble,
model,
data,
io,
sg,
batch_size=batch_size,
output_coords={"variable": np.array(["t2m", "tcwv"])},
)
2025-01-23 04:38:47.879 | INFO | earth2studio.run:ensemble:315 - Running ensemble inference!
2025-01-23 04:38:47.880 | INFO | earth2studio.run:ensemble:323 - Inference device: cuda
2025-01-23 04:38:47.926 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:209 - Fetching GFS index file: 2024-01-01 00:00:00 lead 0:00:00
Fetching GFS for 2024-01-01 00:00:00: 0%| | 0/26 [00:00<?, ?it/s]
2025-01-23 04:38:47.930 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u10m at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 0%| | 0/26 [00:00<?, ?it/s]
2025-01-23 04:38:47.958 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v10m at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 0%| | 0/26 [00:00<?, ?it/s]
2025-01-23 04:38:47.984 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t2m at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 0%| | 0/26 [00:00<?, ?it/s]
2025-01-23 04:38:48.012 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: sp at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 0%| | 0/26 [00:00<?, ?it/s]
Fetching GFS for 2024-01-01 00:00:00: 15%|█▌ | 4/26 [00:00<00:00, 37.15it/s]
2025-01-23 04:38:48.038 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: msl at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 15%|█▌ | 4/26 [00:00<00:00, 37.15it/s]
2025-01-23 04:38:48.064 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t850 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 15%|█▌ | 4/26 [00:00<00:00, 37.15it/s]
2025-01-23 04:38:48.090 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u1000 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 15%|█▌ | 4/26 [00:00<00:00, 37.15it/s]
2025-01-23 04:38:48.117 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v1000 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 15%|█▌ | 4/26 [00:00<00:00, 37.15it/s]
Fetching GFS for 2024-01-01 00:00:00: 31%|███ | 8/26 [00:00<00:00, 37.54it/s]
2025-01-23 04:38:48.144 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z1000 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 31%|███ | 8/26 [00:00<00:00, 37.54it/s]
2025-01-23 04:38:48.171 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u850 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 31%|███ | 8/26 [00:00<00:00, 37.54it/s]
2025-01-23 04:38:48.197 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v850 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 31%|███ | 8/26 [00:00<00:00, 37.54it/s]
2025-01-23 04:38:48.224 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z850 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 31%|███ | 8/26 [00:00<00:00, 37.54it/s]
Fetching GFS for 2024-01-01 00:00:00: 46%|████▌ | 12/26 [00:00<00:00, 37.31it/s]
2025-01-23 04:38:48.252 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u500 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 46%|████▌ | 12/26 [00:00<00:00, 37.31it/s]
2025-01-23 04:38:48.278 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v500 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 46%|████▌ | 12/26 [00:00<00:00, 37.31it/s]
2025-01-23 04:38:48.305 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z500 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 46%|████▌ | 12/26 [00:00<00:00, 37.31it/s]
2025-01-23 04:38:48.331 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t500 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 46%|████▌ | 12/26 [00:00<00:00, 37.31it/s]
Fetching GFS for 2024-01-01 00:00:00: 62%|██████▏ | 16/26 [00:00<00:00, 37.48it/s]
2025-01-23 04:38:48.358 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z50 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 62%|██████▏ | 16/26 [00:00<00:00, 37.48it/s]
2025-01-23 04:38:48.385 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: r500 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 62%|██████▏ | 16/26 [00:00<00:00, 37.48it/s]
2025-01-23 04:38:48.411 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: r850 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 62%|██████▏ | 16/26 [00:00<00:00, 37.48it/s]
2025-01-23 04:38:48.438 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: tcwv at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 62%|██████▏ | 16/26 [00:00<00:00, 37.48it/s]
Fetching GFS for 2024-01-01 00:00:00: 77%|███████▋ | 20/26 [00:00<00:00, 37.46it/s]
2025-01-23 04:38:48.465 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u100m at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 77%|███████▋ | 20/26 [00:00<00:00, 37.46it/s]
2025-01-23 04:38:48.491 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v100m at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 77%|███████▋ | 20/26 [00:00<00:00, 37.46it/s]
2025-01-23 04:38:48.517 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: u250 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 77%|███████▋ | 20/26 [00:00<00:00, 37.46it/s]
2025-01-23 04:38:48.543 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: v250 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 77%|███████▋ | 20/26 [00:00<00:00, 37.46it/s]
Fetching GFS for 2024-01-01 00:00:00: 92%|█████████▏| 24/26 [00:00<00:00, 37.67it/s]
2025-01-23 04:38:48.570 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z250 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 92%|█████████▏| 24/26 [00:00<00:00, 37.67it/s]
2025-01-23 04:38:48.596 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t250 at 2024-01-01 00:00:00_0:00:00
Fetching GFS for 2024-01-01 00:00:00: 92%|█████████▏| 24/26 [00:00<00:00, 37.67it/s]
Fetching GFS for 2024-01-01 00:00:00: 100%|██████████| 26/26 [00:00<00:00, 37.58it/s]
2025-01-23 04:38:48.735 | SUCCESS | earth2studio.run:ensemble:345 - Fetched data from GFS
2025-01-23 04:38:48.744 | INFO | earth2studio.run:ensemble:367 - Starting 8 Member Ensemble Inference with 4 number of batches.
Total Ensemble Batches: 0%| | 0/4 [00:00<?, ?it/s]
Running batch 0 inference: 0%| | 0/11 [00:00<?, ?it/s]
Running batch 0 inference: 9%|▉ | 1/11 [00:00<00:02, 4.29it/s]
Running batch 0 inference: 18%|█▊ | 2/11 [00:00<00:04, 2.17it/s]
Running batch 0 inference: 27%|██▋ | 3/11 [00:01<00:03, 2.09it/s]
Running batch 0 inference: 36%|███▋ | 4/11 [00:01<00:03, 2.06it/s]
Running batch 0 inference: 45%|████▌ | 5/11 [00:02<00:02, 2.05it/s]
Running batch 0 inference: 55%|█████▍ | 6/11 [00:02<00:02, 2.04it/s]
Running batch 0 inference: 64%|██████▎ | 7/11 [00:03<00:01, 2.04it/s]
Running batch 0 inference: 73%|███████▎ | 8/11 [00:03<00:01, 2.03it/s]
Running batch 0 inference: 82%|████████▏ | 9/11 [00:04<00:00, 2.03it/s]
Running batch 0 inference: 91%|█████████ | 10/11 [00:04<00:00, 2.03it/s]
Running batch 0 inference: 100%|██████████| 11/11 [00:05<00:00, 2.04it/s]
Total Ensemble Batches: 25%|██▌ | 1/4 [00:10<00:31, 10.37s/it]
Running batch 2 inference: 0%| | 0/11 [00:00<?, ?it/s]
Running batch 2 inference: 9%|▉ | 1/11 [00:00<00:01, 5.17it/s]
Running batch 2 inference: 18%|█▊ | 2/11 [00:00<00:03, 2.71it/s]
Running batch 2 inference: 27%|██▋ | 3/11 [00:01<00:03, 2.35it/s]
Running batch 2 inference: 36%|███▋ | 4/11 [00:01<00:03, 2.21it/s]
Running batch 2 inference: 45%|████▌ | 5/11 [00:02<00:02, 2.13it/s]
Running batch 2 inference: 55%|█████▍ | 6/11 [00:02<00:02, 2.09it/s]
Running batch 2 inference: 64%|██████▎ | 7/11 [00:03<00:01, 2.07it/s]
Running batch 2 inference: 73%|███████▎ | 8/11 [00:03<00:01, 2.05it/s]
Running batch 2 inference: 82%|████████▏ | 9/11 [00:04<00:00, 2.05it/s]
Running batch 2 inference: 91%|█████████ | 10/11 [00:04<00:00, 2.05it/s]
Running batch 2 inference: 100%|██████████| 11/11 [00:05<00:00, 2.04it/s]
Total Ensemble Batches: 50%|█████ | 2/4 [00:20<00:20, 10.24s/it]
Running batch 4 inference: 0%| | 0/11 [00:00<?, ?it/s]
Running batch 4 inference: 9%|▉ | 1/11 [00:00<00:01, 5.18it/s]
Running batch 4 inference: 18%|█▊ | 2/11 [00:00<00:03, 2.70it/s]
Running batch 4 inference: 27%|██▋ | 3/11 [00:01<00:03, 2.34it/s]
Running batch 4 inference: 36%|███▋ | 4/11 [00:01<00:03, 2.21it/s]
Running batch 4 inference: 45%|████▌ | 5/11 [00:02<00:02, 2.14it/s]
Running batch 4 inference: 55%|█████▍ | 6/11 [00:02<00:02, 2.10it/s]
Running batch 4 inference: 64%|██████▎ | 7/11 [00:03<00:01, 2.06it/s]
Running batch 4 inference: 73%|███████▎ | 8/11 [00:03<00:01, 2.05it/s]
Running batch 4 inference: 82%|████████▏ | 9/11 [00:04<00:00, 2.03it/s]
Running batch 4 inference: 91%|█████████ | 10/11 [00:04<00:00, 2.02it/s]
Running batch 4 inference: 100%|██████████| 11/11 [00:05<00:00, 2.02it/s]
Total Ensemble Batches: 75%|███████▌ | 3/4 [00:30<00:10, 10.16s/it]
Running batch 6 inference: 0%| | 0/11 [00:00<?, ?it/s]
Running batch 6 inference: 9%|▉ | 1/11 [00:00<00:01, 5.01it/s]
Running batch 6 inference: 18%|█▊ | 2/11 [00:00<00:03, 2.65it/s]
Running batch 6 inference: 27%|██▋ | 3/11 [00:01<00:03, 2.31it/s]
Running batch 6 inference: 36%|███▋ | 4/11 [00:01<00:03, 2.18it/s]
Running batch 6 inference: 45%|████▌ | 5/11 [00:02<00:02, 2.12it/s]
Running batch 6 inference: 55%|█████▍ | 6/11 [00:02<00:02, 2.08it/s]
Running batch 6 inference: 64%|██████▎ | 7/11 [00:03<00:01, 2.06it/s]
Running batch 6 inference: 73%|███████▎ | 8/11 [00:03<00:01, 2.05it/s]
Running batch 6 inference: 82%|████████▏ | 9/11 [00:04<00:00, 2.04it/s]
Running batch 6 inference: 91%|█████████ | 10/11 [00:04<00:00, 2.03it/s]
Running batch 6 inference: 100%|██████████| 11/11 [00:05<00:00, 2.03it/s]
Total Ensemble Batches: 100%|██████████| 4/4 [00:40<00:00, 10.13s/it]
Total Ensemble Batches: 100%|██████████| 4/4 [00:40<00:00, 10.17s/it]
2025-01-23 04:39:29.405 | SUCCESS | earth2studio.run:ensemble:412 - Inference complete
Post Processing#
The last step is to post process our results. Cartopy is a great library for plotting fields on projections of a sphere.
Notice that the Zarr IO function has additional APIs to interact with the stored data.
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
forecast = "2024-01-01"
def plot_(axi, data, title, cmap):
"""Convenience function for plotting pcolormesh."""
# Plot the field using pcolormesh
im = axi.pcolormesh(
io["lon"][:],
io["lat"][:],
data,
transform=ccrs.PlateCarree(),
cmap=cmap,
)
plt.colorbar(im, ax=axi, shrink=0.6, pad=0.04)
# Set title
axi.set_title(title)
# Add coastlines and gridlines
axi.coastlines()
axi.gridlines()
for variable, cmap in zip(["tcwv"], ["Blues"]):
step = 4 # lead time = 24 hrs
plt.close("all")
# Create a Robinson projection
projection = ccrs.Robinson()
# Create a figure and axes with the specified projection
fig, (ax1, ax2, ax3) = plt.subplots(
nrows=1, ncols=3, subplot_kw={"projection": projection}, figsize=(16, 3)
)
plot_(
ax1,
io[variable][0, 0, step],
f"{forecast} - Lead time: {6*step}hrs - Member: {0}",
cmap,
)
plot_(
ax2,
io[variable][1, 0, step],
f"{forecast} - Lead time: {6*step}hrs - Member: {1}",
cmap,
)
plot_(
ax3,
np.std(io[variable][:, 0, step], axis=0),
f"{forecast} - Lead time: {6*step}hrs - Std",
cmap,
)
plt.savefig(f"outputs/03_{forecast}_{variable}_{step}_ensemble.jpg")
Total running time of the script: (2 minutes 9.081 seconds)