Model Hook Injection: Perturbation#

Adding model noise by using custom hooks.

This example will demonstrate how to run an ensemble inference workflow to generate a perturbed ensemble forecast. This perturbation is done by injecting code into the model front and rear hooks. These hooks are applied to the tensor data before/after the model forward call.

This example also illustrates how you can subselect data for IO. In this example we will only output two variables: total column water vapor (tcwv) and 500 hPa geopotential (z500). To run this, make sure that the model selected predicts these variables are change appropriately.

In this example you will learn:

  • How to instantiate a built in prognostic model

  • Creating a data source and IO object

  • Changing the model forward/rear hooks

  • Choose a subselection of coordinates to save to an IO object.

  • Post-processing results

Creating an Ensemble Workflow#

To start let’s begin with creating an ensemble workflow to use. We encourage users to explore and experiment with their own custom workflows that borrow ideas from built in workflows inside earth2studio.run or the examples.

Creating our own generalizable ensemble workflow is easy when we rely on the component interfaces defined in Earth2Studio (use dependency injection). Here we create a run method that accepts the following:

  • time: Input list of datetimes / strings to run inference for

  • nsteps: Number of forecast steps to predict

  • nensemble: Number of ensembles to run for

  • prognostic: Our initialized prognostic model

  • data: Initialized data source to fetch initial conditions from

  • io: io store that data is written to.

  • output_coords: CoordSystem of output coordinates that should be saved. Should be a proper subset of model output coordinates.

Set Up#

With the ensemble workflow defined, we now need to create the individual components.

We need the following:

We will first run the ensemble workflow using an unmodified function, that is a model that has the default (identity) forward and rear hooks. Then we will define new hooks for the model and rerun the inference request. %%

import os

os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv

load_dotenv()  # TODO: make common example prep function

import numpy as np

from earth2studio.data import GFS
from earth2studio.io import ZarrBackend
from earth2studio.models.px import DLWP
from earth2studio.perturbation import Gaussian
from earth2studio.run import ensemble

# Load the default model package which downloads the check point from NGC
package = DLWP.load_default_package()
model = DLWP.load_model(package)

# Create the data source
data = GFS()

# Create the IO handler, store in memory
chunks = {"ensemble": 1, "time": 1, "lead_time": 1}
io_unperturbed = ZarrBackend(file_name="outputs/05_ensemble.zarr", chunks=chunks)
/usr/local/lib/python3.10/dist-packages/modulus/models/module.py:360: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  model_dict = torch.load(

Execute the Workflow#

First, we will run the ensemble workflow but with a earth2studio.perturbation.Gaussian() perturbation as the control.

The workflow will return the provided IO object back to the user, which can be used to then post process. Some have additional APIs that can be handy for post-processing or saving to file. Check the API docs for more information.

nsteps = 4 * 12
nensemble = 16
batch_size = 4
forecast_date = "2024-01-30"
output_coords = {
    "lat": np.arange(25.0, 60.0, 0.25),
    "lon": np.arange(230.0, 300.0, 0.25),
    "variable": np.array(["tcwv", "z500"]),
}

# First run with no model perturbation
io_unperturbed = ensemble(
    [forecast_date],
    nsteps,
    nensemble,
    model,
    data,
    io_unperturbed,
    Gaussian(noise_amplitude=0.01),
    output_coords=output_coords,
    batch_size=batch_size,
)
2025-01-23 04:41:38.660 | INFO     | earth2studio.run:ensemble:315 - Running ensemble inference!
2025-01-23 04:41:38.661 | INFO     | earth2studio.run:ensemble:323 - Inference device: cuda
2025-01-23 04:41:38.670 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:209 - Fetching GFS index file: 2024-01-29 18:00:00 lead 0:00:00

Fetching GFS for 2024-01-29 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:38.674 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t850 at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:38.702 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z1000 at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:38.728 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z700 at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:38.755 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z500 at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
Fetching GFS for 2024-01-29 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 36.97it/s]

2025-01-23 04:41:38.782 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z300 at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 36.97it/s]

2025-01-23 04:41:38.809 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: tcwv at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 36.97it/s]

2025-01-23 04:41:38.835 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t2m at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 36.97it/s]
Fetching GFS for 2024-01-29 18:00:00: 100%|██████████| 7/7 [00:00<00:00, 37.24it/s]
2025-01-23 04:41:38.877 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:209 - Fetching GFS index file: 2024-01-30 00:00:00 lead 0:00:00

Fetching GFS for 2024-01-30 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:38.881 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t850 at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:38.909 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z1000 at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:38.937 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z700 at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:38.964 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z500 at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
Fetching GFS for 2024-01-30 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 36.45it/s]

2025-01-23 04:41:38.991 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z300 at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 36.45it/s]

2025-01-23 04:41:39.017 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: tcwv at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 36.45it/s]

2025-01-23 04:41:39.044 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t2m at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 36.45it/s]
Fetching GFS for 2024-01-30 00:00:00: 100%|██████████| 7/7 [00:00<00:00, 36.98it/s]
2025-01-23 04:41:39.126 | SUCCESS  | earth2studio.run:ensemble:345 - Fetched data from GFS
2025-01-23 04:41:39.143 | INFO     | earth2studio.run:ensemble:367 - Starting 16 Member Ensemble Inference with             4 number of batches.

Total Ensemble Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Running batch 0 inference:   0%|          | 0/49 [00:00<?, ?it/s]

Running batch 0 inference:   8%|▊         | 4/49 [00:00<00:01, 29.45it/s]

Running batch 0 inference:  16%|█▋        | 8/49 [00:00<00:01, 29.36it/s]

Running batch 0 inference:  24%|██▍       | 12/49 [00:00<00:01, 29.30it/s]

Running batch 0 inference:  33%|███▎      | 16/49 [00:00<00:01, 29.19it/s]

Running batch 0 inference:  41%|████      | 20/49 [00:00<00:00, 29.01it/s]

Running batch 0 inference:  49%|████▉     | 24/49 [00:00<00:00, 29.05it/s]

Running batch 0 inference:  57%|█████▋    | 28/49 [00:00<00:00, 29.11it/s]

Running batch 0 inference:  65%|██████▌   | 32/49 [00:01<00:00, 29.16it/s]

Running batch 0 inference:  73%|███████▎  | 36/49 [00:01<00:00, 29.09it/s]

Running batch 0 inference:  82%|████████▏ | 40/49 [00:01<00:00, 29.09it/s]

Running batch 0 inference:  90%|████████▉ | 44/49 [00:01<00:00, 29.07it/s]

Running batch 0 inference:  98%|█████████▊| 48/49 [00:01<00:00, 29.12it/s]


Total Ensemble Batches:  25%|██▌       | 1/4 [00:01<00:05,  1.68s/it]

Running batch 4 inference:   0%|          | 0/49 [00:00<?, ?it/s]

Running batch 4 inference:   8%|▊         | 4/49 [00:00<00:01, 28.08it/s]

Running batch 4 inference:  16%|█▋        | 8/49 [00:00<00:01, 28.32it/s]

Running batch 4 inference:  24%|██▍       | 12/49 [00:00<00:01, 28.68it/s]

Running batch 4 inference:  33%|███▎      | 16/49 [00:00<00:01, 28.85it/s]

Running batch 4 inference:  41%|████      | 20/49 [00:00<00:01, 28.97it/s]

Running batch 4 inference:  49%|████▉     | 24/49 [00:00<00:00, 29.02it/s]

Running batch 4 inference:  57%|█████▋    | 28/49 [00:00<00:00, 29.03it/s]

Running batch 4 inference:  65%|██████▌   | 32/49 [00:01<00:00, 29.06it/s]

Running batch 4 inference:  73%|███████▎  | 36/49 [00:01<00:00, 29.02it/s]

Running batch 4 inference:  82%|████████▏ | 40/49 [00:01<00:00, 29.03it/s]

Running batch 4 inference:  90%|████████▉ | 44/49 [00:01<00:00, 28.80it/s]

Running batch 4 inference:  98%|█████████▊| 48/49 [00:01<00:00, 28.65it/s]


Total Ensemble Batches:  50%|█████     | 2/4 [00:03<00:03,  1.69s/it]

Running batch 8 inference:   0%|          | 0/49 [00:00<?, ?it/s]

Running batch 8 inference:   8%|▊         | 4/49 [00:00<00:01, 28.22it/s]

Running batch 8 inference:  16%|█▋        | 8/49 [00:00<00:01, 28.15it/s]

Running batch 8 inference:  24%|██▍       | 12/49 [00:00<00:01, 28.31it/s]

Running batch 8 inference:  33%|███▎      | 16/49 [00:00<00:01, 28.39it/s]

Running batch 8 inference:  41%|████      | 20/49 [00:00<00:01, 28.40it/s]

Running batch 8 inference:  49%|████▉     | 24/49 [00:00<00:00, 28.35it/s]

Running batch 8 inference:  57%|█████▋    | 28/49 [00:00<00:00, 28.42it/s]

Running batch 8 inference:  65%|██████▌   | 32/49 [00:01<00:00, 28.46it/s]

Running batch 8 inference:  73%|███████▎  | 36/49 [00:01<00:00, 28.55it/s]

Running batch 8 inference:  82%|████████▏ | 40/49 [00:01<00:00, 28.49it/s]

Running batch 8 inference:  90%|████████▉ | 44/49 [00:01<00:00, 28.53it/s]

Running batch 8 inference:  98%|█████████▊| 48/49 [00:01<00:00, 28.54it/s]


Total Ensemble Batches:  75%|███████▌  | 3/4 [00:05<00:01,  1.70s/it]

Running batch 12 inference:   0%|          | 0/49 [00:00<?, ?it/s]

Running batch 12 inference:   8%|▊         | 4/49 [00:00<00:01, 29.36it/s]

Running batch 12 inference:  16%|█▋        | 8/49 [00:00<00:01, 28.80it/s]

Running batch 12 inference:  24%|██▍       | 12/49 [00:00<00:01, 28.75it/s]

Running batch 12 inference:  33%|███▎      | 16/49 [00:00<00:01, 28.69it/s]

Running batch 12 inference:  41%|████      | 20/49 [00:00<00:01, 28.43it/s]

Running batch 12 inference:  49%|████▉     | 24/49 [00:00<00:00, 28.52it/s]

Running batch 12 inference:  57%|█████▋    | 28/49 [00:00<00:00, 28.61it/s]

Running batch 12 inference:  65%|██████▌   | 32/49 [00:01<00:00, 28.71it/s]

Running batch 12 inference:  73%|███████▎  | 36/49 [00:01<00:00, 28.72it/s]

Running batch 12 inference:  82%|████████▏ | 40/49 [00:01<00:00, 28.70it/s]

Running batch 12 inference:  90%|████████▉ | 44/49 [00:01<00:00, 28.74it/s]

Running batch 12 inference:  98%|█████████▊| 48/49 [00:01<00:00, 28.70it/s]


Total Ensemble Batches: 100%|██████████| 4/4 [00:06<00:00,  1.70s/it]
Total Ensemble Batches: 100%|██████████| 4/4 [00:06<00:00,  1.70s/it]
2025-01-23 04:41:45.925 | SUCCESS  | earth2studio.run:ensemble:412 - Inference complete

Now let’s introduce slight model perturbation using the prognostic model hooks defined in earth2studio.models.px.utils.PrognosticMixin. Note that center.unsqueeze(-1) is DLWP specific since it operates on a cubed sphere with grid dimensions (nface, lat, lon) instead of just (lat,lon). To switch out the model, consider removing the unsqueeze() .

model.front_hook = lambda x, coords: (
    x
    - 0.1
    * x.var(dim=0)
    * (x - model.center.unsqueeze(-1))
    / (model.scale.unsqueeze(-1)) ** 2
    + 0.1 * (x - x.mean(dim=0)),
    coords,
)
# Also could use model.rear_hook = ...

io_perturbed = ZarrBackend(
    file_name="outputs/05_ensemble_model_perturbation.zarr",
    chunks=chunks,
    backend_kwargs={"overwrite": True},
)
io_perturbed = ensemble(
    [forecast_date],
    nsteps,
    nensemble,
    model,
    data,
    io_perturbed,
    Gaussian(noise_amplitude=0.01),
    output_coords=output_coords,
    batch_size=batch_size,
)
2025-01-23 04:41:45.927 | INFO     | earth2studio.run:ensemble:315 - Running ensemble inference!
2025-01-23 04:41:45.927 | INFO     | earth2studio.run:ensemble:323 - Inference device: cuda
2025-01-23 04:41:45.928 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:209 - Fetching GFS index file: 2024-01-29 18:00:00 lead 0:00:00

Fetching GFS for 2024-01-29 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:45.931 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t850 at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:45.957 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z1000 at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:45.983 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z700 at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:46.009 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z500 at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
Fetching GFS for 2024-01-29 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 38.10it/s]

2025-01-23 04:41:46.036 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z300 at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 38.10it/s]

2025-01-23 04:41:46.063 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: tcwv at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 38.10it/s]

2025-01-23 04:41:46.089 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t2m at 2024-01-29 18:00:00_0:00:00

Fetching GFS for 2024-01-29 18:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 38.10it/s]
Fetching GFS for 2024-01-29 18:00:00: 100%|██████████| 7/7 [00:00<00:00, 38.03it/s]
2025-01-23 04:41:46.136 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:209 - Fetching GFS index file: 2024-01-30 00:00:00 lead 0:00:00

Fetching GFS for 2024-01-30 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:46.139 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t850 at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:46.166 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z1000 at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:46.193 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z700 at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]

2025-01-23 04:41:46.219 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z500 at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:   0%|          | 0/7 [00:00<?, ?it/s]
Fetching GFS for 2024-01-30 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 37.30it/s]

2025-01-23 04:41:46.246 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: z300 at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 37.30it/s]

2025-01-23 04:41:46.273 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: tcwv at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 37.30it/s]

2025-01-23 04:41:46.299 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: t2m at 2024-01-30 00:00:00_0:00:00

Fetching GFS for 2024-01-30 00:00:00:  57%|█████▋    | 4/7 [00:00<00:00, 37.30it/s]
Fetching GFS for 2024-01-30 00:00:00: 100%|██████████| 7/7 [00:00<00:00, 37.48it/s]
2025-01-23 04:41:46.383 | SUCCESS  | earth2studio.run:ensemble:345 - Fetched data from GFS
2025-01-23 04:41:46.400 | INFO     | earth2studio.run:ensemble:367 - Starting 16 Member Ensemble Inference with             4 number of batches.

Total Ensemble Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Running batch 0 inference:   0%|          | 0/49 [00:00<?, ?it/s]

Running batch 0 inference:   8%|▊         | 4/49 [00:00<00:01, 29.28it/s]

Running batch 0 inference:  16%|█▋        | 8/49 [00:00<00:01, 29.05it/s]

Running batch 0 inference:  24%|██▍       | 12/49 [00:00<00:01, 28.71it/s]

Running batch 0 inference:  33%|███▎      | 16/49 [00:00<00:01, 28.76it/s]

Running batch 0 inference:  41%|████      | 20/49 [00:00<00:01, 28.79it/s]

Running batch 0 inference:  49%|████▉     | 24/49 [00:00<00:00, 28.79it/s]

Running batch 0 inference:  57%|█████▋    | 28/49 [00:00<00:00, 28.35it/s]

Running batch 0 inference:  65%|██████▌   | 32/49 [00:01<00:00, 28.51it/s]

Running batch 0 inference:  73%|███████▎  | 36/49 [00:01<00:00, 28.53it/s]

Running batch 0 inference:  82%|████████▏ | 40/49 [00:01<00:00, 28.55it/s]

Running batch 0 inference:  90%|████████▉ | 44/49 [00:01<00:00, 28.61it/s]

Running batch 0 inference:  98%|█████████▊| 48/49 [00:01<00:00, 28.68it/s]


Total Ensemble Batches:  25%|██▌       | 1/4 [00:01<00:05,  1.70s/it]

Running batch 4 inference:   0%|          | 0/49 [00:00<?, ?it/s]

Running batch 4 inference:   8%|▊         | 4/49 [00:00<00:01, 28.79it/s]

Running batch 4 inference:  16%|█▋        | 8/49 [00:00<00:01, 28.72it/s]

Running batch 4 inference:  24%|██▍       | 12/49 [00:00<00:01, 28.73it/s]

Running batch 4 inference:  33%|███▎      | 16/49 [00:00<00:01, 28.70it/s]

Running batch 4 inference:  41%|████      | 20/49 [00:00<00:01, 28.67it/s]

Running batch 4 inference:  49%|████▉     | 24/49 [00:00<00:00, 28.69it/s]

Running batch 4 inference:  57%|█████▋    | 28/49 [00:00<00:00, 28.76it/s]

Running batch 4 inference:  65%|██████▌   | 32/49 [00:01<00:00, 28.86it/s]

Running batch 4 inference:  73%|███████▎  | 36/49 [00:01<00:00, 28.82it/s]

Running batch 4 inference:  82%|████████▏ | 40/49 [00:01<00:00, 28.73it/s]

Running batch 4 inference:  90%|████████▉ | 44/49 [00:01<00:00, 28.75it/s]

Running batch 4 inference:  98%|█████████▊| 48/49 [00:01<00:00, 28.85it/s]


Total Ensemble Batches:  50%|█████     | 2/4 [00:03<00:03,  1.70s/it]

Running batch 8 inference:   0%|          | 0/49 [00:00<?, ?it/s]

Running batch 8 inference:   8%|▊         | 4/49 [00:00<00:01, 29.43it/s]

Running batch 8 inference:  16%|█▋        | 8/49 [00:00<00:01, 29.14it/s]

Running batch 8 inference:  24%|██▍       | 12/49 [00:00<00:01, 29.18it/s]

Running batch 8 inference:  33%|███▎      | 16/49 [00:00<00:01, 29.15it/s]

Running batch 8 inference:  41%|████      | 20/49 [00:00<00:00, 29.04it/s]

Running batch 8 inference:  49%|████▉     | 24/49 [00:00<00:00, 29.02it/s]

Running batch 8 inference:  57%|█████▋    | 28/49 [00:00<00:00, 29.01it/s]

Running batch 8 inference:  65%|██████▌   | 32/49 [00:01<00:00, 28.83it/s]

Running batch 8 inference:  73%|███████▎  | 36/49 [00:01<00:00, 28.56it/s]

Running batch 8 inference:  82%|████████▏ | 40/49 [00:01<00:00, 28.60it/s]

Running batch 8 inference:  90%|████████▉ | 44/49 [00:01<00:00, 28.70it/s]

Running batch 8 inference:  98%|█████████▊| 48/49 [00:01<00:00, 28.82it/s]


Total Ensemble Batches:  75%|███████▌  | 3/4 [00:05<00:01,  1.69s/it]

Running batch 12 inference:   0%|          | 0/49 [00:00<?, ?it/s]

Running batch 12 inference:   8%|▊         | 4/49 [00:00<00:01, 29.54it/s]

Running batch 12 inference:  16%|█▋        | 8/49 [00:00<00:01, 29.14it/s]

Running batch 12 inference:  24%|██▍       | 12/49 [00:00<00:01, 28.92it/s]

Running batch 12 inference:  33%|███▎      | 16/49 [00:00<00:01, 29.02it/s]

Running batch 12 inference:  41%|████      | 20/49 [00:00<00:01, 28.98it/s]

Running batch 12 inference:  49%|████▉     | 24/49 [00:00<00:00, 28.85it/s]

Running batch 12 inference:  57%|█████▋    | 28/49 [00:00<00:00, 28.49it/s]

Running batch 12 inference:  65%|██████▌   | 32/49 [00:01<00:00, 28.52it/s]

Running batch 12 inference:  73%|███████▎  | 36/49 [00:01<00:00, 28.55it/s]

Running batch 12 inference:  82%|████████▏ | 40/49 [00:01<00:00, 28.41it/s]

Running batch 12 inference:  90%|████████▉ | 44/49 [00:01<00:00, 28.26it/s]

Running batch 12 inference:  98%|█████████▊| 48/49 [00:01<00:00, 28.11it/s]


Total Ensemble Batches: 100%|██████████| 4/4 [00:06<00:00,  1.70s/it]
Total Ensemble Batches: 100%|██████████| 4/4 [00:06<00:00,  1.70s/it]
2025-01-23 04:41:53.193 | SUCCESS  | earth2studio.run:ensemble:412 - Inference complete

Post Processing#

The last step is to post process our results. Here we plot and compare the ensemble mean and standard deviation from using an unperturbed/perturbed model.

Notice that the Zarr IO function has additional APIs to interact with the stored data.

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

levels_unperturbed = np.linspace(0, io_unperturbed["tcwv"][:].max())
levels_perturbed = np.linspace(0, io_perturbed["tcwv"][:].max())


std_levels_perturbed = np.linspace(0, io_perturbed["tcwv"][:].std(axis=0).max())

plt.close("all")
fig = plt.figure(figsize=(20, 10), tight_layout=True)
ax0 = fig.add_subplot(2, 2, 1, projection=ccrs.PlateCarree())
ax1 = fig.add_subplot(2, 2, 2, projection=ccrs.PlateCarree())
ax2 = fig.add_subplot(2, 2, 3, projection=ccrs.PlateCarree())
ax3 = fig.add_subplot(2, 2, 4, projection=ccrs.PlateCarree())


def update(frame):
    """This function updates the frame with a new lead time for animation."""
    import warnings

    warnings.filterwarnings("ignore")
    ax0.clear()
    ax1.clear()
    ax2.clear()
    ax3.clear()

    ## Update unperturbed image
    im0 = ax0.contourf(
        io_unperturbed["lon"][:],
        io_unperturbed["lat"][:],
        io_unperturbed["tcwv"][:, 0, frame].mean(axis=0),
        transform=ccrs.PlateCarree(),
        cmap="Blues",
        levels=levels_unperturbed,
    )
    ax0.coastlines()
    ax0.gridlines()

    im1 = ax1.contourf(
        io_unperturbed["lon"][:],
        io_unperturbed["lat"][:],
        io_unperturbed["tcwv"][:, 0, frame].std(axis=0),
        transform=ccrs.PlateCarree(),
        cmap="RdPu",
        levels=std_levels_perturbed,
        norm=LogNorm(vmin=1e-1, vmax=std_levels_perturbed[-1]),
    )
    ax1.coastlines()
    ax1.gridlines()

    im2 = ax2.contourf(
        io_perturbed["lon"][:],
        io_perturbed["lat"][:],
        io_perturbed["tcwv"][:, 0, frame].mean(axis=0),
        transform=ccrs.PlateCarree(),
        cmap="Blues",
        levels=levels_perturbed,
    )
    ax2.coastlines()
    ax2.gridlines()

    im3 = ax3.contourf(
        io_perturbed["lon"][:],
        io_perturbed["lat"][:],
        io_perturbed["tcwv"][:, 0, frame].std(axis=0),
        transform=ccrs.PlateCarree(),
        cmap="RdPu",
        levels=std_levels_perturbed,
        norm=LogNorm(vmin=1e-1, vmax=std_levels_perturbed[-1]),
    )
    ax3.coastlines()
    ax3.gridlines()

    for i in range(16):
        ax0.contour(
            io_unperturbed["lon"][:],
            io_unperturbed["lat"][:],
            io_unperturbed["z500"][i, 0, frame] / 100.0,
            transform=ccrs.PlateCarree(),
            levels=np.arange(485, 580, 15),
            colors="black",
            linestyle="dashed",
        )

        ax2.contour(
            io_perturbed["lon"][:],
            io_perturbed["lat"][:],
            io_perturbed["z500"][i, 0, frame] / 100.0,
            transform=ccrs.PlateCarree(),
            levels=np.arange(485, 580, 15),
            colors="black",
            linestyle="dashed",
        )
    plt.suptitle(
        f'Forecast Starting on {forecast_date} - Lead Time - {io_perturbed["lead_time"][frame]}'
    )

    ax0.set_title("Unperturbed Ensemble Mean - tcwv + z500 countors")
    ax1.set_title("Unperturbed Ensemble Std - tcwv")
    ax2.set_title("Perturbed Ensemble Mean - tcwv + z500 contours")
    ax3.set_title("Perturbed Ensemble Std - tcwv")

    if frame == 0:
        plt.colorbar(
            im0, ax=ax0, shrink=0.75, pad=0.04, label="kg m^-2", format="%2.1f"
        )
        plt.colorbar(
            im1, ax=ax1, shrink=0.75, pad=0.04, label="kg m^-2", format="%1.2e"
        )
        plt.colorbar(
            im2, ax=ax2, shrink=0.75, pad=0.04, label="kg m^-2", format="%2.1f"
        )
        plt.colorbar(
            im3, ax=ax3, shrink=0.75, pad=0.04, label="kg m^-2", format="%1.2e"
        )


# Uncomment this for animation
# import matplotlib.animation as animation
# update(0)
# ani = animation.FuncAnimation(
# fig=fig, func=update, frames=range(1, nsteps), cache_frame_data=False
# )
# ani.save(f"outputs/05_model_perturbation_{forecast_date}.gif", dpi=300)


for lt in [10, 20, 30, 40]:
    update(lt)
    plt.savefig(
        f"outputs/05_model_perturbation_{forecast_date}_leadtime_{lt}.png",
        dpi=300,
        bbox_inches="tight",
    )
Forecast Starting on 2024-01-30 - Lead Time - 240 hours, Unperturbed Ensemble Mean - tcwv + z500 countors, Unperturbed Ensemble Std - tcwv, Perturbed Ensemble Mean - tcwv + z500 contours, Perturbed Ensemble Std - tcwv

Total running time of the script: (0 minutes 27.013 seconds)

Gallery generated by Sphinx-Gallery