.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/05_ensemble_workflow_extend.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_05_ensemble_workflow_extend.py: Single Variable Perturbation Method =================================== Intermediate ensemble inference using a custom perturbation method. This example will demonstrate how to run a an ensemble inference workflow with a custom perturbation method that only applies noise to a specific variable. In this example you will learn: - How to extend an existing pertubration method - How to instantiate a built in prognostic model - Creating a data source and IO object - Running a simple built in workflow - Extend a built-in method using custom code. - Post-processing results .. GENERATED FROM PYTHON SOURCE LINES 38-43 Set Up ------ All workflows inside Earth2Studio require constructed components to be handed to them. In this example, we will use the built in ensemble workflow :py:meth:`earth2studio.run.ensemble`. .. GENERATED FROM PYTHON SOURCE LINES 45-48 .. literalinclude:: ../../earth2studio/run.py :language: python :lines: 116-156 .. GENERATED FROM PYTHON SOURCE LINES 50-56 We need the following: - Prognostic Model: Use the built in DLWP model :py:class:`earth2studio.models.px.DLWP`. - perturbation_method: Extend the Spherical Gaussian Method :py:class:`earth2studio.perturbation.SphericalGaussian`. - Datasource: Pull data from the GFS data api :py:class:`earth2studio.data.GFS`. - IO Backend: Save the outputs into a Zarr store :py:class:`earth2studio.io.ZarrBackend`. .. GENERATED FROM PYTHON SOURCE LINES 58-82 .. code-block:: Python import os os.makedirs("outputs", exist_ok=True) from dotenv import load_dotenv load_dotenv() # TODO: make common example prep function import numpy as np import torch from earth2studio.data import GFS from earth2studio.io import ZarrBackend from earth2studio.models.px import DLWP from earth2studio.perturbation import Perturbation, SphericalGaussian from earth2studio.run import ensemble from earth2studio.utils.type import CoordSystem # Load the default model package which downloads the check point from NGC package = DLWP.load_default_package() model = DLWP.load_model(package) # Create the data source data = GFS() .. GENERATED FROM PYTHON SOURCE LINES 83-86 The perturbation method in :ref:`sphx_glr_examples_03_ensemble_workflow.py` is naive because it applies the same noise amplitude to every variable. We can create a custom wrapper that only applies the perturbation method to a particular variable instead. .. GENERATED FROM PYTHON SOURCE LINES 88-122 .. code-block:: Python class ApplyToVariable: """Apply a perturbation to only a particular variable.""" def __init__(self, pm: Perturbation, variable: str | list[str]): self.pm = pm if isinstance(variable, str): variable = [variable] self.variable = variable @torch.inference_mode() def __call__( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: # Apply perturbation xp, _ = self.pm(x, coords) # Add perturbed slice back into original tensor ind = np.in1d(coords["variable"], self.variable) x[..., ind, :, :] = xp[..., ind, :, :] return x, coords # Generate a new noise amplitude that specifically targets 't2m' with a 1 K noise amplitude avsg = ApplyToVariable(SphericalGaussian(noise_amplitude=1.0), "t2m") # Create the IO handler, store in memory chunks = {"ensemble": 1, "time": 1} io = ZarrBackend( file_name="outputs/05_ensemble_avsg.zarr", chunks=chunks, backend_kwargs={"overwrite": True}, ) .. GENERATED FROM PYTHON SOURCE LINES 123-132 Execute the Workflow -------------------- With all components initialized, running the workflow is a single line of Python code. Workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information. For the forecast we will predict for 10 steps (for FCN, this is 60 hours) with 8 ensemble members which will be ran in 2 batches with batch size 4. .. GENERATED FROM PYTHON SOURCE LINES 134-149 .. code-block:: Python nsteps = 10 nensemble = 8 batch_size = 4 io = ensemble( ["2024-01-01"], nsteps, nensemble, model, data, io, avsg, batch_size=batch_size, output_coords={"variable": np.array(["t2m", "tcwv"])}, ) .. GENERATED FROM PYTHON SOURCE LINES 150-157 Post Processing --------------- The last step is to post process our results. Lets plot both the perturbed t2m field and also the unperturbed tcwv field. First to confirm the perturbation method works as expect, the initial state is plotted. Notice that the Zarr IO function has additional APIs to interact with the stored data. .. GENERATED FROM PYTHON SOURCE LINES 159-203 .. code-block:: Python import matplotlib.pyplot as plt forecast = "2024-01-01" def plot_(axi, data, title, cmap): """Simple plot util function""" im = axi.imshow(data, cmap=cmap) plt.colorbar(im, ax=axi, shrink=0.5, pad=0.04) axi.set_title(title) step = 0 # lead time = 24 hrs plt.close("all") # Create a figure and axes with the specified projection fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) plot_( ax[0, 0], np.mean(io["t2m"][:, 0, step], axis=0), f"{forecast} - t2m - Lead time: {6*step}hrs - Mean", "coolwarm", ) plot_( ax[0, 1], np.std(io["t2m"][:, 0, step], axis=0), f"{forecast} - t2m - Lead time: {6*step}hrs - Std", "coolwarm", ) plot_( ax[1, 0], np.mean(io["tcwv"][:, 0, step], axis=0), f"{forecast} - tcwv - Lead time: {6*step}hrs - Mean", "Blues", ) plot_( ax[1, 1], np.std(io["tcwv"][:, 0, step], axis=0), f"{forecast} - tcwv - Lead time: {6*step}hrs - Std", "Blues", ) plt.savefig(f"outputs/05_{forecast}_{step}_ensemble.jpg") .. GENERATED FROM PYTHON SOURCE LINES 204-208 Due to the intrinsic coupling between all fields, we should expect all variables to have some uncertainty for later lead times. Here the total column water vapor is plotted at a lead time of 24 hours, note the variance in the members despite just perturbing the temperature field. .. GENERATED FROM PYTHON SOURCE LINES 210-229 .. code-block:: Python step = 4 # lead time = 24 hrs plt.close("all") # Create a figure and axes with the specified projection fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 3)) plot_( ax[0], np.mean(io["tcwv"][:, 0, step], axis=0), f"{forecast} - tcwv - Lead time: {6*step}hrs - Mean", "Blues", ) plot_( ax[1], np.std(io["tcwv"][:, 0, step], axis=0), f"{forecast} - tcwv - Lead time: {6*step}hrs - Std", "Blues", ) plt.savefig(f"outputs/05_{forecast}_{step}_ensemble.jpg") .. _sphx_glr_download_examples_05_ensemble_workflow_extend.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 05_ensemble_workflow_extend.ipynb <05_ensemble_workflow_extend.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 05_ensemble_workflow_extend.py <05_ensemble_workflow_extend.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 05_ensemble_workflow_extend.zip <05_ensemble_workflow_extend.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_