Note
Go to the end to download the full example code.
Generative Downscaling#
Generative downscaling over Taiwan using CorrDiff diffusion model.
This example will demonstrate how to user Nvidia’s CorrDiff model, trained for predicting weather over Taiwan, to perform generative downscaling from quarter degree global forecast data to ~3km.
This checkpoint was trained on ERA5 data and WRF data that spans 2018-2021 at one hour time resolution. In this example, we demonstrate an application to GFS data for a typhoon super-resolution from 2023. The model’s performance on GFS data and on data from this year has not been evaluated.
In this example you will learn:
Creating a custom workflow for running CorrDiff inference
Creating a data-source for CorrDiff’s input
Initializing and running CorrDiff diagnostic model
Post-processing results.
Creating a Simple CorrDiff Workflow#
As usual, we start with creating a simple workflow to run CorrDiff in. To maximize the
generalization of this workflow, we use dependency injection following the pattern
provided inside earth2studio.run
. Since CorrDiff is a diagnostic model, this
workflow won’t predict a time-series, rather just an instantaneous prediction.
For this workflow, we specify
time: Input list of datetimes / strings to run inference for
corrdiff: The initialized CorrDiffTaiwan model
data: Initialized data source to fetch initial conditions from
io: IOBackend
number_of_samples: Number of samples to generate from the model
import os
os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv
load_dotenv() # TODO: make common example prep function
from collections import OrderedDict
from datetime import datetime
import numpy as np
import torch
from loguru import logger
from earth2studio.data import DataSource, prep_data_array
from earth2studio.io import IOBackend
from earth2studio.models.dx import CorrDiffTaiwan
from earth2studio.utils.coords import map_coords, split_coords
from earth2studio.utils.time import to_time_array
def run(
time: list[str] | list[datetime] | list[np.datetime64],
corrdiff: CorrDiffTaiwan,
data: DataSource,
io: IOBackend,
number_of_samples: int = 1,
) -> IOBackend:
"""CorrDiff infernce workflow
Parameters
----------
time : list[str] | list[datetime] | list[np.datetime64]
List of string, datetimes or np.datetime64
corrdiff : CorrDiffTaiwan
CorrDiff mode
data : DataSource
Data source
io : IOBackend
IO object
number_of_samples : int, optional
Number of samples to generate, by default 1
Returns
-------
IOBackend
Output IO object
"""
logger.info("Running corrdiff inference!")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Inference device: {device}")
corrdiff = corrdiff.to(device)
# Update the number of samples for corrdiff to generate
corrdiff.number_of_samples = number_of_samples
# Fetch data from data source and load onto device
time = to_time_array(time)
x, coords = prep_data_array(
data(time, corrdiff.input_coords()["variable"]), device=device
)
x, coords = map_coords(x, coords, corrdiff.input_coords())
logger.success(f"Fetched data from {data.__class__.__name__}")
# Set up IO backend
output_coords = corrdiff.output_coords(corrdiff.input_coords())
total_coords = OrderedDict(
{
"time": coords["time"],
"sample": output_coords["sample"],
"lat": output_coords["lat"],
"lon": output_coords["lon"],
}
)
io.add_array(total_coords, output_coords["variable"])
logger.info("Inference starting!")
x, coords = corrdiff(x, coords)
io.write(*split_coords(x, coords))
logger.success("Inference complete")
return io
Set Up#
With the workflow defined, the next step is initializing the needed components from Earth-2 studio
It’s clear we need the following:
Diagnostic Model: CorrDiff model for Taiwan
earth2studio.models.dx.CorrDiffTaiwan
.Datasource: Pull data from the GFS data api
earth2studio.data.GFS
.IO Backend: Save the outputs into a Zarr store
earth2studio.io.ZarrBackend
.
from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
# Create CorrDiff model
package = CorrDiffTaiwan.load_default_package()
corrdiff = CorrDiffTaiwan.load_model(package)
# Create the data source
data = GFS()
# Create the IO handler, store in memory
io = ZarrBackend()
Downloading corrdiff_inference_package.zip: 0%| | 0.00/684M [00:00<?, ?B/s]
Downloading corrdiff_inference_package.zip: 0%| | 1.16M/684M [00:00<00:59, 12.1MB/s]
Downloading corrdiff_inference_package.zip: 3%|▎ | 17.5M/684M [00:00<00:06, 105MB/s]
Downloading corrdiff_inference_package.zip: 6%|▌ | 38.6M/684M [00:00<00:04, 158MB/s]
Downloading corrdiff_inference_package.zip: 9%|▉ | 59.9M/684M [00:00<00:03, 184MB/s]
Downloading corrdiff_inference_package.zip: 12%|█▏ | 81.2M/684M [00:00<00:03, 198MB/s]
Downloading corrdiff_inference_package.zip: 15%|█▍ | 102M/684M [00:00<00:02, 206MB/s]
Downloading corrdiff_inference_package.zip: 18%|█▊ | 124M/684M [00:00<00:02, 212MB/s]
Downloading corrdiff_inference_package.zip: 21%|██ | 145M/684M [00:00<00:02, 216MB/s]
Downloading corrdiff_inference_package.zip: 24%|██▍ | 167M/684M [00:00<00:02, 219MB/s]
Downloading corrdiff_inference_package.zip: 27%|██▋ | 188M/684M [00:01<00:02, 220MB/s]
Downloading corrdiff_inference_package.zip: 31%|███ | 209M/684M [00:01<00:02, 221MB/s]
Downloading corrdiff_inference_package.zip: 34%|███▍ | 231M/684M [00:01<00:02, 223MB/s]
Downloading corrdiff_inference_package.zip: 37%|███▋ | 252M/684M [00:01<00:02, 223MB/s]
Downloading corrdiff_inference_package.zip: 40%|████ | 274M/684M [00:01<00:01, 223MB/s]
Downloading corrdiff_inference_package.zip: 43%|████▎ | 295M/684M [00:01<00:01, 223MB/s]
Downloading corrdiff_inference_package.zip: 46%|████▋ | 316M/684M [00:01<00:01, 224MB/s]
Downloading corrdiff_inference_package.zip: 49%|████▉ | 338M/684M [00:01<00:01, 224MB/s]
Downloading corrdiff_inference_package.zip: 53%|█████▎ | 359M/684M [00:01<00:01, 224MB/s]
Downloading corrdiff_inference_package.zip: 56%|█████▌ | 381M/684M [00:01<00:01, 224MB/s]
Downloading corrdiff_inference_package.zip: 59%|█████▉ | 402M/684M [00:02<00:01, 224MB/s]
Downloading corrdiff_inference_package.zip: 62%|██████▏ | 424M/684M [00:02<00:01, 224MB/s]
Downloading corrdiff_inference_package.zip: 65%|██████▌ | 445M/684M [00:02<00:01, 224MB/s]
Downloading corrdiff_inference_package.zip: 68%|██████▊ | 467M/684M [00:02<00:01, 224MB/s]
Downloading corrdiff_inference_package.zip: 71%|███████▏ | 488M/684M [00:02<00:00, 224MB/s]
Downloading corrdiff_inference_package.zip: 75%|███████▍ | 510M/684M [00:02<00:00, 223MB/s]
Downloading corrdiff_inference_package.zip: 78%|███████▊ | 531M/684M [00:02<00:00, 224MB/s]
Downloading corrdiff_inference_package.zip: 81%|████████ | 552M/684M [00:02<00:00, 223MB/s]
Downloading corrdiff_inference_package.zip: 84%|████████▍ | 574M/684M [00:02<00:00, 223MB/s]
Downloading corrdiff_inference_package.zip: 87%|████████▋ | 595M/684M [00:02<00:00, 223MB/s]
Downloading corrdiff_inference_package.zip: 90%|█████████ | 616M/684M [00:03<00:00, 222MB/s]
Downloading corrdiff_inference_package.zip: 93%|█████████▎| 637M/684M [00:03<00:00, 221MB/s]
Downloading corrdiff_inference_package.zip: 96%|█████████▋| 659M/684M [00:03<00:00, 219MB/s]
Downloading corrdiff_inference_package.zip: 99%|█████████▉| 679M/684M [00:03<00:00, 218MB/s]
Downloading corrdiff_inference_package.zip: 100%|██████████| 684M/684M [00:03<00:00, 215MB/s]
Execute the Workflow#
With all components initialized, running the workflow is a single line of Python code. Workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information.
For the inference we will predict 1 sample for a particular timestamp representing Typhoon Koinu.
io = run(["2023-10-04T18:00:00"], corrdiff, data, io, number_of_samples=1)
2025-05-16 00:22:21.759 | INFO | __main__:run:108 - Running corrdiff inference!
2025-05-16 00:22:21.759 | INFO | __main__:run:110 - Inference device: cuda
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.345 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 330494443-911634
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.346 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 337873698-579539
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.347 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 418951103-1205697
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.347 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 331406077-849817
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.348 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 258879946-735487
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.349 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 407465443-514270
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.350 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 264233423-556738
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.350 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 258060355-819591
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.351 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 412051984-953804
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.352 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 411076281-975703
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.352 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 264790161-556189
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
2025-05-16 00:22:22.353 | DEBUG | earth2studio.data.gfs:fetch_array:353 - Fetching GFS grib file: noaa-gfs-bdp-pds/gfs.20231004/18/atmos/gfs.t18z.pgrb2.0p25.f000 337303053-570645
Fetching GFS data: 0%| | 0/12 [00:00<?, ?it/s]
Fetching GFS data: 8%|▊ | 1/12 [00:00<00:05, 1.90it/s]
Fetching GFS data: 17%|█▋ | 2/12 [00:00<00:02, 3.53it/s]
Fetching GFS data: 58%|█████▊ | 7/12 [00:00<00:00, 12.31it/s]
Fetching GFS data: 92%|█████████▏| 11/12 [00:00<00:00, 16.53it/s]
Fetching GFS data: 100%|██████████| 12/12 [00:01<00:00, 9.46it/s]
2025-05-16 00:22:23.633 | SUCCESS | __main__:run:123 - Fetched data from GFS
2025-05-16 00:22:23.634 | WARNING | earth2studio.io.zarr:add_array:192 - Datetime64 not supported in zarr 3.0, converting to int64 nanoseconds since epoch
2025-05-16 00:22:23.668 | INFO | __main__:run:137 - Inference starting!
2025-05-16 00:22:25.348 | SUCCESS | __main__:run:141 - Inference complete
Post Processing#
The last step is to post process our results. Cartopy is a great library for plotting fields on projections of a sphere.
Notice that the Zarr IO function has additional APIs to interact with the stored data.
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
projection = ccrs.LambertConformal(
central_longitude=io["lon"][:].mean(),
)
fig = plt.figure(figsize=(4 * 8, 8))
ax0 = fig.add_subplot(1, 3, 1, projection=projection)
c = ax0.pcolormesh(
io["lon"],
io["lat"],
io["mrr"][0, 0],
transform=ccrs.PlateCarree(),
cmap="inferno",
)
plt.colorbar(c, ax=ax0, shrink=0.6, label="mrr dBz")
ax0.coastlines()
ax0.gridlines()
ax0.set_title("Radar Reflectivity")
ax1 = fig.add_subplot(1, 3, 2, projection=projection)
c = ax1.pcolormesh(
io["lon"],
io["lat"],
io["t2m"][0, 0],
transform=ccrs.PlateCarree(),
cmap="RdBu_r",
)
plt.colorbar(c, ax=ax1, shrink=0.6, label="K")
ax1.coastlines()
ax1.gridlines()
ax1.set_title("2-meter Temperature")
ax2 = fig.add_subplot(1, 3, 3, projection=projection)
c = ax2.pcolormesh(
io["lon"],
io["lat"],
np.sqrt(io["u10m"][0, 0] ** 2 + io["v10m"][0, 0] ** 2),
transform=ccrs.PlateCarree(),
cmap="Greens",
)
plt.colorbar(c, ax=ax2, shrink=0.6, label="w10m m s^-1")
ax2.coastlines()
ax2.gridlines()
ax2.set_title("10-meter Wind Speed")
plt.savefig("outputs/04_corr_diff_prediction.jpg")

Total running time of the script: (0 minutes 17.140 seconds)