.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/extend/02_custom_diagnostic.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_extend_02_custom_diagnostic.py: Extending Diagnostic Models =========================== Implementing a custom diagnostic model This example will demonstrate how to extend Earth2Studio by implementing a custom diagnostic model and running it in a general workflow. In this example you will learn: - API requirements of diagnostic models - Implementing a custom diagnostic model - Running this custom model in a workflow with built in prognostic .. GENERATED FROM PYTHON SOURCE LINES 35-47 Custom Diagnostic ----------------- As discussed in the :ref:`diagnostic_model_userguide` section of the user guide, Earth2Studio defines a diagnostic model through a simple interface :py:class:`earth2studio.models.dx.base.Diagnostic Model`. This can be used to help guide the required APIs needed to successfully create our own model. In this example, lets consider a simple diagnostic that converts the surface temperature in Kelvin to Celsius to make it more readable for the average person. Our diagnostic model has a base class of :py:class:`torch.nn.Module` which allows us to get the required :py:obj:`to(device)` method for free. .. GENERATED FROM PYTHON SOURCE LINES 49-141 .. code-block:: Python import os os.makedirs("outputs", exist_ok=True) from dotenv import load_dotenv load_dotenv() # TODO: make common example prep function from collections import OrderedDict import numpy as np import torch from earth2studio.models.batch import batch_coords, batch_func from earth2studio.utils import handshake_coords, handshake_dim from earth2studio.utils.type import CoordSystem class CustomDiagnostic(torch.nn.Module): """Custom dianostic model""" def __init__(self): super().__init__() def input_coords(self) -> CoordSystem: """Input coordinate system of the prognostic model Returns ------- CoordSystem Coordinate system dictionary """ return OrderedDict( { "batch": np.empty(0), "variable": np.array(["t2m"]), "lat": np.linspace(90, -90, 721), "lon": np.linspace(0, 360, 1440, endpoint=False), } ) @batch_coords() def output_coords(self, input_coords: CoordSystem) -> CoordSystem: """Output coordinate system of the prognostic model Parameters ---------- input_coords : CoordSystem Input coordinate system to transform into output_coords Returns ------- CoordSystem Coordinate system dictionary """ # Check input coordinates are valid target_input_coords = self.input_coords() for i, (key, value) in enumerate(target_input_coords.items()): if key != "batch": handshake_dim(input_coords, key, i) handshake_coords(input_coords, target_input_coords, key) output_coords = OrderedDict( { "batch": np.empty(0), "variable": np.array(["t2m_c"]), "lat": np.linspace(90, -90, 721), "lon": np.linspace(0, 360, 1440, endpoint=False), } ) output_coords["batch"] = input_coords["batch"] return output_coords @batch_func() def __call__( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: """Runs diagnostic model Parameters ---------- x : torch.Tensor Input tensor coords : CoordSystem Input coordinate system """ out_coords = self.output_coords(coords) out = x - 273.15 # To celcius return out, out_coords .. GENERATED FROM PYTHON SOURCE LINES 142-154 Input/Output Coordinates ~~~~~~~~~~~~~~~~~~~~~~~~ Defining the input/output coordinate systems is essential for any model in Earth2Studio since this is how both the package and users can learn what type of data the model expects. This requires the definition of :py:func:`input_coords` and :py:func:`output_coords`. Have a look at :ref:`coordinates_userguide` for details on coordinate system. For this diagnostic model, we simply define the input coordinates to be the global surface temperature specified in :file:`earth2studio.lexicon.base.py`. The output is a custom variable :code:`t2m_c` that represents the temperature in Celsius. .. GENERATED FROM PYTHON SOURCE LINES 156-167 :py:func:`__call__` API ~~~~~~~~~~~~~~~~~~~~~~~ The call function is the main API of diagnostic models that have a tensor and coordinate system as input/output. This function first validates that the coordinate system is correct. Then both the input data tensor and also coordinate system are updated and returned. .. note:: You may notice the :py:func:`batch_func` decorator, which is used to make batched operations easier. For more details about this refer to the :ref:`batch_function_userguide` section of the user guide. .. GENERATED FROM PYTHON SOURCE LINES 169-173 Set Up ------ With the custom diagnostic model defined, the next step is to set up and run a workflow. We will use the built in workflow :py:meth:`earth2studio.run.diagnostic`. .. GENERATED FROM PYTHON SOURCE LINES 175-181 Lets instantiate the components needed. - Prognostic Model: Use the built in DLWP model :py:class:`earth2studio.models.px.DLWP`. - Diagnostic Model: The custom diagnostic model defined above - Datasource: Pull data from the GFS data api :py:class:`earth2studio.data.GFS`. - IO Backend: Save the outputs into a Zarr store :py:class:`earth2studio.io.ZarrBackend`. .. GENERATED FROM PYTHON SOURCE LINES 183-204 .. code-block:: Python from dotenv import load_dotenv load_dotenv() # TODO: make common example prep function from earth2studio.data import GFS from earth2studio.io import ZarrBackend from earth2studio.models.px import DLWP # Load the default model package which downloads the check point from NGC package = DLWP.load_default_package() model = DLWP.load_model(package) # Diagnostic model diagnostic = CustomDiagnostic() # Create the data source data = GFS() # Create the IO handler, store in memory io = ZarrBackend() .. GENERATED FROM PYTHON SOURCE LINES 205-209 Execute the Workflow -------------------- Running our workflow with a build in prognostic model and a custom diagnostic is the same as running a built in diagnostic. .. GENERATED FROM PYTHON SOURCE LINES 211-218 .. code-block:: Python import earth2studio.run as run nsteps = 20 io = run.diagnostic(["2024-01-01"], nsteps, model, diagnostic, data, io) print(io.root.tree()) .. GENERATED FROM PYTHON SOURCE LINES 219-222 Post Processing --------------- Let's plot the Celsius temperature field from our custom diagnostic model. .. GENERATED FROM PYTHON SOURCE LINES 224-268 .. code-block:: Python import cartopy.crs as ccrs import matplotlib.pyplot as plt forecast = "2024-01-01" variable = "t2m_c" plt.close("all") # Create a figure and axes with the specified projection fig, ax = plt.subplots( 1, 5, figsize=(12, 4), subplot_kw={"projection": ccrs.Orthographic()}, constrained_layout=True, ) times = io["lead_time"].astype("timedelta64[h]").astype(int) step = 4 # 24hrs for i, t in enumerate(range(0, 20, step)): ctr = ax[i].contourf( io["lon"][:], io["lat"][:], io[variable][0, t], vmin=-10, vmax=30, transform=ccrs.PlateCarree(), levels=20, cmap="coolwarm", ) ax[i].set_title(f"{times[t]}hrs") ax[i].coastlines() ax[i].gridlines() plt.suptitle(f"{variable} - {forecast}") cbar = plt.cm.ScalarMappable(cmap="coolwarm") cbar.set_array(io[variable][0, 0]) cbar.set_clim(-10.0, 30) cbar = fig.colorbar(cbar, ax=ax[-1], orientation="vertical", label="C", shrink=0.8) plt.savefig("outputs/custom_diagnostic_dlwp_prediction.jpg") .. _sphx_glr_download_examples_extend_02_custom_diagnostic.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 02_custom_diagnostic.ipynb <02_custom_diagnostic.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 02_custom_diagnostic.py <02_custom_diagnostic.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 02_custom_diagnostic.zip <02_custom_diagnostic.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_