.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "examples/extend/02_custom_diagnostic.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_examples_extend_02_custom_diagnostic.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_examples_extend_02_custom_diagnostic.py:


Extending Diagnostic Models
===========================

Implementing a custom diagnostic model

This example will demonstrate how to extend Earth2Studio by implementing a custom
diagnostic model and running it in a general workflow.

In this example you will learn:

- API requirements of diagnostic models
- Implementing a custom diagnostic model
- Running this custom model in a workflow with built in prognostic

.. GENERATED FROM PYTHON SOURCE LINES 35-47

Custom Diagnostic
-----------------
As discussed in the :ref:`diagnostic_model_userguide` section of the user guide,
Earth2Studio defines a diagnostic model through a simple interface
:py:class:`earth2studio.models.dx.base.Diagnostic Model`. This can be used to help
guide the required APIs needed to successfully create our own model.

In this example, lets consider a simple diagnostic that converts the surface
temperature in Kelvin to Celsius to make it more readable for the average person.

Our diagnostic model has a base class of :py:class:`torch.nn.Module` which allows us
to get the required :py:obj:`to(device)` method for free.

.. GENERATED FROM PYTHON SOURCE LINES 49-141

.. code-block:: Python

    import os

    os.makedirs("outputs", exist_ok=True)
    from dotenv import load_dotenv

    load_dotenv()  # TODO: make common example prep function

    from collections import OrderedDict

    import numpy as np
    import torch

    from earth2studio.models.batch import batch_coords, batch_func
    from earth2studio.utils import handshake_coords, handshake_dim
    from earth2studio.utils.type import CoordSystem


    class CustomDiagnostic(torch.nn.Module):
        """Custom dianostic model"""

        def __init__(self):
            super().__init__()

        def input_coords(self) -> CoordSystem:
            """Input coordinate system of the prognostic model

            Returns
            -------
            CoordSystem
                Coordinate system dictionary
            """
            return OrderedDict(
                {
                    "batch": np.empty(0),
                    "variable": np.array(["t2m"]),
                    "lat": np.linspace(90, -90, 721),
                    "lon": np.linspace(0, 360, 1440, endpoint=False),
                }
            )

        @batch_coords()
        def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
            """Output coordinate system of the prognostic model

            Parameters
            ----------
            input_coords : CoordSystem
                Input coordinate system to transform into output_coords

            Returns
            -------
            CoordSystem
                Coordinate system dictionary
            """
            # Check input coordinates are valid
            target_input_coords = self.input_coords()
            for i, (key, value) in enumerate(target_input_coords.items()):
                if key != "batch":
                    handshake_dim(input_coords, key, i)
                    handshake_coords(input_coords, target_input_coords, key)

            output_coords = OrderedDict(
                {
                    "batch": np.empty(0),
                    "variable": np.array(["t2m_c"]),
                    "lat": np.linspace(90, -90, 721),
                    "lon": np.linspace(0, 360, 1440, endpoint=False),
                }
            )
            output_coords["batch"] = input_coords["batch"]
            return output_coords

        @batch_func()
        def __call__(
            self,
            x: torch.Tensor,
            coords: CoordSystem,
        ) -> tuple[torch.Tensor, CoordSystem]:
            """Runs diagnostic model

            Parameters
            ----------
            x : torch.Tensor
                Input tensor
            coords : CoordSystem
                Input coordinate system
            """
            out_coords = self.output_coords(coords)
            out = x - 273.15  # To celcius
            return out, out_coords









.. GENERATED FROM PYTHON SOURCE LINES 142-154

Input/Output Coordinates
~~~~~~~~~~~~~~~~~~~~~~~~
Defining the input/output coordinate systems is essential for any model in
Earth2Studio since this is how both the package and users can learn what type of data
the model expects. This requires the definition of  :py:func:`input_coords` and
:py:func:`output_coords`. Have a look at :ref:`coordinates_userguide` for details on
coordinate system.

For this diagnostic model, we simply define the input coordinates
to be the global surface temperature specified in :file:`earth2studio.lexicon.base.py`.
The output is a custom variable :code:`t2m_c` that represents the temperature in
Celsius.

.. GENERATED FROM PYTHON SOURCE LINES 156-167

:py:func:`__call__` API
~~~~~~~~~~~~~~~~~~~~~~~
The call function is the main API of diagnostic models that have a tensor and
coordinate system as input/output. This function first validates that the coordinate
system is correct. Then both the input data tensor and also coordinate system are
updated and returned.

.. note::
  You may notice the :py:func:`batch_func` decorator, which is used to make batched
  operations easier. For more details about this refer to the :ref:`batch_function_userguide`
  section of the user guide.

.. GENERATED FROM PYTHON SOURCE LINES 169-173

Set Up
------
With the custom diagnostic model defined, the next step is to set up and run a
workflow. We will use the built in workflow :py:meth:`earth2studio.run.diagnostic`.

.. GENERATED FROM PYTHON SOURCE LINES 175-181

Lets instantiate the components needed.

- Prognostic Model: Use the built in DLWP model :py:class:`earth2studio.models.px.DLWP`.
- Diagnostic Model: The custom diagnostic model defined above
- Datasource: Pull data from the GFS data api :py:class:`earth2studio.data.GFS`.
- IO Backend: Save the outputs into a Zarr store :py:class:`earth2studio.io.ZarrBackend`.

.. GENERATED FROM PYTHON SOURCE LINES 183-204

.. code-block:: Python

    from dotenv import load_dotenv

    load_dotenv()  # TODO: make common example prep function

    from earth2studio.data import GFS
    from earth2studio.io import ZarrBackend
    from earth2studio.models.px import DLWP

    # Load the default model package which downloads the check point from NGC
    package = DLWP.load_default_package()
    model = DLWP.load_model(package)

    # Diagnostic model
    diagnostic = CustomDiagnostic()

    # Create the data source
    data = GFS()

    # Create the IO handler, store in memory
    io = ZarrBackend()








.. GENERATED FROM PYTHON SOURCE LINES 205-209

Execute the Workflow
--------------------
Running our workflow with a build in prognostic model and a custom diagnostic is the
same as running a built in diagnostic.

.. GENERATED FROM PYTHON SOURCE LINES 211-218

.. code-block:: Python

    import earth2studio.run as run

    nsteps = 20
    io = run.diagnostic(["2024-01-01"], nsteps, model, diagnostic, data, io)

    print(io.root.tree())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    2025-03-27 08:10:48.978 | INFO     | earth2studio.run:diagnostic:190 - Running diagnostic workflow!
    2025-03-27 08:10:48.978 | INFO     | earth2studio.run:diagnostic:197 - Inference device: cuda
    2025-03-27 08:10:48.987 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:214 - Fetching GFS index file: 2023-12-31 18:00:00 lead 0:00:00

    Fetching GFS for 2023-12-31 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
                                                                           
    2025-03-27 08:10:48.993 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: t850 at 2023-12-31 18:00:00_0:00:00

    Fetching GFS for 2023-12-31 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
                                                                           
    2025-03-27 08:10:49.022 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: z1000 at 2023-12-31 18:00:00_0:00:00

    Fetching GFS for 2023-12-31 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
                                                                           
    2025-03-27 08:10:49.049 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: z700 at 2023-12-31 18:00:00_0:00:00

    Fetching GFS for 2023-12-31 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
                                                                           
    2025-03-27 08:10:49.077 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: z500 at 2023-12-31 18:00:00_0:00:00

    Fetching GFS for 2023-12-31 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
    Fetching GFS for 2023-12-31 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 35.41it/s]
                                                                                   
    2025-03-27 08:10:49.106 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: z300 at 2023-12-31 18:00:00_0:00:00

    Fetching GFS for 2023-12-31 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 35.41it/s]
                                                                                   
    2025-03-27 08:10:49.134 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: tcwv at 2023-12-31 18:00:00_0:00:00

    Fetching GFS for 2023-12-31 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 35.41it/s]
                                                                                   
    2025-03-27 08:10:49.161 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: t2m at 2023-12-31 18:00:00_0:00:00

    Fetching GFS for 2023-12-31 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 35.41it/s]
    Fetching GFS for 2023-12-31 18:00:00: 100%|██████████| 7/7 [00:00<00:00, 35.76it/s]
    2025-03-27 08:10:49.196 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:214 - Fetching GFS index file: 2024-01-01 00:00:00 lead 0:00:00

    Fetching GFS for 2024-01-01 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
                                                                           
    2025-03-27 08:10:49.201 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: t850 at 2024-01-01 00:00:00_0:00:00

    Fetching GFS for 2024-01-01 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
                                                                           
    2025-03-27 08:10:49.229 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: z1000 at 2024-01-01 00:00:00_0:00:00

    Fetching GFS for 2024-01-01 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
                                                                           
    2025-03-27 08:10:49.257 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: z700 at 2024-01-01 00:00:00_0:00:00

    Fetching GFS for 2024-01-01 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
                                                                           
    2025-03-27 08:10:49.285 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: z500 at 2024-01-01 00:00:00_0:00:00

    Fetching GFS for 2024-01-01 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
    Fetching GFS for 2024-01-01 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 35.60it/s]
                                                                                   
    2025-03-27 08:10:49.313 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: z300 at 2024-01-01 00:00:00_0:00:00

    Fetching GFS for 2024-01-01 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 35.60it/s]
                                                                                   
    2025-03-27 08:10:49.341 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: tcwv at 2024-01-01 00:00:00_0:00:00

    Fetching GFS for 2024-01-01 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 35.60it/s]
                                                                                   
    2025-03-27 08:10:49.369 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:260 - Fetching GFS grib file for variable: t2m at 2024-01-01 00:00:00_0:00:00

    Fetching GFS for 2024-01-01 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 35.60it/s]
    Fetching GFS for 2024-01-01 00:00:00: 100%|██████████| 7/7 [00:00<00:00, 35.68it/s]
    2025-03-27 08:10:49.421 | SUCCESS  | earth2studio.run:diagnostic:220 - Fetched data from GFS
    2025-03-27 08:10:49.427 | INFO     | earth2studio.run:diagnostic:252 - Inference starting!

    Running inference:   0%|          | 0/21 [00:00<?, ?it/s]
    Running inference:  14%|█▍        | 3/21 [00:00<00:00, 19.19it/s]
    Running inference:  24%|██▍       | 5/21 [00:00<00:00, 17.24it/s]
    Running inference:  33%|███▎      | 7/21 [00:00<00:00, 15.92it/s]
    Running inference:  43%|████▎     | 9/21 [00:00<00:00, 14.99it/s]
    Running inference:  52%|█████▏    | 11/21 [00:00<00:00, 14.10it/s]
    Running inference:  62%|██████▏   | 13/21 [00:00<00:00, 13.80it/s]
    Running inference:  71%|███████▏  | 15/21 [00:01<00:00, 13.53it/s]
    Running inference:  81%|████████  | 17/21 [00:01<00:00, 13.35it/s]
    Running inference:  90%|█████████ | 19/21 [00:01<00:00, 13.14it/s]
    Running inference: 100%|██████████| 21/21 [00:01<00:00, 12.83it/s]
    Running inference: 100%|██████████| 21/21 [00:01<00:00, 13.91it/s]
    2025-03-27 08:10:50.937 | SUCCESS  | earth2studio.run:diagnostic:266 - Inference complete
    /
     ├── lat (721,) float64
     ├── lead_time (21,) timedelta64[h]
     ├── lon (1440,) float64
     ├── t2m_c (1, 21, 721, 1440) float32
     └── time (1,) datetime64[ns]




.. GENERATED FROM PYTHON SOURCE LINES 219-222

Post Processing
---------------
Let's plot the Celsius temperature field from our custom diagnostic model.

.. GENERATED FROM PYTHON SOURCE LINES 224-270

.. code-block:: Python

    import cartopy.crs as ccrs
    import matplotlib.pyplot as plt

    forecast = "2024-01-01"
    variable = "t2m_c"

    plt.close("all")

    # Create a figure and axes with the specified projection
    fig, ax = plt.subplots(
        1,
        5,
        figsize=(12, 4),
        subplot_kw={"projection": ccrs.Orthographic()},
        constrained_layout=True,
    )

    times = (
        io["lead_time"][:].astype("timedelta64[ns]").astype("timedelta64[h]").astype(int)
    )
    step = 4  # 24hrs
    for i, t in enumerate(range(0, 20, step)):

        ctr = ax[i].contourf(
            io["lon"][:],
            io["lat"][:],
            io[variable][0, t],
            vmin=-10,
            vmax=30,
            transform=ccrs.PlateCarree(),
            levels=20,
            cmap="coolwarm",
        )
        ax[i].set_title(f"{times[t]}hrs")
        ax[i].coastlines()
        ax[i].gridlines()

    plt.suptitle(f"{variable} - {forecast}")

    cbar = plt.cm.ScalarMappable(cmap="coolwarm")
    cbar.set_array(io[variable][0, 0])
    cbar.set_clim(-10.0, 30)
    cbar = fig.colorbar(cbar, ax=ax[-1], orientation="vertical", label="C", shrink=0.8)


    plt.savefig("outputs/02_custom_diagnostic_dlwp_prediction.jpg")



.. image-sg:: /examples/extend/images/sphx_glr_02_custom_diagnostic_001.png
   :alt: t2m_c - 2024-01-01, 0hrs, 24hrs, 48hrs, 72hrs, 96hrs
   :srcset: /examples/extend/images/sphx_glr_02_custom_diagnostic_001.png, /examples/extend/images/sphx_glr_02_custom_diagnostic_001_2_00x.png 2.00x
   :class: sphx-glr-single-img






.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 46.744 seconds)


.. _sphx_glr_download_examples_extend_02_custom_diagnostic.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: 02_custom_diagnostic.ipynb <02_custom_diagnostic.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: 02_custom_diagnostic.py <02_custom_diagnostic.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: 02_custom_diagnostic.zip <02_custom_diagnostic.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_