.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/extend/02_custom_diagnostic.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_extend_02_custom_diagnostic.py: Extending Diagnostic Models =========================== Implementing a custom diagnostic model This example will demonstrate how to extend Earth2Studio by implementing a custom diagnostic model and running it in a general workflow. In this example you will learn: - API requirements of diagnostic models - Implementing a custom diagnostic model - Running this custom model in a workflow with built in prognostic .. GENERATED FROM PYTHON SOURCE LINES 35-47 Custom Diagnostic ----------------- As discussed in the :ref:`diagnostic_model_userguide` section of the user guide, Earth2Studio defines a diagnostic model through a simple interface :py:class:`earth2studio.models.dx.base.Diagnostic Model`. This can be used to help guide the required APIs needed to successfully create our own model. In this example, lets consider a simple diagnostic that converts the surface temperature in Kelvin to Celsius to make it more readable for the average person. Our diagnostic model has a base class of :py:class:`torch.nn.Module` which allows us to get the required :py:obj:`to(device)` method for free. .. GENERATED FROM PYTHON SOURCE LINES 49-110 .. code-block:: Python from collections import OrderedDict import numpy as np import torch from earth2studio.models.batch import batch_func from earth2studio.utils import handshake_coords, handshake_dim from earth2studio.utils.type import CoordSystem class CustomDiagnostic(torch.nn.Module): """Custom dianostic model""" def __init__(self): super().__init__() input_coords = OrderedDict( { "batch": np.empty(1), "variable": np.array(["t2m"]), "lat": np.linspace(90, -90, 721), "lon": np.linspace(0, 360, 1440, endpoint=False), } ) output_coords = OrderedDict( { "batch": np.empty(1), "variable": np.array(["t2m_c"]), "lat": np.linspace(90, -90, 721), "lon": np.linspace(0, 360, 1440, endpoint=False), } ) @batch_func() def __call__( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: """Runs diagnostic model Parameters ---------- x : torch.Tensor Input tensor coords : CoordSystem Input coordinate system """ for i, (key, value) in enumerate(self.input_coords.items()): if key != "batch": handshake_dim(coords, key, i) handshake_coords(coords, self.input_coords, key) out_coords = coords.copy() out_coords["variable"] = self.output_coords["variable"] out = x - 273.15 # To celcius return out, out_coords .. GENERATED FROM PYTHON SOURCE LINES 111-120 Input/Output Coordinates ~~~~~~~~~~~~~~~~~~~~~~~~ Defining the input/output coordinate systems is essential for any model in Earth2Studio since this is how both the package and users can learn what type of data the model expects. Have a look at :ref:`coordinates_userguide` for details on coordinate system. For this diagnostic model, we simply define the input coordinates to be the global surface temperature specified in :py:file:`earth2studio.lexicon.base.py`. The output is a custom variable :py:var:`t2m_c` that represents the temperature in Celsius. .. GENERATED FROM PYTHON SOURCE LINES 122-133 :py:func:`__call__` API ~~~~~~~~~~~~~~~~~~~~~~~ The call function is the main API of diagnostic models that have a tensor and coordinate system as input/output. This function first validates that the coordinate system is correct. Then both the input data tensor and also coordinate system are updated and returned. .. note:: You may notice the :py:func:`batch_func` decorator, which is used to make batched operations easier. For more details about this refer to the :ref:`batch_function_userguide` section of the user guide. .. GENERATED FROM PYTHON SOURCE LINES 135-140 Set Up ------ With the custom diagnostic model defined, it's now easily usable in a workflow. Let's create our own simple diagnostic workflow based on the ones that exist already in Earth2Studio. .. GENERATED FROM PYTHON SOURCE LINES 142-248 .. code-block:: Python from datetime import datetime from typing import Optional import numpy as np import torch from loguru import logger from tqdm import tqdm from earth2studio.data import DataSource, fetch_data from earth2studio.io import IOBackend from earth2studio.models.dx import DiagnosticModel from earth2studio.models.px import PrognosticModel from earth2studio.utils.coords import extract_coords, map_coords from earth2studio.utils.time import to_time_array def run( time: list[str] | list[datetime] | list[np.datetime64], nsteps: int, prognostic: PrognosticModel, diagnostic: DiagnosticModel, data: DataSource, io: IOBackend, device: Optional[torch.device] = None, ) -> IOBackend: """Simple diagnostic workflow Parameters ---------- time : list[str] | list[datetime] | list[np.datetime64] List of string, datetimes or np.datetime64 nsteps : int Number of forecast steps prognostic : PrognosticModel Prognostic models data : DataSource Data source io : IOBackend IO object device : Optional[torch.device], optional Device to run inference on, by default None Returns ------- IOBackend Output IO object """ logger.info("Running diagnostic workflow!") # Load model onto the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) # Fetch data from data source and load onto device time = to_time_array(time) x, coords = fetch_data( source=data, time=time, lead_time=prognostic.input_coords["lead_time"], variable=prognostic.input_coords["variable"], device=device, ) logger.success(f"Fetched data from {data.__class__.__name__}") # Set up IO backend total_coords = prognostic.output_coords.copy() del total_coords["batch"] # Unsafe if batch not supported for key, value in total_coords.items(): if value.shape == 0: del total_coords[key] total_coords["time"] = time total_coords["lead_time"] = np.asarray( [prognostic.output_coords["lead_time"] * i for i in range(nsteps + 1)] ).flatten() total_coords.move_to_end("lead_time", last=False) total_coords.move_to_end("time", last=False) for name, value in diagnostic.output_coords.items(): if name == "batch": continue total_coords[name] = value var_names = total_coords.pop("variable") io.add_array(total_coords, var_names) # Map lat and lon if needed x, coords = map_coords(x, coords, prognostic.input_coords) # Create prognostic iterator model = prognostic.create_iterator(x, coords) logger.info("Inference starting!") with tqdm(total=nsteps + 1, desc="Running inference") as pbar: for step, (x, coords) in enumerate(model): # Run diagnostic x, coords = map_coords(x, coords, diagnostic.input_coords) x, coords = diagnostic(x, coords) io.write(*extract_coords(x, coords)) pbar.update(1) if step == nsteps: break logger.success("Inference complete") return io .. GENERATED FROM PYTHON SOURCE LINES 249-255 Lets instantiate the components needed. - Prognostic Model: Use the built in DLWP model :py:class:`earth2studio.models.px.DLWP`. - Diagnostic Model: The custom diagnostic model defined above - Datasource: Pull data from the GFS data api :py:class:`earth2studio.data.GFS`. - IO Backend: Save the outputs into a Zarr store :py:class:`earth2studio.io.ZarrBackend`. .. GENERATED FROM PYTHON SOURCE LINES 257-281 .. code-block:: Python from collections import OrderedDict import numpy as np from dotenv import load_dotenv load_dotenv() # TODO: make common example prep function from earth2studio.data import GFS from earth2studio.io import ZarrBackend from earth2studio.models.px import DLWP # Load the default model package which downloads the check point from NGC package = DLWP.load_default_package() model = DLWP.load_model(package) # Diagnostic model diagnostic = CustomDiagnostic() # Create the data source data = GFS() # Create the IO handler, store in memory io = ZarrBackend() .. GENERATED FROM PYTHON SOURCE LINES 282-286 Execute the Workflow -------------------- Running our workflow with a build in prognostic model and a custom diagnostic is as simple as the following. .. GENERATED FROM PYTHON SOURCE LINES 288-294 .. code-block:: Python nsteps = 20 io = run(["2024-01-01"], nsteps, model, diagnostic, data, io) print(io.root.tree()) .. rst-class:: sphx-glr-script-out .. code-block:: none 2024-04-19 00:36:48.079 | INFO | __main__:run:189 - Running diagnostic workflow! 2024-04-19 00:36:48.079 | INFO | __main__:run:192 - Inference device: cuda 2024-04-19 00:36:48.086 | DEBUG | earth2studio.data.gfs:fetch_gfs_dataarray:151 - Fetching GFS index file: 2023-12-31 18:00:00 Fetching GFS for 2023-12-31 18:00:00: 0%| | 0/7 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 02_custom_diagnostic.py <02_custom_diagnostic.py>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_