Earth2Studio is now OSS!

Single Variable Perturbation Method#

Intermediate ensemble inference using a custom perturbation method.

This example will demonstrate how to run a an ensemble inference workflow with a custom perturbation method that only applies noise to a specific variable.

In this example you will learn:

  • How to extend an existing pertubration method

  • How to instantiate a built in prognostic model

  • Creating a data source and IO object

  • Running a simple built in workflow

  • Extend a built-in method using custom code.

  • Post-processing results

Set Up#

All workflows inside Earth2Studio require constructed components to be handed to them. In this example, we will use the built in ensemble workflow earth2studio.run.ensemble().

        total_coords[key] = output_coords.get(key, value)
    var_names = total_coords.pop("variable")
    io.add_array(total_coords, var_names)

    # Map lat and lon if needed
    x, coords = map_coords(x, coords, prognostic.input_coords())
    # Create prognostic iterator
    model = prognostic.create_iterator(x, coords)

    logger.info("Inference starting!")
    with tqdm(total=nsteps + 1, desc="Running inference") as pbar:
        for step, (x, coords) in enumerate(model):
            # Subselect domain/variables as indicated in output_coords
            x, coords = map_coords(x, coords, output_coords)
            io.write(*split_coords(x, coords))
            pbar.update(1)
            if step == nsteps:
                break

    logger.success("Inference complete")
    return io


# sphinx - diagnostic start
def diagnostic(
    time: list[str] | list[datetime] | list[np.datetime64],
    nsteps: int,
    prognostic: PrognosticModel,
    diagnostic: DiagnosticModel,
    data: DataSource,
    io: IOBackend,
    output_coords: CoordSystem = OrderedDict({}),
    device: torch.device | None = None,
) -> IOBackend:
    """Built in diagnostic workflow.
    This workflow creates a determinstic inference pipeline that couples a prognostic
    model with a diagnostic model.

    Parameters
    ----------
    time : list[str] | list[datetime] | list[np.datetime64]

We need the following:

import os

os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv

load_dotenv()  # TODO: make common example prep function

import numpy as np
import torch

from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
from earth2studio.models.px import DLWP
from earth2studio.perturbation import Perturbation, SphericalGaussian
from earth2studio.run import ensemble
from earth2studio.utils.type import CoordSystem

# Load the default model package which downloads the check point from NGC
package = DLWP.load_default_package()
model = DLWP.load_model(package)

# Create the data source
data = GFS()

The perturbation method in Running Ensemble Inference is naive because it applies the same noise amplitude to every variable. We can create a custom wrapper that only applies the perturbation method to a particular variable instead.

class ApplyToVariable:
    """Apply a perturbation to only a particular variable."""

    def __init__(self, pm: Perturbation, variable: str | list[str]):
        self.pm = pm
        if isinstance(variable, str):
            variable = [variable]
        self.variable = variable

    @torch.inference_mode()
    def __call__(
        self,
        x: torch.Tensor,
        coords: CoordSystem,
    ) -> tuple[torch.Tensor, CoordSystem]:
        # Apply perturbation
        xp, _ = self.pm(x, coords)
        # Add perturbed slice back into original tensor
        ind = np.in1d(coords["variable"], self.variable)
        x[..., ind, :, :] = xp[..., ind, :, :]
        return x, coords


# Generate a new noise amplitude that specifically targets 't2m' with a 1 K noise amplitude
avsg = ApplyToVariable(SphericalGaussian(noise_amplitude=1.0), "t2m")

# Create the IO handler, store in memory
chunks = {"ensemble": 1, "time": 1}
io = ZarrBackend(file_name="outputs/05_ensemble_avsg.zarr", chunks=chunks)

Execute the Workflow#

With all components initialized, running the workflow is a single line of Python code. Workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information.

For the forecast we will predict for 10 steps (for FCN, this is 60 hours) with 8 ensemble members which will be ran in 2 batches with batch size 4.

nsteps = 10
nensemble = 8
batch_size = 4
io = ensemble(
    ["2024-01-01"],
    nsteps,
    nensemble,
    model,
    data,
    io,
    avsg,
    batch_size=batch_size,
    output_coords={"variable": np.array(["t2m", "tcwv"])},
)
2024-06-25 13:59:01.558 | INFO     | earth2studio.run:ensemble:294 - Running ensemble inference!
2024-06-25 13:59:01.558 | INFO     | earth2studio.run:ensemble:302 - Inference device: cuda
2024-06-25 13:59:01.565 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:149 - Fetching GFS index file: 2023-12-31 18:00:00

Fetching GFS for 2023-12-31 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2024-06-25 13:59:01.902 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: t850 at 2023-12-31 18:00:00

Fetching GFS for 2023-12-31 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
Fetching GFS for 2023-12-31 18:00:00:  14%|█▍        | 1/7 [00:00<00:04,  1.32it/s]

2024-06-25 13:59:02.661 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z1000 at 2023-12-31 18:00:00

Fetching GFS for 2023-12-31 18:00:00:  14%|█▍        | 1/7 [00:00<00:04,  1.32it/s]
Fetching GFS for 2023-12-31 18:00:00:  29%|██▊       | 2/7 [00:01<00:02,  1.86it/s]

2024-06-25 13:59:03.044 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z700 at 2023-12-31 18:00:00

Fetching GFS for 2023-12-31 18:00:00:  29%|██▊       | 2/7 [00:01<00:02,  1.86it/s]
Fetching GFS for 2023-12-31 18:00:00:  43%|████▎     | 3/7 [00:01<00:01,  2.17it/s]

2024-06-25 13:59:03.413 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z500 at 2023-12-31 18:00:00

Fetching GFS for 2023-12-31 18:00:00:  43%|████▎     | 3/7 [00:01<00:01,  2.17it/s]
Fetching GFS for 2023-12-31 18:00:00:  57%|█████▋    | 4/7 [00:01<00:01,  2.36it/s]

2024-06-25 13:59:03.778 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z300 at 2023-12-31 18:00:00

Fetching GFS for 2023-12-31 18:00:00:  57%|█████▋    | 4/7 [00:01<00:01,  2.36it/s]
Fetching GFS for 2023-12-31 18:00:00:  71%|███████▏  | 5/7 [00:02<00:00,  2.60it/s]

2024-06-25 13:59:04.095 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: tcwv at 2023-12-31 18:00:00

Fetching GFS for 2023-12-31 18:00:00:  71%|███████▏  | 5/7 [00:02<00:00,  2.60it/s]
Fetching GFS for 2023-12-31 18:00:00:  86%|████████▌ | 6/7 [00:02<00:00,  2.58it/s]

2024-06-25 13:59:04.488 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: t2m at 2023-12-31 18:00:00

Fetching GFS for 2023-12-31 18:00:00:  86%|████████▌ | 6/7 [00:02<00:00,  2.58it/s]
Fetching GFS for 2023-12-31 18:00:00: 100%|██████████| 7/7 [00:02<00:00,  2.54it/s]
Fetching GFS for 2023-12-31 18:00:00: 100%|██████████| 7/7 [00:02<00:00,  2.34it/s]
2024-06-25 13:59:04.901 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:149 - Fetching GFS index file: 2024-01-01 00:00:00

Fetching GFS for 2024-01-01 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2024-06-25 13:59:05.008 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: t850 at 2024-01-01 00:00:00

Fetching GFS for 2024-01-01 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2024-06-25 13:59:05.029 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z1000 at 2024-01-01 00:00:00

Fetching GFS for 2024-01-01 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2024-06-25 13:59:05.047 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z700 at 2024-01-01 00:00:00

Fetching GFS for 2024-01-01 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
Fetching GFS for 2024-01-01 00:00:00:  43%|████▎     | 3/7 [00:00<00:00,  7.28it/s]

2024-06-25 13:59:05.420 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z500 at 2024-01-01 00:00:00

Fetching GFS for 2024-01-01 00:00:00:  43%|████▎     | 3/7 [00:00<00:00,  7.28it/s]

2024-06-25 13:59:05.438 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: z300 at 2024-01-01 00:00:00

Fetching GFS for 2024-01-01 00:00:00:  43%|████▎     | 3/7 [00:00<00:00,  7.28it/s]
Fetching GFS for 2024-01-01 00:00:00:  71%|███████▏  | 5/7 [00:00<00:00,  6.02it/s]

2024-06-25 13:59:05.812 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: tcwv at 2024-01-01 00:00:00

Fetching GFS for 2024-01-01 00:00:00:  71%|███████▏  | 5/7 [00:00<00:00,  6.02it/s]

2024-06-25 13:59:05.831 | DEBUG    | earth2studio.data.gfs:fetch_gfs_dataarray:196 - Fetching GFS grib file for variable: t2m at 2024-01-01 00:00:00

Fetching GFS for 2024-01-01 00:00:00:  71%|███████▏  | 5/7 [00:00<00:00,  6.02it/s]
Fetching GFS for 2024-01-01 00:00:00: 100%|██████████| 7/7 [00:00<00:00,  8.31it/s]
2024-06-25 13:59:05.911 | SUCCESS  | earth2studio.run:ensemble:315 - Fetched data from GFS
2024-06-25 13:59:05.919 | INFO     | earth2studio.run:ensemble:337 - Starting 8 Member Ensemble Inference with             2 number of batches.

Total Ensemble Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Running batch 0 inference:   0%|          | 0/11 [00:00<?, ?it/s]

Running batch 0 inference:   9%|▉         | 1/11 [00:00<00:02,  3.74it/s]

Running batch 0 inference:  18%|█▊        | 2/11 [00:00<00:02,  3.01it/s]

Running batch 0 inference:  27%|██▋       | 3/11 [00:01<00:02,  2.86it/s]

Running batch 0 inference:  36%|███▋      | 4/11 [00:01<00:02,  2.65it/s]

Running batch 0 inference:  45%|████▌     | 5/11 [00:01<00:02,  2.53it/s]

Running batch 0 inference:  55%|█████▍    | 6/11 [00:02<00:02,  2.40it/s]

Running batch 0 inference:  64%|██████▎   | 7/11 [00:02<00:01,  2.29it/s]

Running batch 0 inference:  73%|███████▎  | 8/11 [00:03<00:01,  2.16it/s]

Running batch 0 inference:  82%|████████▏ | 9/11 [00:03<00:00,  2.10it/s]

Running batch 0 inference:  91%|█████████ | 10/11 [00:04<00:00,  2.00it/s]

Running batch 0 inference: 100%|██████████| 11/11 [00:04<00:00,  1.95it/s]


Total Ensemble Batches:  50%|█████     | 1/2 [00:07<00:07,  7.79s/it]

Running batch 4 inference:   0%|          | 0/11 [00:00<?, ?it/s]

Running batch 4 inference:   9%|▉         | 1/11 [00:00<00:02,  3.80it/s]

Running batch 4 inference:  18%|█▊        | 2/11 [00:00<00:02,  3.12it/s]

Running batch 4 inference:  27%|██▋       | 3/11 [00:01<00:02,  2.82it/s]

Running batch 4 inference:  36%|███▋      | 4/11 [00:01<00:02,  2.67it/s]

Running batch 4 inference:  45%|████▌     | 5/11 [00:01<00:02,  2.48it/s]

Running batch 4 inference:  55%|█████▍    | 6/11 [00:02<00:02,  2.35it/s]

Running batch 4 inference:  64%|██████▎   | 7/11 [00:02<00:01,  2.26it/s]

Running batch 4 inference:  73%|███████▎  | 8/11 [00:03<00:01,  2.14it/s]

Running batch 4 inference:  82%|████████▏ | 9/11 [00:03<00:00,  2.08it/s]

Running batch 4 inference:  91%|█████████ | 10/11 [00:04<00:00,  2.01it/s]

Running batch 4 inference: 100%|██████████| 11/11 [00:04<00:00,  1.96it/s]


Total Ensemble Batches: 100%|██████████| 2/2 [00:15<00:00,  7.75s/it]
Total Ensemble Batches: 100%|██████████| 2/2 [00:15<00:00,  7.76s/it]
2024-06-25 13:59:21.438 | SUCCESS  | earth2studio.run:ensemble:382 - Inference complete

Post Processing#

The last step is to post process our results. Lets plot both the perturbed t2m field and also the unperturbed tcwv field. First to confirm the perturbation method works as expect, the initial state is plotted.

Notice that the Zarr IO function has additional APIs to interact with the stored data.

import matplotlib.pyplot as plt

forecast = "2024-01-01"


def plot_(axi, data, title, cmap):
    """Simple plot util function"""
    im = axi.imshow(data, cmap=cmap)
    plt.colorbar(im, ax=axi, shrink=0.5, pad=0.04)
    axi.set_title(title)


step = 0  # lead time = 24 hrs
plt.close("all")

# Create a figure and axes with the specified projection
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 6))
plot_(
    ax[0, 0],
    np.mean(io["t2m"][:, 0, step], axis=0),
    f"{forecast} - t2m - Lead time: {6*step}hrs - Mean",
    "coolwarm",
)
plot_(
    ax[0, 1],
    np.std(io["t2m"][:, 0, step], axis=0),
    f"{forecast} - t2m - Lead time: {6*step}hrs - Std",
    "coolwarm",
)
plot_(
    ax[1, 0],
    np.mean(io["tcwv"][:, 0, step], axis=0),
    f"{forecast} - tcwv - Lead time: {6*step}hrs - Mean",
    "Blues",
)
plot_(
    ax[1, 1],
    np.std(io["tcwv"][:, 0, step], axis=0),
    f"{forecast} - tcwv - Lead time: {6*step}hrs - Std",
    "Blues",
)

plt.savefig(f"outputs/05_{forecast}_{step}_ensemble.jpg")
2024-01-01 - t2m - Lead time: 0hrs - Mean, 2024-01-01 - t2m - Lead time: 0hrs - Std, 2024-01-01 - tcwv - Lead time: 0hrs - Mean, 2024-01-01 - tcwv - Lead time: 0hrs - Std

Due to the intrinsic coupling between all fields, we should expect all variables to have some uncertainty for later lead times. Here the total column water vapor is plotted at a lead time of 24 hours, note the variance in the members despite just perturbing the temperature field.

step = 4  # lead time = 24 hrs
plt.close("all")

# Create a figure and axes with the specified projection
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 3))
plot_(
    ax[0],
    np.mean(io["tcwv"][:, 0, step], axis=0),
    f"{forecast} - tcwv - Lead time: {6*step}hrs - Mean",
    "Blues",
)
plot_(
    ax[1],
    np.std(io["tcwv"][:, 0, step], axis=0),
    f"{forecast} - tcwv - Lead time: {6*step}hrs - Std",
    "Blues",
)

plt.savefig(f"outputs/05_{forecast}_{step}_ensemble.jpg")
2024-01-01 - tcwv - Lead time: 24hrs - Mean, 2024-01-01 - tcwv - Lead time: 24hrs - Std

Total running time of the script: (0 minutes 24.119 seconds)

Gallery generated by Sphinx-Gallery