.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/04_corrdiff_inference.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_04_corrdiff_inference.py: Generative Downscaling ======================== Generative downscaling over Taiwan using CorrDiff diffusion model. This example will demonstrate how to user Nvidia's CorrDiff model, trained for predicting weather over Taiwan, to perform generative downscaling from quarter degree global forecast data to ~3km. This checkpoint was trained on ERA5 data and WRF data that spans 2018-2021 at one hour time resolution. In this example, we demonstrate an application to GFS data for a typhoon super-resolution from 2023. The model's performance on GFS data and on data from this year has not been evaluated. In this example you will learn: - Creating a custom workflow for running CorrDiff inference - Creating a data-source for CorrDiff's input - Initializing and running CorrDiff diagnostic model - Post-processing results. .. GENERATED FROM PYTHON SOURCE LINES 42-58 Creating a Simple CorrDiff Workflow ----------------------------------- As usual, we start with creating a simple workflow to run CorrDiff in. To maximize the generalization of this workflow, we use dependency injection following the pattern provided inside :py:obj:`earth2studio.run`. Since CorrDiff is a diagnostic model, this workflow won't predict a time-series, rather just an instantaneous prediction. For this workflow, we specify - time: Input list of datetimes / strings to run inference for - corrdiff: The initialized CorrDiffTaiwan model - data: Initialized data source to fetch initial conditions from - io: IOBackend - number_of_samples: Number of samples to generate from the model .. GENERATED FROM PYTHON SOURCE LINES 60-157 .. code-block:: Python import os os.makedirs("outputs", exist_ok=True) from dotenv import load_dotenv load_dotenv() # TODO: make common example prep function from collections import OrderedDict from datetime import datetime import numpy as np import torch from loguru import logger from earth2studio.data import DataSource, prep_data_array from earth2studio.io import IOBackend from earth2studio.models.dx import CorrDiffTaiwan from earth2studio.utils.coords import map_coords, split_coords from earth2studio.utils.time import to_time_array def run( time: list[str] | list[datetime] | list[np.datetime64], corrdiff: CorrDiffTaiwan, data: DataSource, io: IOBackend, number_of_samples: int = 1, ) -> IOBackend: """CorrDiff infernce workflow Parameters ---------- time : list[str] | list[datetime] | list[np.datetime64] List of string, datetimes or np.datetime64 corrdiff : CorrDiffTaiwan CorrDiff mode data : DataSource Data source io : IOBackend IO object number_of_samples : int, optional Number of samples to generate, by default 1 Returns ------- IOBackend Output IO object """ logger.info("Running corrdiff inference!") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Inference device: {device}") corrdiff = corrdiff.to(device) # Update the number of samples for corrdiff to generate corrdiff.number_of_samples = number_of_samples # Fetch data from data source and load onto device time = to_time_array(time) x, coords = prep_data_array( data(time, corrdiff.input_coords()["variable"]), device=device ) x, coords = map_coords(x, coords, corrdiff.input_coords()) logger.success(f"Fetched data from {data.__class__.__name__}") # Set up IO backend output_coords = corrdiff.output_coords(corrdiff.input_coords()) total_coords = OrderedDict( { "time": coords["time"], "sample": output_coords["sample"], "ilat": output_coords["ilat"], "ilon": output_coords["ilon"], } ) io.add_array(total_coords, output_coords["variable"]) # Add lat/lon grid metadata arrays io.add_array( OrderedDict({"ilat": total_coords["ilat"], "ilon": total_coords["ilon"]}), "lat", data=corrdiff.out_lat, ) io.add_array( OrderedDict({"ilat": total_coords["ilat"], "ilon": total_coords["ilon"]}), "lon", data=corrdiff.out_lon, ) logger.info("Inference starting!") x, coords = corrdiff(x, coords) io.write(*split_coords(x, coords)) logger.success("Inference complete") return io .. GENERATED FROM PYTHON SOURCE LINES 158-169 Set Up ------ With the workflow defined, the next step is initializing the needed components from Earth-2 studio It's clear we need the following: - Diagnostic Model: CorrDiff model for Taiwan :py:class:`earth2studio.models.dx.CorrDiffTaiwan`. - Datasource: Pull data from the GFS data api :py:class:`earth2studio.data.GFS`. - IO Backend: Save the outputs into a Zarr store :py:class:`earth2studio.io.ZarrBackend`. .. GENERATED FROM PYTHON SOURCE LINES 171-184 .. code-block:: Python from earth2studio.data import GFS from earth2studio.io import ZarrBackend # Create CorrDiff model package = CorrDiffTaiwan.load_default_package() corrdiff = CorrDiffTaiwan.load_model(package) # Create the data source data = GFS() # Create the IO handler, store in memory io = ZarrBackend() .. GENERATED FROM PYTHON SOURCE LINES 185-194 Execute the Workflow -------------------- With all components initialized, running the workflow is a single line of Python code. Workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information. For the inference we will predict 1 sample for a particular timestamp representing Typhoon Koinu. .. GENERATED FROM PYTHON SOURCE LINES 196-198 .. code-block:: Python io = run(["2023-10-04T18:00:00"], corrdiff, data, io, number_of_samples=1) .. GENERATED FROM PYTHON SOURCE LINES 199-205 Post Processing --------------- The last step is to post process our results. Cartopy is a great library for plotting fields on projections of a sphere. Notice that the Zarr IO function has additional APIs to interact with the stored data. .. GENERATED FROM PYTHON SOURCE LINES 207-256 .. code-block:: Python import cartopy.crs as ccrs import matplotlib.pyplot as plt projection = ccrs.LambertConformal( central_longitude=io["lon"][:].mean(), ) fig = plt.figure(figsize=(4 * 8, 8)) ax0 = fig.add_subplot(1, 3, 1, projection=projection) c = ax0.pcolormesh( io["lon"], io["lat"], io["mrr"][0, 0], transform=ccrs.PlateCarree(), cmap="inferno", ) plt.colorbar(c, ax=ax0, shrink=0.6, label="mrr dBz") ax0.coastlines() ax0.gridlines() ax0.set_title("Radar Reflectivity") ax1 = fig.add_subplot(1, 3, 2, projection=projection) c = ax1.pcolormesh( io["lon"], io["lat"], io["t2m"][0, 0], transform=ccrs.PlateCarree(), cmap="RdBu_r", ) plt.colorbar(c, ax=ax1, shrink=0.6, label="K") ax1.coastlines() ax1.gridlines() ax1.set_title("2-meter Temperature") ax2 = fig.add_subplot(1, 3, 3, projection=projection) c = ax2.pcolormesh( io["lon"], io["lat"], np.sqrt(io["u10m"][0, 0] ** 2 + io["v10m"][0, 0] ** 2), transform=ccrs.PlateCarree(), cmap="Greens", ) plt.colorbar(c, ax=ax2, shrink=0.6, label="w10m m s^-1") ax2.coastlines() ax2.gridlines() ax2.set_title("10-meter Wind Speed") plt.savefig("outputs/04_corr_diff_prediction.jpg") .. _sphx_glr_download_examples_04_corrdiff_inference.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 04_corrdiff_inference.ipynb <04_corrdiff_inference.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 04_corrdiff_inference.py <04_corrdiff_inference.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 04_corrdiff_inference.zip <04_corrdiff_inference.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_