Note
Go to the end to download the full example code.
Extending Data Sources#
Implementing a custom data source
This example will demonstrate how to extend Earth2Studio by implementing a custom data source to use in a built in workflow.
In this example you will learn:
API requirements of data soruces
Implementing a custom data soruce
Custom Data Source#
Earth2Studio defines the required APIs for data sources in
earth2studio.data.base.DataSource
which requires just a call function.
For this example, we will consider extending an existing remote data source with
another atmospheric field we can calculate.
The earth2studio.data.ARCO
data source provides the ERA5 dataset in a cloud
optimized format, however it only provides specific humidity. This is a problem for
models that may use relative humidity as an input. Based on ECMWF documentation we can
calculate the relative humidity based on temperature and geo-potential.
import os
os.makedirs("outputs", exist_ok=True)
from dotenv import load_dotenv
load_dotenv() # TODO: make common example prep function
from datetime import datetime
import numpy as np
import xarray as xr
from earth2studio.data import ARCO, GFS
from earth2studio.data.utils import prep_data_inputs
from earth2studio.utils.type import TimeArray, VariableArray
class CustomDataSource:
"""Custom ARCO datasource"""
relative_humidity_ids = [
"r50",
"r100",
"r150",
"r200",
"r250",
"r300",
"r400",
"r500",
"r600",
"r700",
"r850",
"r925",
"r1000",
]
def __init__(self, cache: bool = True, verbose: bool = True):
self.arco = ARCO(cache, verbose)
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Function to get data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in IFS lexicon.
Returns
-------
xr.DataArray
"""
time, variable = prep_data_inputs(time, variable)
# Replace relative humidity with respective temperature
# and specifc humidity fields
variable_expanded = []
for v in variable:
if v in self.relative_humidity_ids:
level = int(v[1:])
variable_expanded.extend([f"t{level}", f"q{level}"])
else:
variable_expanded.append(v)
variable_expanded = list(set(variable_expanded))
# Fetch from ARCO
da_exp = self.arco(time, variable_expanded)
# Calculate relative humidity when needed
arrays = []
for v in variable:
if v in self.relative_humidity_ids:
level = int(v[1:])
t = da_exp.sel(variable=f"t{level}").values
q = da_exp.sel(variable=f"q{level}").values
rh = self.calc_relative_humdity(t, q, 100 * level)
arrays.append(rh)
else:
arrays.append(da_exp.sel(variable=v).values)
da = xr.DataArray(
data=np.stack(arrays, axis=1),
dims=["time", "variable", "lat", "lon"],
coords=dict(
time=da_exp.coords["time"].values,
variable=np.array(variable),
lat=da_exp.coords["lat"].values,
lon=da_exp.coords["lon"].values,
),
)
return da
def calc_relative_humdity(
self, temperature: np.array, specific_humidity: np.array, pressure: float
) -> np.array:
"""Relative humidity calculation
Parameters
----------
temperature : np.array
Temperature field (K)
specific_humidity : np.array
Specific humidity field (g.kg-1)
pressure : float
Pressure (Pa)
Returns
-------
np.array
"""
epsilon = 0.621981
p = pressure
q = specific_humidity
t = temperature
e = (p * q * (1.0 / epsilon)) / (1 + q * (1.0 / (epsilon) - 1))
es_w = 611.21 * np.exp(17.502 * (t - 273.16) / (t - 32.19))
es_i = 611.21 * np.exp(22.587 * (t - 273.16) / (t + 0.7))
alpha = np.clip((t - 250.16) / (273.16 - 250.16), 0, 1.2) ** 2
es = alpha * es_w + (1 - alpha) * es_i
rh = 100 * e / es
return rh
__call__()
API#
The call function is the main API of data source which return the Xarray data array with the requested data. For this custom data source we intercept relative humidity variables, replace them with temperature and specific humidity requests then calculate the relative humidity from these fields. Note that the ARCO data source is handling the remote complexity, we are just manipulating Numpy arrays
calc_relative_humdity()
#
Based on the calculations ECMWF uses in their IFS numerical simulator which accounts for estimating the water vapor and ice present in the atmosphere.
Note
See reference, equation 7.98 onwards: https://www.ecmwf.int/en/elibrary/81370-ifs-documentation-cy48r1-part-iv-physical-processes
Verification#
Before plugging this into our workflow, let’s quickly verify our data source is consistent with when GFS provides for relative humidity.
ds = CustomDataSource()
da_custom = ds(time=datetime(2022, 1, 1, hour=0), variable=["r500"])
ds_gfs = GFS()
da_gfs = ds_gfs(time=datetime(2022, 1, 1, hour=0), variable=["r500"])
print(da_custom)
Fetching ARCO data: 0%| | 0/2 [00:00<?, ?it/s]
2025-01-23 05:03:21.814 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: t500 at 2022-01-01T00:00:00
Fetching ARCO data: 0%| | 0/2 [00:00<?, ?it/s]
2025-01-23 05:03:22.012 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: q500 at 2022-01-01T00:00:00
Fetching ARCO data: 0%| | 0/2 [00:00<?, ?it/s]
Fetching ARCO data: 50%|█████ | 1/2 [00:00<00:00, 2.45it/s]
Fetching ARCO data: 100%|██████████| 2/2 [00:00<00:00, 4.88it/s]
2025-01-23 05:03:22.278 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:209 - Fetching GFS index file: 2022-01-01 00:00:00 lead 0:00:00
Fetching GFS for 2022-01-01 00:00:00: 0%| | 0/1 [00:00<?, ?it/s]
2025-01-23 05:03:22.282 | DEBUG | earth2studio.data.gfs:_fetch_gfs_dataarray:255 - Fetching GFS grib file for variable: r500 at 2022-01-01 00:00:00_0:00:00
Fetching GFS for 2022-01-01 00:00:00: 0%| | 0/1 [00:00<?, ?it/s]
Fetching GFS for 2022-01-01 00:00:00: 100%|██████████| 1/1 [00:00<00:00, 31.36it/s]
<xarray.DataArray (time: 1, variable: 1, lat: 721, lon: 1440)> Size: 8MB
array([[[[ 28.01413468, 28.01413468, 28.01413468, ..., 28.01413468,
28.01413468, 28.01413468],
[ 29.75579658, 29.75314258, 29.81191126, ..., 29.77167483,
29.76636404, 29.76376009],
[ 31.28908409, 31.28075462, 31.27237381, ..., 31.31699177,
31.30860053, 31.29746972],
...,
[ 97.36539026, 97.35685903, 97.39891442, ..., 97.34011022,
97.32321665, 97.31468906],
[ 97.50236648, 97.49380369, 97.54603948, ..., 97.5280597 ,
97.51949448, 97.51949448],
[102.58417544, 102.58417544, 102.58417544, ..., 102.58417544,
102.58417544, 102.58417544]]]])
Coordinates:
* time (time) datetime64[ns] 8B 2022-01-01
* variable (variable) <U4 16B 'r500'
* lat (lat) float64 6kB 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
* lon (lon) float64 12kB 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
fig, ax = plt.subplots(
1,
2,
figsize=(10, 3),
subplot_kw={"projection": ccrs.Mollweide()},
constrained_layout=True,
)
ax[0].imshow(
da_custom.sel(variable="r500")[0], transform=ccrs.PlateCarree(), vmin=0, vmax=100
)
ax[1].imshow(
da_gfs.sel(variable="r500")[0], transform=ccrs.PlateCarree(), vmin=0, vmax=100
)
ax[0].set_title("Custom ARCO")
ax[1].set_title("GFS")
plt.suptitle("r500", fontsize=24)
cbar = plt.cm.ScalarMappable()
cbar.set_array(da_custom.sel(variable="r500")[0])
cbar.set_clim(0, 100)
cbar = fig.colorbar(cbar, ax=ax[-1], orientation="vertical", shrink=0.8)
plt.savefig("outputs/custom_datasource_gfs_versus_custom.jpg")
Execute Workflow#
We will use this custom data source to run deterministic inference with a model that
requires relative humidity. earth2studio.models.px.FCN
is one such model. Since
we are using ARCO, we can run inference for a time quite far back in time.
Let’s instantiate the components needed.
Prognostic Model: Use the built in FourCastNet Model
earth2studio.models.px.FCN
.Datasource: Custom data source above
IO Backend: Save the outputs into a Zarr store
earth2studio.io.ZarrBackend
.
from dotenv import load_dotenv
load_dotenv() # TODO: make common example prep function
import earth2studio.run as run
from earth2studio.io import ZarrBackend
from earth2studio.models.px import FCN
package = FCN.load_default_package()
model = FCN.load_model(package)
# Create the data source
data = CustomDataSource()
# Create the IO handler, store in memory
io = ZarrBackend()
nsteps = 4
io = run.deterministic(["1993-04-05"], nsteps, model, data, io)
print(io.root.tree())
2025-01-23 05:03:37.838 | INFO | earth2studio.run:deterministic:75 - Running simple workflow!
2025-01-23 05:03:37.838 | INFO | earth2studio.run:deterministic:82 - Inference device: cuda
Fetching ARCO data: 0%| | 0/26 [00:00<?, ?it/s]
2025-01-23 05:03:37.889 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: t500 at 1993-04-05T00:00:00
Fetching ARCO data: 0%| | 0/26 [00:00<?, ?it/s]
2025-01-23 05:03:38.087 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: msl at 1993-04-05T00:00:00
Fetching ARCO data: 0%| | 0/26 [00:00<?, ?it/s]
2025-01-23 05:03:38.813 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v500 at 1993-04-05T00:00:00
Fetching ARCO data: 0%| | 0/26 [00:00<?, ?it/s]
2025-01-23 05:03:39.017 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: t2m at 1993-04-05T00:00:00
Fetching ARCO data: 0%| | 0/26 [00:01<?, ?it/s]
Fetching ARCO data: 4%|▍ | 1/26 [00:01<00:44, 1.78s/it]
2025-01-23 05:03:39.672 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u1000 at 1993-04-05T00:00:00
Fetching ARCO data: 15%|█▌ | 4/26 [00:01<00:39, 1.78s/it]
2025-01-23 05:03:39.869 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u10m at 1993-04-05T00:00:00
Fetching ARCO data: 15%|█▌ | 4/26 [00:01<00:39, 1.78s/it]
2025-01-23 05:03:39.874 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: z850 at 1993-04-05T00:00:00
Fetching ARCO data: 15%|█▌ | 4/26 [00:01<00:39, 1.78s/it]
2025-01-23 05:03:40.064 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: z500 at 1993-04-05T00:00:00
Fetching ARCO data: 15%|█▌ | 4/26 [00:02<00:39, 1.78s/it]
Fetching ARCO data: 19%|█▉ | 5/26 [00:02<00:08, 2.57it/s]
2025-01-23 05:03:40.255 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: t850 at 1993-04-05T00:00:00
Fetching ARCO data: 31%|███ | 8/26 [00:02<00:06, 2.57it/s]
2025-01-23 05:03:40.441 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: z50 at 1993-04-05T00:00:00
Fetching ARCO data: 31%|███ | 8/26 [00:02<00:06, 2.57it/s]
2025-01-23 05:03:40.620 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: t250 at 1993-04-05T00:00:00
Fetching ARCO data: 31%|███ | 8/26 [00:02<00:06, 2.57it/s]
2025-01-23 05:03:40.808 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: sp at 1993-04-05T00:00:00
Fetching ARCO data: 31%|███ | 8/26 [00:02<00:06, 2.57it/s]
Fetching ARCO data: 35%|███▍ | 9/26 [00:02<00:04, 3.95it/s]
2025-01-23 05:03:40.820 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: q850 at 1993-04-05T00:00:00
Fetching ARCO data: 46%|████▌ | 12/26 [00:02<00:03, 3.95it/s]
2025-01-23 05:03:41.013 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: z250 at 1993-04-05T00:00:00
Fetching ARCO data: 46%|████▌ | 12/26 [00:03<00:03, 3.95it/s]
2025-01-23 05:03:41.192 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v100m at 1993-04-05T00:00:00
Fetching ARCO data: 46%|████▌ | 12/26 [00:03<00:03, 3.95it/s]
2025-01-23 05:03:42.278 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v250 at 1993-04-05T00:00:00
Fetching ARCO data: 46%|████▌ | 12/26 [00:04<00:03, 3.95it/s]
Fetching ARCO data: 50%|█████ | 13/26 [00:04<00:04, 3.05it/s]
2025-01-23 05:03:42.509 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v1000 at 1993-04-05T00:00:00
Fetching ARCO data: 62%|██████▏ | 16/26 [00:04<00:03, 3.05it/s]
2025-01-23 05:03:42.713 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v850 at 1993-04-05T00:00:00
Fetching ARCO data: 62%|██████▏ | 16/26 [00:04<00:03, 3.05it/s]
2025-01-23 05:03:42.915 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u850 at 1993-04-05T00:00:00
Fetching ARCO data: 62%|██████▏ | 16/26 [00:05<00:03, 3.05it/s]
2025-01-23 05:03:43.110 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u250 at 1993-04-05T00:00:00
Fetching ARCO data: 62%|██████▏ | 16/26 [00:05<00:03, 3.05it/s]
Fetching ARCO data: 65%|██████▌ | 17/26 [00:05<00:02, 3.59it/s]
2025-01-23 05:03:43.315 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: z1000 at 1993-04-05T00:00:00
Fetching ARCO data: 77%|███████▋ | 20/26 [00:05<00:01, 3.59it/s]
2025-01-23 05:03:43.495 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: q500 at 1993-04-05T00:00:00
Fetching ARCO data: 77%|███████▋ | 20/26 [00:05<00:01, 3.59it/s]
2025-01-23 05:03:43.688 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u100m at 1993-04-05T00:00:00
Fetching ARCO data: 77%|███████▋ | 20/26 [00:05<00:01, 3.59it/s]
2025-01-23 05:03:43.693 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: v10m at 1993-04-05T00:00:00
Fetching ARCO data: 77%|███████▋ | 20/26 [00:05<00:01, 3.59it/s]
Fetching ARCO data: 81%|████████ | 21/26 [00:06<00:01, 3.46it/s]
2025-01-23 05:03:44.541 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: tcwv at 1993-04-05T00:00:00
Fetching ARCO data: 92%|█████████▏| 24/26 [00:06<00:00, 3.46it/s]
2025-01-23 05:03:44.546 | DEBUG | earth2studio.data.arco:fetch_array:227 - Fetching ARCO zarr array for variable: u500 at 1993-04-05T00:00:00
Fetching ARCO data: 92%|█████████▏| 24/26 [00:06<00:00, 3.46it/s]
Fetching ARCO data: 96%|█████████▌| 25/26 [00:06<00:00, 4.78it/s]
Fetching ARCO data: 100%|██████████| 26/26 [00:06<00:00, 3.79it/s]
2025-01-23 05:03:44.934 | SUCCESS | earth2studio.run:deterministic:106 - Fetched data from CustomDataSource
2025-01-23 05:03:44.944 | INFO | earth2studio.run:deterministic:136 - Inference starting!
Running inference: 0%| | 0/5 [00:00<?, ?it/s]
Running inference: 20%|██ | 1/5 [00:00<00:03, 1.23it/s]
Running inference: 40%|████ | 2/5 [00:01<00:02, 1.08it/s]
Running inference: 60%|██████ | 3/5 [00:02<00:01, 1.04it/s]
Running inference: 80%|████████ | 4/5 [00:03<00:00, 1.01it/s]
Running inference: 100%|██████████| 5/5 [00:04<00:00, 1.01s/it]
Running inference: 100%|██████████| 5/5 [00:04<00:00, 1.02it/s]
2025-01-23 05:03:49.840 | SUCCESS | earth2studio.run:deterministic:146 - Inference complete
/
├── lat (720,) float64
├── lead_time (5,) timedelta64[h]
├── lon (1440,) float64
├── msl (1, 5, 720, 1440) float32
├── r500 (1, 5, 720, 1440) float32
├── r850 (1, 5, 720, 1440) float32
├── sp (1, 5, 720, 1440) float32
├── t250 (1, 5, 720, 1440) float32
├── t2m (1, 5, 720, 1440) float32
├── t500 (1, 5, 720, 1440) float32
├── t850 (1, 5, 720, 1440) float32
├── tcwv (1, 5, 720, 1440) float32
├── time (1,) datetime64[ns]
├── u1000 (1, 5, 720, 1440) float32
├── u100m (1, 5, 720, 1440) float32
├── u10m (1, 5, 720, 1440) float32
├── u250 (1, 5, 720, 1440) float32
├── u500 (1, 5, 720, 1440) float32
├── u850 (1, 5, 720, 1440) float32
├── v1000 (1, 5, 720, 1440) float32
├── v100m (1, 5, 720, 1440) float32
├── v10m (1, 5, 720, 1440) float32
├── v250 (1, 5, 720, 1440) float32
├── v500 (1, 5, 720, 1440) float32
├── v850 (1, 5, 720, 1440) float32
├── z1000 (1, 5, 720, 1440) float32
├── z250 (1, 5, 720, 1440) float32
├── z50 (1, 5, 720, 1440) float32
├── z500 (1, 5, 720, 1440) float32
└── z850 (1, 5, 720, 1440) float32
Post Processing#
To confirm that our model is working as expected, we will plot the total column water vapor field for a few time-steps.
forecast = "1993-04-05"
variable = "tcwv"
plt.close("all")
# Create a figure and axes with the specified projection
fig, ax = plt.subplots(2, 2, figsize=(6, 4))
# Plot tcwv every 6 hours
ax[0, 0].imshow(io[variable][0, 0], vmin=0, vmax=80, cmap="magma")
ax[0, 1].imshow(io[variable][0, 1], vmin=0, vmax=80, cmap="magma")
ax[1, 0].imshow(io[variable][0, 2], vmin=0, vmax=80, cmap="magma")
ax[1, 1].imshow(io[variable][0, 3], vmin=0, vmax=80, cmap="magma")
# Set title
plt.suptitle(f"{variable} - {forecast}")
times = io["lead_time"].astype("timedelta64[h]").astype(int)
ax[0, 0].set_title(f"Lead time: {times[0]}hrs")
ax[0, 1].set_title(f"Lead time: {times[1]}hrs")
ax[1, 0].set_title(f"Lead time: {times[2]}hrs")
ax[1, 1].set_title(f"Lead time: {times[3]}hrs")
plt.savefig("outputs/custom_datasource_prediction.jpg", bbox_inches="tight")
Total running time of the script: (0 minutes 43.443 seconds)