Note
Go to the end to download the full example code.
Generative Downscaling#
Generative downscaling over Taiwan using CorrDiff diffusion model.
This example will demonstrate how to user Nvidia’s CorrDiff model, trained for predicting weather over Taiwan, to perform generative downscaling from quarter degree global forecast data to ~3km.
This checkpoint was trained on ERA5 data and WRF data that spans 2018-2021 at one hour time resolution. In this example, we demonstrate an application to GFS data for a typhoon super-resolution from 2023. The model’s performance on GFS data and on data from this year has not been evaluated.
In this example you will learn:
Creating a custom workflow for running CorrDiff inference
Creating a data-source for CorrDiff’s input
Initializing and running CorrDiff diagnostic model
Post-processing results.
Creating a Simple CorrDiff Workflow#
As usual, we start with creating a simple workflow to run CorrDiff in. To maximize the
generalization of this workflow, we use dependency injection following the pattern
provided inside earth2studio.run
. Since CorrDiff is a diagnostic model, this
workflow won’t predict a time-series, rather just an instantaneous prediction.
For this workflow, we specify
time: Input list of datetimes / strings to run inference for
corrdiff: The initialized CorrDiffTaiwan model
data: Initialized data source to fetch initial conditions from
io: IOBackend
number_of_samples: Number of samples to generate from the model
import os
os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv
load_dotenv() # TODO: make common example prep function
from collections import OrderedDict
from datetime import datetime
import numpy as np
import torch
from loguru import logger
from earth2studio.data import DataSource, prep_data_array
from earth2studio.io import IOBackend
from earth2studio.models.dx import CorrDiffTaiwan
from earth2studio.utils.coords import map_coords, split_coords
from earth2studio.utils.time import to_time_array
def run(
time: list[str] | list[datetime] | list[np.datetime64],
corrdiff: CorrDiffTaiwan,
data: DataSource,
io: IOBackend,
number_of_samples: int = 1,
) -> IOBackend:
"""CorrDiff infernce workflow
Parameters
----------
time : list[str] | list[datetime] | list[np.datetime64]
List of string, datetimes or np.datetime64
corrdiff : CorrDiffTaiwan
CorrDiff mode
data : DataSource
Data source
io : IOBackend
IO object
number_of_samples : int, optional
Number of samples to generate, by default 1
Returns
-------
IOBackend
Output IO object
"""
logger.info("Running corrdiff inference!")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Inference device: {device}")
corrdiff = corrdiff.to(device)
# Update the number of samples for corrdiff to generate
corrdiff.number_of_samples = number_of_samples
# Fetch data from data source and load onto device
time = to_time_array(time)
x, coords = prep_data_array(
data(time, corrdiff.input_coords()["variable"]), device=device
)
x, coords = map_coords(x, coords, corrdiff.input_coords())
logger.success(f"Fetched data from {data.__class__.__name__}")
# Set up IO backend
output_coords = corrdiff.output_coords(corrdiff.input_coords())
total_coords = OrderedDict(
{
"time": coords["time"],
"sample": output_coords["sample"],
"ilat": output_coords["ilat"],
"ilon": output_coords["ilon"],
}
)
io.add_array(total_coords, output_coords["variable"])
# Add lat/lon grid metadata arrays
io.add_array(
OrderedDict({"ilat": total_coords["ilat"], "ilon": total_coords["ilon"]}),
"lat",
data=corrdiff.out_lat,
)
io.add_array(
OrderedDict({"ilat": total_coords["ilat"], "ilon": total_coords["ilon"]}),
"lon",
data=corrdiff.out_lon,
)
logger.info("Inference starting!")
x, coords = corrdiff(x, coords)
io.write(*split_coords(x, coords))
logger.success("Inference complete")
return io
Set Up#
With the workflow defined, the next step is initializing the needed components from Earth-2 studio
It’s clear we need the following:
Diagnostic Model: CorrDiff model for Taiwan
earth2studio.models.dx.CorrDiffTaiwan
.Datasource: Pull data from the GFS data api
earth2studio.data.GFS
.IO Backend: Save the outputs into a Zarr store
earth2studio.io.ZarrBackend
.
from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
# Create CorrDiff model
package = CorrDiffTaiwan.load_default_package()
corrdiff = CorrDiffTaiwan.load_model(package)
# Create the data source
data = GFS()
# Create the IO handler, store in memory
io = ZarrBackend()
Execute the Workflow#
With all components initialized, running the workflow is a single line of Python code. Workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information.
For the inference we will predict 1 sample for a particular timestamp representing Typhoon Koinu.
io = run(["2023-10-04T18:00:00"], corrdiff, data, io, number_of_samples=1)
2024-06-25 13:58:34.056 | INFO | __main__:run:108 - Running corrdiff inference!
2024-06-25 13:58:34.057 | INFO | __main__:run:110 - Inference device: cuda
2024-06-25 13:58:34.137 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:149 - Fetching GFS index file: 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 0%| | 0/12 [00:00<?, ?it/s]
2024-06-25 13:58:34.460 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: tcwv at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 0%| | 0/12 [00:00<?, ?it/s]
Fetching GFS for 2023-10-04 18:00:00: 8%|▊ | 1/12 [00:05<00:58, 5.28s/it]
2024-06-25 13:58:39.739 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z500 at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 8%|▊ | 1/12 [00:05<00:58, 5.28s/it]
Fetching GFS for 2023-10-04 18:00:00: 17%|█▋ | 2/12 [00:08<00:41, 4.19s/it]
2024-06-25 13:58:43.174 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: t500 at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 17%|█▋ | 2/12 [00:08<00:41, 4.19s/it]
Fetching GFS for 2023-10-04 18:00:00: 25%|██▌ | 3/12 [00:11<00:30, 3.34s/it]
2024-06-25 13:58:45.506 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: u500 at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 25%|██▌ | 3/12 [00:11<00:30, 3.34s/it]
Fetching GFS for 2023-10-04 18:00:00: 33%|███▎ | 4/12 [00:12<00:21, 2.64s/it]
2024-06-25 13:58:47.062 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: v500 at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 33%|███▎ | 4/12 [00:12<00:21, 2.64s/it]
Fetching GFS for 2023-10-04 18:00:00: 42%|████▏ | 5/12 [00:13<00:14, 2.08s/it]
2024-06-25 13:58:48.152 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z850 at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 42%|████▏ | 5/12 [00:13<00:14, 2.08s/it]
Fetching GFS for 2023-10-04 18:00:00: 50%|█████ | 6/12 [00:14<00:10, 1.72s/it]
2024-06-25 13:58:49.169 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: t850 at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 50%|█████ | 6/12 [00:14<00:10, 1.72s/it]
Fetching GFS for 2023-10-04 18:00:00: 58%|█████▊ | 7/12 [00:15<00:07, 1.43s/it]
2024-06-25 13:58:50.010 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: u850 at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 58%|█████▊ | 7/12 [00:15<00:07, 1.43s/it]
Fetching GFS for 2023-10-04 18:00:00: 67%|██████▋ | 8/12 [00:16<00:04, 1.20s/it]
2024-06-25 13:58:50.712 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: v850 at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 67%|██████▋ | 8/12 [00:16<00:04, 1.20s/it]
Fetching GFS for 2023-10-04 18:00:00: 75%|███████▌ | 9/12 [00:16<00:03, 1.03s/it]
2024-06-25 13:58:51.383 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: t2m at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 75%|███████▌ | 9/12 [00:16<00:03, 1.03s/it]
Fetching GFS for 2023-10-04 18:00:00: 83%|████████▎ | 10/12 [00:17<00:01, 1.11it/s]
2024-06-25 13:58:51.991 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: u10m at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 83%|████████▎ | 10/12 [00:17<00:01, 1.11it/s]
Fetching GFS for 2023-10-04 18:00:00: 92%|█████████▏| 11/12 [00:18<00:00, 1.17it/s]
2024-06-25 13:58:52.728 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: v10m at 2023-10-04 18:00:00
Fetching GFS for 2023-10-04 18:00:00: 92%|█████████▏| 11/12 [00:18<00:00, 1.17it/s]
Fetching GFS for 2023-10-04 18:00:00: 100%|██████████| 12/12 [00:18<00:00, 1.24it/s]
Fetching GFS for 2023-10-04 18:00:00: 100%|██████████| 12/12 [00:18<00:00, 1.58s/it]
2024-06-25 13:58:53.457 | SUCCESS | __main__:run:123 - Fetched data from GFS
2024-06-25 13:58:53.462 | INFO | __main__:run:149 - Inference starting!
2024-06-25 13:58:54.700 | SUCCESS | __main__:run:153 - Inference complete
Post Processing#
The last step is to post process our results. Cartopy is a great library for plotting fields on projections of a sphere.
Notice that the Zarr IO function has additional APIs to interact with the stored data.
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
projection = ccrs.LambertConformal(
central_longitude=io["lon"][:].mean(),
)
fig = plt.figure(figsize=(4 * 8, 8))
ax0 = fig.add_subplot(1, 3, 1, projection=projection)
c = ax0.pcolormesh(
io["lon"],
io["lat"],
io["mrr"][0, 0],
transform=ccrs.PlateCarree(),
cmap="inferno",
)
plt.colorbar(c, ax=ax0, shrink=0.6, label="mrr dBz")
ax0.coastlines()
ax0.gridlines()
ax0.set_title("Radar Reflectivity")
ax1 = fig.add_subplot(1, 3, 2, projection=projection)
c = ax1.pcolormesh(
io["lon"],
io["lat"],
io["t2m"][0, 0],
transform=ccrs.PlateCarree(),
cmap="RdBu_r",
)
plt.colorbar(c, ax=ax1, shrink=0.6, label="K")
ax1.coastlines()
ax1.gridlines()
ax1.set_title("2-meter Temperature")
ax2 = fig.add_subplot(1, 3, 3, projection=projection)
c = ax2.pcolormesh(
io["lon"],
io["lat"],
np.sqrt(io["u10m"][0, 0] ** 2 + io["v10m"][0, 0] ** 2),
transform=ccrs.PlateCarree(),
cmap="Greens",
)
plt.colorbar(c, ax=ax2, shrink=0.6, label="w10m m s^-1")
ax2.coastlines()
ax2.gridlines()
ax2.set_title("10-meter Wind Speed")
plt.savefig("outputs/04_corr_diff_prediction.jpg")
/usr/local/lib/python3.10/dist-packages/cartopy/io/__init__.py:241: DownloadWarning: Downloading: https://naturalearth.s3.amazonaws.com/10m_physical/ne_10m_coastline.zip
warnings.warn(f'Downloading: {url}', DownloadWarning)
Total running time of the script: (0 minutes 38.621 seconds)