Extending Data Sources#

Implementing a custom data source

This example will demonstrate how to extend Earth2Studio by implementing a custom data source to use in a built in workflow.

In this example you will learn:

  • API requirements of data soruces

  • Implementing a custom data soruce

Custom Data Source#

Earth2Studio defines the required APIs for data sources in earth2studio.data.base.DataSource which requires just a call function. For this example, we will consider extending an existing remote data source with another atmospheric field we can calculate.

The earth2studio.data.ARCO data source provides the ERA5 dataset in a cloud optimized format, however it only provides specific humidity. This is a problem for models that may use relative humidity as an input. Based on ECMWF documentation we can calculate the relative humidity based on temperature and geo-potential.

import os

os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv

load_dotenv()  # TODO: make common example prep function

from datetime import datetime

import numpy as np
import xarray as xr

from earth2studio.data import ARCO, GFS
from earth2studio.data.utils import prep_data_inputs
from earth2studio.utils.type import TimeArray, VariableArray


class CustomDataSource:
    """Custom ARCO datasource"""

    relative_humidity_ids = [
        "r50",
        "r100",
        "r150",
        "r200",
        "r250",
        "r300",
        "r400",
        "r500",
        "r600",
        "r700",
        "r850",
        "r925",
        "r1000",
    ]

    def __init__(self, cache: bool = True, verbose: bool = True):
        self.arco = ARCO(cache, verbose)

    def __call__(
        self,
        time: datetime | list[datetime] | TimeArray,
        variable: str | list[str] | VariableArray,
    ) -> xr.DataArray:
        """Function to get data.

        Parameters
        ----------
        time : datetime | list[datetime] | TimeArray
            Timestamps to return data for (UTC).
        variable : str | list[str] | VariableArray
            String, list of strings or array of strings that refer to variables to
            return. Must be in IFS lexicon.

        Returns
        -------
        xr.DataArray
        """
        time, variable = prep_data_inputs(time, variable)

        # Replace relative humidity with respective temperature
        # and specifc humidity fields
        variable_expanded = []
        for v in variable:
            if v in self.relative_humidity_ids:
                level = int(v[1:])
                variable_expanded.extend([f"t{level}", f"q{level}"])
            else:
                variable_expanded.append(v)
        variable_expanded = list(set(variable_expanded))

        # Fetch from ARCO
        da_exp = self.arco(time, variable_expanded)

        # Calculate relative humidity when needed
        arrays = []
        for v in variable:
            if v in self.relative_humidity_ids:
                level = int(v[1:])
                t = da_exp.sel(variable=f"t{level}").values
                q = da_exp.sel(variable=f"q{level}").values
                rh = self.calc_relative_humdity(t, q, 100 * level)
                arrays.append(rh)
            else:
                arrays.append(da_exp.sel(variable=v).values)

        da = xr.DataArray(
            data=np.stack(arrays, axis=1),
            dims=["time", "variable", "lat", "lon"],
            coords=dict(
                time=da_exp.coords["time"].values,
                variable=np.array(variable),
                lat=da_exp.coords["lat"].values,
                lon=da_exp.coords["lon"].values,
            ),
        )
        return da

    def calc_relative_humdity(
        self, temperature: np.array, specific_humidity: np.array, pressure: float
    ) -> np.array:
        """Relative humidity calculation

        Parameters
        ----------
        temperature : np.array
            Temperature field (K)
        specific_humidity : np.array
            Specific humidity field (g.kg-1)
        pressure : float
            Pressure (Pa)

        Returns
        -------
        np.array
        """
        epsilon = 0.621981
        p = pressure
        q = specific_humidity
        t = temperature

        e = (p * q * (1.0 / epsilon)) / (1 + q * (1.0 / (epsilon) - 1))

        es_w = 611.21 * np.exp(17.502 * (t - 273.16) / (t - 32.19))
        es_i = 611.21 * np.exp(22.587 * (t - 273.16) / (t + 0.7))

        alpha = np.clip((t - 250.16) / (273.16 - 250.16), 0, 1.2) ** 2
        es = alpha * es_w + (1 - alpha) * es_i
        rh = 100 * e / es

        return rh

__call__() API#

The call function is the main API of data source which return the Xarray data array with the requested data. For this custom data source we intercept relative humidity variables, replace them with temperature and specific humidity requests then calculate the relative humidity from these fields. Note that the ARCO data source is handling the remote complexity, we are just manipulating Numpy arrays

calc_relative_humdity()#

Based on the calculations ECMWF uses in their IFS numerical simulator which accounts for estimating the water vapor and ice present in the atmosphere.

Verification#

Before plugging this into our workflow, let’s quickly verify our data source is consistent with when GFS provides for relative humidity.

ds = CustomDataSource()
da_custom = ds(time=datetime(2022, 1, 1, hour=0), variable=["r500"])

ds_gfs = GFS()
da_gfs = ds_gfs(time=datetime(2022, 1, 1, hour=0), variable=["r500"])

print(da_custom)
Fetching ARCO data:   0%|          | 0/2 [00:00<?, ?it/s]

2025-01-23 05:03:21.814 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: t500 at 2022-01-01T00:00:00

Fetching ARCO data:   0%|          | 0/2 [00:00<?, ?it/s]

2025-01-23 05:03:22.012 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: q500 at 2022-01-01T00:00:00

Fetching ARCO data:   0%|          | 0/2 [00:00<?, ?it/s]
Fetching ARCO data:  50%|█████     | 1/2 [00:00<00:00,  2.45it/s]
Fetching ARCO data: 100%|██████████| 2/2 [00:00<00:00,  4.88it/s]
2025-01-23 05:03:22.278 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:209 - Fetching GFS index file: 2022-01-01 00:00:00 lead 0:00:00

Fetching GFS for 2022-01-01 00:00:00:   0%|          | 0/1 [00:00<?, ?it/s]

2025-01-23 05:03:22.282 | DEBUG    | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: r500 at 2022-01-01 00:00:00_0:00:00

Fetching GFS for 2022-01-01 00:00:00:   0%|          | 0/1 [00:00<?, ?it/s]
Fetching GFS for 2022-01-01 00:00:00: 100%|██████████| 1/1 [00:00<00:00, 31.36it/s]
<xarray.DataArray (time: 1, variable: 1, lat: 721, lon: 1440)> Size: 8MB
array([[[[ 28.01413468,  28.01413468,  28.01413468, ...,  28.01413468,
           28.01413468,  28.01413468],
         [ 29.75579658,  29.75314258,  29.81191126, ...,  29.77167483,
           29.76636404,  29.76376009],
         [ 31.28908409,  31.28075462,  31.27237381, ...,  31.31699177,
           31.30860053,  31.29746972],
         ...,
         [ 97.36539026,  97.35685903,  97.39891442, ...,  97.34011022,
           97.32321665,  97.31468906],
         [ 97.50236648,  97.49380369,  97.54603948, ...,  97.5280597 ,
           97.51949448,  97.51949448],
         [102.58417544, 102.58417544, 102.58417544, ..., 102.58417544,
          102.58417544, 102.58417544]]]])
Coordinates:
  * time      (time) datetime64[ns] 8B 2022-01-01
  * variable  (variable) <U4 16B 'r500'
  * lat       (lat) float64 6kB 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
  * lon       (lon) float64 12kB 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8
import cartopy.crs as ccrs
import matplotlib.pyplot as plt

fig, ax = plt.subplots(
    1,
    2,
    figsize=(10, 3),
    subplot_kw={"projection": ccrs.Mollweide()},
    constrained_layout=True,
)

ax[0].imshow(
    da_custom.sel(variable="r500")[0], transform=ccrs.PlateCarree(), vmin=0, vmax=100
)
ax[1].imshow(
    da_gfs.sel(variable="r500")[0], transform=ccrs.PlateCarree(), vmin=0, vmax=100
)

ax[0].set_title("Custom ARCO")
ax[1].set_title("GFS")
plt.suptitle("r500", fontsize=24)
cbar = plt.cm.ScalarMappable()
cbar.set_array(da_custom.sel(variable="r500")[0])
cbar.set_clim(0, 100)
cbar = fig.colorbar(cbar, ax=ax[-1], orientation="vertical", shrink=0.8)

plt.savefig("outputs/custom_datasource_gfs_versus_custom.jpg")
r500, Custom ARCO, GFS

Execute Workflow#

We will use this custom data source to run deterministic inference with a model that requires relative humidity. earth2studio.models.px.FCN is one such model. Since we are using ARCO, we can run inference for a time quite far back in time.

Let’s instantiate the components needed.

from dotenv import load_dotenv

load_dotenv()  # TODO: make common example prep function

import earth2studio.run as run
from earth2studio.io import ZarrBackend
from earth2studio.models.px import FCN

package = FCN.load_default_package()
model = FCN.load_model(package)

# Create the data source
data = CustomDataSource()

# Create the IO handler, store in memory
io = ZarrBackend()

nsteps = 4
io = run.deterministic(["1993-04-05"], nsteps, model, data, io)

print(io.root.tree())
2025-01-23 05:03:37.838 | INFO     | earth2studio.run:deterministic:75 - Running simple workflow!
2025-01-23 05:03:37.838 | INFO     | earth2studio.run:deterministic:82 - Inference device: cuda

Fetching ARCO data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-01-23 05:03:37.889 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: t500 at 1993-04-05T00:00:00

Fetching ARCO data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-01-23 05:03:38.087 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: msl at 1993-04-05T00:00:00

Fetching ARCO data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-01-23 05:03:38.813 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v500 at 1993-04-05T00:00:00

Fetching ARCO data:   0%|          | 0/26 [00:00<?, ?it/s]

2025-01-23 05:03:39.017 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: t2m at 1993-04-05T00:00:00

Fetching ARCO data:   0%|          | 0/26 [00:01<?, ?it/s]
Fetching ARCO data:   4%|▍         | 1/26 [00:01<00:44,  1.78s/it]

2025-01-23 05:03:39.672 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u1000 at 1993-04-05T00:00:00

Fetching ARCO data:  15%|█▌        | 4/26 [00:01<00:39,  1.78s/it]

2025-01-23 05:03:39.869 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u10m at 1993-04-05T00:00:00

Fetching ARCO data:  15%|█▌        | 4/26 [00:01<00:39,  1.78s/it]

2025-01-23 05:03:39.874 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: z850 at 1993-04-05T00:00:00

Fetching ARCO data:  15%|█▌        | 4/26 [00:01<00:39,  1.78s/it]

2025-01-23 05:03:40.064 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: z500 at 1993-04-05T00:00:00

Fetching ARCO data:  15%|█▌        | 4/26 [00:02<00:39,  1.78s/it]
Fetching ARCO data:  19%|█▉        | 5/26 [00:02<00:08,  2.57it/s]

2025-01-23 05:03:40.255 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: t850 at 1993-04-05T00:00:00

Fetching ARCO data:  31%|███       | 8/26 [00:02<00:06,  2.57it/s]

2025-01-23 05:03:40.441 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: z50 at 1993-04-05T00:00:00

Fetching ARCO data:  31%|███       | 8/26 [00:02<00:06,  2.57it/s]

2025-01-23 05:03:40.620 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: t250 at 1993-04-05T00:00:00

Fetching ARCO data:  31%|███       | 8/26 [00:02<00:06,  2.57it/s]

2025-01-23 05:03:40.808 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: sp at 1993-04-05T00:00:00

Fetching ARCO data:  31%|███       | 8/26 [00:02<00:06,  2.57it/s]
Fetching ARCO data:  35%|███▍      | 9/26 [00:02<00:04,  3.95it/s]

2025-01-23 05:03:40.820 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: q850 at 1993-04-05T00:00:00

Fetching ARCO data:  46%|████▌     | 12/26 [00:02<00:03,  3.95it/s]

2025-01-23 05:03:41.013 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: z250 at 1993-04-05T00:00:00

Fetching ARCO data:  46%|████▌     | 12/26 [00:03<00:03,  3.95it/s]

2025-01-23 05:03:41.192 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v100m at 1993-04-05T00:00:00

Fetching ARCO data:  46%|████▌     | 12/26 [00:03<00:03,  3.95it/s]

2025-01-23 05:03:42.278 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v250 at 1993-04-05T00:00:00

Fetching ARCO data:  46%|████▌     | 12/26 [00:04<00:03,  3.95it/s]
Fetching ARCO data:  50%|█████     | 13/26 [00:04<00:04,  3.05it/s]

2025-01-23 05:03:42.509 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v1000 at 1993-04-05T00:00:00

Fetching ARCO data:  62%|██████▏   | 16/26 [00:04<00:03,  3.05it/s]

2025-01-23 05:03:42.713 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v850 at 1993-04-05T00:00:00

Fetching ARCO data:  62%|██████▏   | 16/26 [00:04<00:03,  3.05it/s]

2025-01-23 05:03:42.915 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u850 at 1993-04-05T00:00:00

Fetching ARCO data:  62%|██████▏   | 16/26 [00:05<00:03,  3.05it/s]

2025-01-23 05:03:43.110 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u250 at 1993-04-05T00:00:00

Fetching ARCO data:  62%|██████▏   | 16/26 [00:05<00:03,  3.05it/s]
Fetching ARCO data:  65%|██████▌   | 17/26 [00:05<00:02,  3.59it/s]

2025-01-23 05:03:43.315 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: z1000 at 1993-04-05T00:00:00

Fetching ARCO data:  77%|███████▋  | 20/26 [00:05<00:01,  3.59it/s]

2025-01-23 05:03:43.495 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: q500 at 1993-04-05T00:00:00

Fetching ARCO data:  77%|███████▋  | 20/26 [00:05<00:01,  3.59it/s]

2025-01-23 05:03:43.688 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u100m at 1993-04-05T00:00:00

Fetching ARCO data:  77%|███████▋  | 20/26 [00:05<00:01,  3.59it/s]

2025-01-23 05:03:43.693 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v10m at 1993-04-05T00:00:00

Fetching ARCO data:  77%|███████▋  | 20/26 [00:05<00:01,  3.59it/s]
Fetching ARCO data:  81%|████████  | 21/26 [00:06<00:01,  3.46it/s]

2025-01-23 05:03:44.541 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: tcwv at 1993-04-05T00:00:00

Fetching ARCO data:  92%|█████████▏| 24/26 [00:06<00:00,  3.46it/s]

2025-01-23 05:03:44.546 | DEBUG    | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u500 at 1993-04-05T00:00:00

Fetching ARCO data:  92%|█████████▏| 24/26 [00:06<00:00,  3.46it/s]
Fetching ARCO data:  96%|█████████▌| 25/26 [00:06<00:00,  4.78it/s]
Fetching ARCO data: 100%|██████████| 26/26 [00:06<00:00,  3.79it/s]
2025-01-23 05:03:44.934 | SUCCESS  | earth2studio.run:deterministic:106 - Fetched data from CustomDataSource
2025-01-23 05:03:44.944 | INFO     | earth2studio.run:deterministic:136 - Inference starting!

Running inference:   0%|          | 0/5 [00:00<?, ?it/s]
Running inference:  20%|██        | 1/5 [00:00<00:03,  1.23it/s]
Running inference:  40%|████      | 2/5 [00:01<00:02,  1.08it/s]
Running inference:  60%|██████    | 3/5 [00:02<00:01,  1.04it/s]
Running inference:  80%|████████  | 4/5 [00:03<00:00,  1.01it/s]
Running inference: 100%|██████████| 5/5 [00:04<00:00,  1.01s/it]
Running inference: 100%|██████████| 5/5 [00:04<00:00,  1.02it/s]
2025-01-23 05:03:49.840 | SUCCESS  | earth2studio.run:deterministic:146 - Inference complete
/
 ├── lat (720,) float64
 ├── lead_time (5,) timedelta64[h]
 ├── lon (1440,) float64
 ├── msl (1, 5, 720, 1440) float32
 ├── r500 (1, 5, 720, 1440) float32
 ├── r850 (1, 5, 720, 1440) float32
 ├── sp (1, 5, 720, 1440) float32
 ├── t250 (1, 5, 720, 1440) float32
 ├── t2m (1, 5, 720, 1440) float32
 ├── t500 (1, 5, 720, 1440) float32
 ├── t850 (1, 5, 720, 1440) float32
 ├── tcwv (1, 5, 720, 1440) float32
 ├── time (1,) datetime64[ns]
 ├── u1000 (1, 5, 720, 1440) float32
 ├── u100m (1, 5, 720, 1440) float32
 ├── u10m (1, 5, 720, 1440) float32
 ├── u250 (1, 5, 720, 1440) float32
 ├── u500 (1, 5, 720, 1440) float32
 ├── u850 (1, 5, 720, 1440) float32
 ├── v1000 (1, 5, 720, 1440) float32
 ├── v100m (1, 5, 720, 1440) float32
 ├── v10m (1, 5, 720, 1440) float32
 ├── v250 (1, 5, 720, 1440) float32
 ├── v500 (1, 5, 720, 1440) float32
 ├── v850 (1, 5, 720, 1440) float32
 ├── z1000 (1, 5, 720, 1440) float32
 ├── z250 (1, 5, 720, 1440) float32
 ├── z50 (1, 5, 720, 1440) float32
 ├── z500 (1, 5, 720, 1440) float32
 └── z850 (1, 5, 720, 1440) float32

Post Processing#

To confirm that our model is working as expected, we will plot the total column water vapor field for a few time-steps.

forecast = "1993-04-05"
variable = "tcwv"

plt.close("all")

# Create a figure and axes with the specified projection
fig, ax = plt.subplots(2, 2, figsize=(6, 4))

# Plot tcwv every 6 hours
ax[0, 0].imshow(io[variable][0, 0], vmin=0, vmax=80, cmap="magma")
ax[0, 1].imshow(io[variable][0, 1], vmin=0, vmax=80, cmap="magma")
ax[1, 0].imshow(io[variable][0, 2], vmin=0, vmax=80, cmap="magma")
ax[1, 1].imshow(io[variable][0, 3], vmin=0, vmax=80, cmap="magma")

# Set title
plt.suptitle(f"{variable} - {forecast}")
times = io["lead_time"].astype("timedelta64[h]").astype(int)
ax[0, 0].set_title(f"Lead time: {times[0]}hrs")
ax[0, 1].set_title(f"Lead time: {times[1]}hrs")
ax[1, 0].set_title(f"Lead time: {times[2]}hrs")
ax[1, 1].set_title(f"Lead time: {times[3]}hrs")

plt.savefig("outputs/custom_datasource_prediction.jpg", bbox_inches="tight")
tcwv - 1993-04-05, Lead time: 0hrs, Lead time: 6hrs, Lead time: 12hrs, Lead time: 18hrs

Total running time of the script: (0 minutes 43.443 seconds)

Gallery generated by Sphinx-Gallery