Earth2Studio is now OSS!

Generative Downscaling#

Generative downscaling over Taiwan using CorrDiff diffusion model.

This example will demonstrate how to user Nvidia’s CorrDiff model, trained for predicting weather over Taiwan, to perform generative downscaling from quarter degree global forecast data to ~3km.

This checkpoint was trained on ERA5 data and WRF data that spans 2018-2021 at one hour time resolution. In this example, we demonstrate an application to GFS data for a typhoon super-resolution from 2023. The model’s performance on GFS data and on data from this year has not been evaluated.

In this example you will learn:

  • Creating a custom workflow for running CorrDiff inference

  • Creating a data-source for CorrDiff’s input

  • Initializing and running CorrDiff diagnostic model

  • Post-processing results.

Creating a Simple CorrDiff Workflow#

As usual, we start with creating a simple workflow to run CorrDiff in. To maximize the generalization of this workflow, we use dependency injection following the pattern provided inside earth2studio.run. Since CorrDiff is a diagnostic model, this workflow won’t predict a time-series, rather just an instantaneous prediction.

For this workflow, we specify

  • time: Input list of datetimes / strings to run inference for

  • corrdiff: The initialized CorrDiffTaiwan model

  • data: Initialized data source to fetch initial conditions from

  • io: IOBackend

  • number_of_samples: Number of samples to generate from the model

import os

os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv

load_dotenv()  # TODO: make common example prep function

from collections import OrderedDict
from datetime import datetime

import numpy as np
import torch
from loguru import logger

from earth2studio.data import DataSource, prep_data_array
from earth2studio.io import IOBackend
from earth2studio.models.dx import CorrDiffTaiwan
from earth2studio.utils.coords import map_coords, split_coords
from earth2studio.utils.time import to_time_array


def run(
    time: list[str] | list[datetime] | list[np.datetime64],
    corrdiff: CorrDiffTaiwan,
    data: DataSource,
    io: IOBackend,
    number_of_samples: int = 1,
) -> IOBackend:
    """CorrDiff infernce workflow

    Parameters
    ----------
    time : list[str] | list[datetime] | list[np.datetime64]
        List of string, datetimes or np.datetime64
    corrdiff : CorrDiffTaiwan
        CorrDiff mode
    data : DataSource
        Data source
    io : IOBackend
        IO object
    number_of_samples : int, optional
        Number of samples to generate, by default 1

    Returns
    -------
    IOBackend
        Output IO object
    """
    logger.info("Running corrdiff inference!")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Inference device: {device}")

    corrdiff = corrdiff.to(device)
    # Update the number of samples for corrdiff to generate
    corrdiff.number_of_samples = number_of_samples

    # Fetch data from data source and load onto device
    time = to_time_array(time)
    x, coords = prep_data_array(
        data(time, corrdiff.input_coords()["variable"]), device=device
    )
    x, coords = map_coords(x, coords, corrdiff.input_coords())

    logger.success(f"Fetched data from {data.__class__.__name__}")

    # Set up IO backend
    output_coords = corrdiff.output_coords(corrdiff.input_coords())
    total_coords = OrderedDict(
        {
            "time": coords["time"],
            "sample": output_coords["sample"],
            "ilat": output_coords["ilat"],
            "ilon": output_coords["ilon"],
        }
    )
    io.add_array(total_coords, output_coords["variable"])

    # Add lat/lon grid metadata arrays
    io.add_array(
        OrderedDict({"ilat": total_coords["ilat"], "ilon": total_coords["ilon"]}),
        "lat",
        data=corrdiff.out_lat,
    )
    io.add_array(
        OrderedDict({"ilat": total_coords["ilat"], "ilon": total_coords["ilon"]}),
        "lon",
        data=corrdiff.out_lon,
    )

    logger.info("Inference starting!")
    x, coords = corrdiff(x, coords)
    io.write(*split_coords(x, coords))

    logger.success("Inference complete")
    return io

Set Up#

With the workflow defined, the next step is initializing the needed components from Earth-2 studio

It’s clear we need the following:

from earth2studio.data import GFS
from earth2studio.io import ZarrBackend

# Create CorrDiff model
package = CorrDiffTaiwan.load_default_package()
corrdiff = CorrDiffTaiwan.load_model(package)

# Create the data source
data = GFS()

# Create the IO handler, store in memory
io = ZarrBackend()

Execute the Workflow#

With all components initialized, running the workflow is a single line of Python code. Workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information.

For the inference we will predict 1 sample for a particular timestamp representing Typhoon Koinu.

io = run(["2023-10-04T18:00:00"], corrdiff, data, io, number_of_samples=1)
2024-06-25 13:58:34.056 | INFO     | __main__:run:108 - Running corrdiff inference!
2024-06-25 13:58:34.057 | INFO     | __main__:run:110 - Inference device: cuda
2024-06-25 13:58:34.137 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:149 - Fetching GFS index file: 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:   0%|          | 0/12 [00:00<?, ?it/s]

2024-06-25 13:58:34.460 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: tcwv at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:   0%|          | 0/12 [00:00<?, ?it/s]
Fetching GFS for 2023-10-04 18:00:00:   8%|▊         | 1/12 [00:05<00:58,  5.28s/it]

2024-06-25 13:58:39.739 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z500 at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:   8%|▊         | 1/12 [00:05<00:58,  5.28s/it]
Fetching GFS for 2023-10-04 18:00:00:  17%|█▋        | 2/12 [00:08<00:41,  4.19s/it]

2024-06-25 13:58:43.174 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: t500 at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:  17%|█▋        | 2/12 [00:08<00:41,  4.19s/it]
Fetching GFS for 2023-10-04 18:00:00:  25%|██▌       | 3/12 [00:11<00:30,  3.34s/it]

2024-06-25 13:58:45.506 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: u500 at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:  25%|██▌       | 3/12 [00:11<00:30,  3.34s/it]
Fetching GFS for 2023-10-04 18:00:00:  33%|███▎      | 4/12 [00:12<00:21,  2.64s/it]

2024-06-25 13:58:47.062 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: v500 at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:  33%|███▎      | 4/12 [00:12<00:21,  2.64s/it]
Fetching GFS for 2023-10-04 18:00:00:  42%|████▏     | 5/12 [00:13<00:14,  2.08s/it]

2024-06-25 13:58:48.152 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z850 at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:  42%|████▏     | 5/12 [00:13<00:14,  2.08s/it]
Fetching GFS for 2023-10-04 18:00:00:  50%|█████     | 6/12 [00:14<00:10,  1.72s/it]

2024-06-25 13:58:49.169 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: t850 at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:  50%|█████     | 6/12 [00:14<00:10,  1.72s/it]
Fetching GFS for 2023-10-04 18:00:00:  58%|█████▊    | 7/12 [00:15<00:07,  1.43s/it]

2024-06-25 13:58:50.010 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: u850 at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:  58%|█████▊    | 7/12 [00:15<00:07,  1.43s/it]
Fetching GFS for 2023-10-04 18:00:00:  67%|██████▋   | 8/12 [00:16<00:04,  1.20s/it]

2024-06-25 13:58:50.712 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: v850 at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:  67%|██████▋   | 8/12 [00:16<00:04,  1.20s/it]
Fetching GFS for 2023-10-04 18:00:00:  75%|███████▌  | 9/12 [00:16<00:03,  1.03s/it]

2024-06-25 13:58:51.383 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: t2m at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:  75%|███████▌  | 9/12 [00:16<00:03,  1.03s/it]
Fetching GFS for 2023-10-04 18:00:00:  83%|████████▎ | 10/12 [00:17<00:01,  1.11it/s]

2024-06-25 13:58:51.991 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: u10m at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:  83%|████████▎ | 10/12 [00:17<00:01,  1.11it/s]
Fetching GFS for 2023-10-04 18:00:00:  92%|█████████▏| 11/12 [00:18<00:00,  1.17it/s]

2024-06-25 13:58:52.728 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: v10m at 2023-10-04 18:00:00

Fetching GFS for 2023-10-04 18:00:00:  92%|█████████▏| 11/12 [00:18<00:00,  1.17it/s]
Fetching GFS for 2023-10-04 18:00:00: 100%|██████████| 12/12 [00:18<00:00,  1.24it/s]
Fetching GFS for 2023-10-04 18:00:00: 100%|██████████| 12/12 [00:18<00:00,  1.58s/it]
2024-06-25 13:58:53.457 | SUCCESS  | __main__:run:123 - Fetched data from GFS
2024-06-25 13:58:53.462 | INFO     | __main__:run:149 - Inference starting!
2024-06-25 13:58:54.700 | SUCCESS  | __main__:run:153 - Inference complete

Post Processing#

The last step is to post process our results. Cartopy is a great library for plotting fields on projections of a sphere.

Notice that the Zarr IO function has additional APIs to interact with the stored data.

import cartopy.crs as ccrs
import matplotlib.pyplot as plt

projection = ccrs.LambertConformal(
    central_longitude=io["lon"][:].mean(),
)

fig = plt.figure(figsize=(4 * 8, 8))

ax0 = fig.add_subplot(1, 3, 1, projection=projection)
c = ax0.pcolormesh(
    io["lon"],
    io["lat"],
    io["mrr"][0, 0],
    transform=ccrs.PlateCarree(),
    cmap="inferno",
)
plt.colorbar(c, ax=ax0, shrink=0.6, label="mrr dBz")
ax0.coastlines()
ax0.gridlines()
ax0.set_title("Radar Reflectivity")

ax1 = fig.add_subplot(1, 3, 2, projection=projection)
c = ax1.pcolormesh(
    io["lon"],
    io["lat"],
    io["t2m"][0, 0],
    transform=ccrs.PlateCarree(),
    cmap="RdBu_r",
)
plt.colorbar(c, ax=ax1, shrink=0.6, label="K")
ax1.coastlines()
ax1.gridlines()
ax1.set_title("2-meter Temperature")

ax2 = fig.add_subplot(1, 3, 3, projection=projection)
c = ax2.pcolormesh(
    io["lon"],
    io["lat"],
    np.sqrt(io["u10m"][0, 0] ** 2 + io["v10m"][0, 0] ** 2),
    transform=ccrs.PlateCarree(),
    cmap="Greens",
)
plt.colorbar(c, ax=ax2, shrink=0.6, label="w10m m s^-1")
ax2.coastlines()
ax2.gridlines()
ax2.set_title("10-meter Wind Speed")

plt.savefig("outputs/04_corr_diff_prediction.jpg")
Radar Reflectivity, 2-meter Temperature, 10-meter Wind Speed
/usr/local/lib/python3.10/dist-packages/cartopy/io/__init__.py:241: DownloadWarning: Downloading: https://naturalearth.s3.amazonaws.com/10m_physical/ne_10m_coastline.zip
  warnings.warn(f'Downloading: {url}', DownloadWarning)

Total running time of the script: (0 minutes 38.621 seconds)

Gallery generated by Sphinx-Gallery