# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import functools
import inspect
import os
import pathlib
import shutil
from datetime import datetime
from importlib.metadata import version
from typing import Literal
import gcsfs
import nest_asyncio
import numpy as np
import xarray as xr
import zarr
from loguru import logger
from tqdm.asyncio import tqdm
from earth2studio.data.utils import (
AsyncCachingFileSystem,
datasource_cache_root,
prep_data_inputs,
)
from earth2studio.lexicon import WB2ClimatetologyLexicon, WB2Lexicon
from earth2studio.utils.type import TimeArray, VariableArray
class _WB2Base:
"""Base class for weather bench 2 ERA5 datasets"""
WB2_ERA5_LAT = np.empty(0)
WB2_ERA5_LON = np.empty(0)
def __init__(
self,
wb2_zarr_store: str,
wb2_product: str = "era5",
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
self._zarr_store_name = wb2_zarr_store
self._product = wb2_product
self._cache = cache
self._verbose = verbose
self.async_timeout = async_timeout
# Check Zarr version and use appropriate method
try:
zarr_version = version("zarr")
zarr_major_version = int(zarr_version.split(".")[0])
except Exception:
# Fallback to older method if version check fails
zarr_major_version = 2 # Assume older version if we can't determine
# Only zarr 3.0 support
if zarr_major_version < 3:
raise ModuleNotFoundError("Zarr 3.0 and above support only")
# Check to see if there is a running loop (initialized in async)
try:
nest_asyncio.apply() # Monkey patch asyncio to work in notebooks
loop = asyncio.get_running_loop()
loop.run_until_complete(self._async_init())
except RuntimeError:
# Else we assume that async calls will be used which in that case
# we will init the group in the call function when we have the loop
self.zarr_group = None
self.level_coords = None
async def _async_init(self) -> None:
"""Async initialization of zarr group
Note
----
Async fsspec expects initialization inside of the execution loop
"""
fs = gcsfs.GCSFileSystem(
cache_timeout=-1,
token="anon", # noqa: S106 # nosec B106
access="read_only",
block_size=8**20,
asynchronous=True,
skip_instance_cache=True,
)
fs._loop = asyncio.get_event_loop()
if self._cache:
cache_options = {
"cache_storage": self.cache,
"expiry_time": 31622400, # 1 year
}
fs = AsyncCachingFileSystem(fs=fs, **cache_options, asynchronous=True)
zstore = zarr.storage.FsspecStore(
fs,
path=f"/weatherbench2/datasets/{self._product}/{self._zarr_store_name}",
)
self.zarr_group = await zarr.api.asynchronous.open(store=zstore, mode="r")
self.level_coords = await (await self.zarr_group.get("level")).getitem( # type: ignore
slice(None)
)
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Function to get data
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the WB2 lexicon.
Returns
-------
xr.DataArray
ERA5 weather data array from weather bench 2
"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# If no event loop exists, create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if self.zarr_group is None:
loop.run_until_complete(self._async_init())
xr_array = loop.run_until_complete(
asyncio.wait_for(self.fetch(time, variable), timeout=self.async_timeout)
)
# Delete cache if needed
if not self._cache:
shutil.rmtree(self.cache)
return xr_array
async def fetch(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the WB2 lexicon.
Returns
-------
xr.DataArray
ERA5 weather data array from weather bench 2
"""
if self.zarr_group is None:
raise ValueError(
"Zarr group is not initialized! If you are calling this \
function directly make sure the data source is initialized inside the async \
loop!"
)
time, variable = prep_data_inputs(time, variable)
# Create cache dir if doesnt exist
pathlib.Path(self.cache).mkdir(parents=True, exist_ok=True)
# Make sure input time is valid
self._validate_time(time)
xr_array = xr.DataArray(
data=np.empty(
(
len(time),
len(variable),
len(self.WB2_ERA5_LAT),
len(self.WB2_ERA5_LON),
)
),
dims=["time", "variable", "lat", "lon"],
coords={
"time": time,
"variable": variable,
"lat": self.WB2_ERA5_LAT,
"lon": self.WB2_ERA5_LON,
},
)
args = [
(t, i, v, j) for j, v in enumerate(variable) for i, t in enumerate(time)
]
func_map = map(functools.partial(self.fetch_wrapper, xr_array=xr_array), args)
# Launch all fetch requests
await tqdm.gather(
*func_map, desc="Fetching WB2 data", disable=(not self._verbose)
)
return xr_array
async def fetch_wrapper(
self,
e: tuple[datetime, int, str, int],
xr_array: xr.DataArray,
) -> None:
"""Small wrapper to pack arrays into the DataArray"""
out = await self.fetch_array(e[0], e[2])
xr_array[e[1], e[3]] = out
async def fetch_array(self, time: datetime, variable: str) -> np.ndarray:
"""Fetches requested array from remote store
Parameters
----------
time : datetime
Time to fetch
variable : str
Variable to fetch
Returns
-------
np.ndarray
Data
"""
if self.zarr_group is None:
raise ValueError("Zarr group is not initialized")
# Get time index (vanilla zarr doesnt support date indices)
time_index = self._get_time_index(time)
logger.debug(
f"Fetching WB2 zarr array for variable: {variable} at {time.isoformat()}"
)
try:
wb2_name, modifier = WB2Lexicon[variable] # type: ignore
except KeyError as e:
logger.error(f"variable id {variable} not found in WB2 lexicon")
raise e
wb2_name, level = wb2_name.split("::")
zarr_array = await self.zarr_group.get(wb2_name)
shape = zarr_array.shape
# Static variables
if len(shape) == 2:
data = await zarr_array.getitem(slice(None))
output = modifier(data)
# Surface variable
elif len(shape) == 3:
data = await zarr_array.getitem(time_index)
output = modifier(data)
# Atmospheric variable
else:
# Load levels coordinate system from Zarr store and check
level_index = np.searchsorted(self.level_coords, int(level))
data = await zarr_array.getitem((time_index, level_index))
output = modifier(data)
# Some WB2 data Zarr stores are saved [lon, lat] with lat flipped
# Namely its the lower resolutions ones with this issue
if output.shape[0] > output.shape[1]:
output = np.flip(output, axis=-1).T
return output
@property
def cache(self) -> str:
"""Get the appropriate cache location."""
cache_location = os.path.join(datasource_cache_root(), "wb2era5")
if not self._cache:
cache_location = os.path.join(cache_location, "tmp_wb2era5")
return cache_location
@classmethod
def _validate_time(cls, times: list[datetime]) -> None:
"""Verify if date time is valid for Weatherbench 2 ERA5
Parameters
----------
times : list[datetime]
list of date times to fetch data
"""
for time in times:
if not (time - datetime(1900, 1, 1)).total_seconds() % 21600 == 0:
raise ValueError(
f"Requested date time {time} needs to be 6 hour interval for Weatherbench2 ERA5"
)
if time < datetime(year=1959, month=1, day=1):
raise ValueError(
f"Requested date time {time} needs to be after January 1st, 1959 for Weatherbench2 ERA5"
)
if time > datetime(year=2023, month=1, day=10, hour=18):
raise ValueError(
f"Requested date time {time} needs to be before January 11th, 2023 for Weatherbench2 ERA5"
)
@classmethod
def _get_time_index(cls, time: datetime) -> int:
"""Little index converter to go from datetime to integer index for hour
and day of year.
Parameters
----------
time : datetime
Input date time
Returns
-------
int
hour coordinate index of data
int
day_of_year coordinate index of data
"""
start_date = datetime(year=1959, month=1, day=1)
duration = time - start_date
return int(divmod(duration.total_seconds(), 21600)[0])
[docs]
class WB2ERA5(_WB2Base):
"""
ERA5 reanalysis data with several derived variables on a 0.25 degree lat-lon grid
from 1959 to 2023 (incl) to 6 hour intervals on 13 pressure levels. Provided by the
WeatherBench2 data repository.
Parameters
----------
cache : bool, optional
Cache data source on local memory, by default True
verbose : bool, optional
Print download progress, by default True
async_timeout : int, optional
Time in sec after which download will be cancelled if not finished successfully,
by default 600
Warning
-------
This is a remote data source and can potentially download a large amount of data
to your local machine for large requests.
Note
----
Additional information on the data repository can be referenced here:
- https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5
- https://arxiv.org/abs/2308.15560
"""
WB2_ERA5_LAT = np.linspace(90, -90, 721)
WB2_ERA5_LON = np.linspace(0, 359.75, 1440)
def __init__(
self,
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
super().__init__(
wb2_zarr_store="1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr",
cache=cache,
verbose=verbose,
async_timeout=async_timeout,
)
[docs]
class WB2ERA5_121x240(_WB2Base):
"""
ERA5 reanalysis data with several derived variables down sampled to a 1.5 degree
lat-lon grid from 1959 to 2023 (incl) to 6 hour intervals on 13 pressure levels.
Provided by the WeatherBench2 data repository.
Parameters
----------
cache : bool, optional
Cache data source on local memory, by default True
verbose : bool, optional
Print download progress, by default True
async_timeout : int, optional
Time in sec after which download will be cancelled if not finished successfully,
by default 600
Warning
-------
This is a remote data source and can potentially download a large amount of data
to your local machine for large requests.
Note
----
Additional information on the data repository can be referenced here:
- https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5
- https://arxiv.org/abs/2308.15560
"""
WB2_ERA5_LAT = np.linspace(90, -90, 121)
WB2_ERA5_LON = np.linspace(0, 359.5, 240)
def __init__(
self,
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
super().__init__(
wb2_zarr_store="1959-2023_01_10-6h-240x121_equiangular_with_poles_conservative.zarr",
cache=cache,
verbose=verbose,
async_timeout=async_timeout,
)
[docs]
class WB2ERA5_32x64(_WB2Base):
"""
ERA5 reanalysis data with several derived variables down sampled to a 5.625 degree
lat-lon grid from 1959 to 2023 (incl) to 6 hour intervals on 13 pressure levels.
Provided by the WeatherBench2 data repository.
Parameters
----------
cache : bool, optional
Cache data source on local memory, by default True
verbose : bool, optional
Print download progress, by default True
async_timeout : int, optional
Time in sec after which download will be cancelled if not finished successfully,
by default 600
Warning
-------
This is a remote data source and can potentially download a large amount of data
to your local machine for large requests.
Note
----
Additional information on the data repository can be referenced here:
- https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5
- https://arxiv.org/abs/2308.15560
"""
WB2_ERA5_LAT = np.linspace(-87.1875, 87.1875, 32)
WB2_ERA5_LON = np.linspace(0, 360, 64, endpoint=False)
def __init__(
self,
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
super().__init__(
"1959-2023_01_10-6h-64x32_equiangular_conservative.zarr",
cache=cache,
verbose=verbose,
async_timeout=async_timeout,
)
ClimatologyZarrStore = Literal[
"1990-2017_6h_1440x721.zarr",
"1990-2017_6h_512x256_equiangular_conservative.zarr",
"1990-2017_6h_240x121_equiangular_with_poles_conservative.zarr",
"1990-2017_6h_64x32_equiangular_conservative.zarr",
"1990-2019_6h_1440x721.zarr",
"1990-2019_6h_512x256_equiangular_conservative.zarr",
"1990-2019_6h_240x121_equiangular_with_poles_conservative.zarr",
"1990-2019_6h_64x32_equiangular_conservative.zarr",
]
[docs]
class WB2Climatology(_WB2Base):
"""
Climatology provided by WeatherBench2,
| A climatology is used for e.g. computing anomaly metrics such as the ACC.
| For WeatherBench 2, the climatology was computed using a running window for
| smoothing (see paper and script) for each day of year and sixth hour of day.
| We have computed climatologies for 1990-2017 and 1990-2019.
Parameters
----------
climatology_zarr_store : ClimatologyZarrStore, optional
Stores within `gs://weatherbench2/datasets/era5-hourly-climatology/` to select
As of 05/03/2024 this is the following list of available files:
- 1990-2017_6h_1440x721.zarr
- 1990-2017_6h_512x256_equiangular_conservative.zarr
- 1990-2017_6h_240x121_equiangular_with_poles_conservative.zarr
- 1990-2017_6h_64x32_equiangular_conservative.zarr
- 1990-2019_6h_1440x721.zarr
- 1990-2019_6h_512x256_equiangular_conservative.zarr
- 1990-2019_6h_240x121_equiangular_with_poles_conservative.zarr
- 1990-2019_6h_64x32_equiangular_conservative.zarr
by default `1990-2019_6h_1440x721.zarr`
cache : bool, optional
Cache data source on local memory, by default True
verbose : bool, optional
Print download progress, by default True
async_timeout : int, optional
Time in sec after which download will be cancelled if not finished successfully,
by default 600
Warning
-------
This is a remote data source and can potentially download a large amount of data
to your local machine for large requests.
Note
----
Additional information on the data repository can be referenced here:
- https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5-climatology
- https://arxiv.org/abs/2308.15560
"""
def __init__(
self,
climatology_zarr_store: ClimatologyZarrStore = "1990-2017_6h_1440x721.zarr",
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
super().__init__(
climatology_zarr_store,
wb2_product="era5-hourly-climatology",
cache=cache,
verbose=verbose,
async_timeout=async_timeout,
)
[docs]
async def fetch(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the WB2 Climatology lexicon.
Returns
-------
xr.DataArray
ERA5 weather data array from weather bench 2
"""
if self.zarr_group is None:
raise ValueError(
"Zarr group is not initialized! If you are calling this \
function directly make sure the data source is initialized inside the async \
loop!"
)
time, variable = prep_data_inputs(time, variable)
# Create cache dir if doesnt exist
pathlib.Path(self.cache).mkdir(parents=True, exist_ok=True)
# Make sure input time is valid
self._validate_time(time)
# Before anything wait until the group gets opened
if inspect.isawaitable(self.zarr_group):
self.zarr_group = await self.zarr_group
WB2_CLIMATE_LAT = await (await self.zarr_group.get("latitude")).getitem(
slice(None)
)
WB2_CLIMATE_LON = await (await self.zarr_group.get("longitude")).getitem(
slice(None)
)
xr_array = xr.DataArray(
data=np.empty(
(len(time), len(variable), len(WB2_CLIMATE_LAT), len(WB2_CLIMATE_LON))
),
dims=["time", "variable", "lat", "lon"],
coords={
"time": time,
"variable": variable,
"lat": WB2_CLIMATE_LAT[:],
"lon": WB2_CLIMATE_LON[:],
},
)
args = [
(t, i, v, j) for j, v in enumerate(variable) for i, t in enumerate(time)
]
func_map = map(functools.partial(self.fetch_wrapper, xr_array=xr_array), args)
self.level_coords = await (await self.zarr_group.get("level")).getitem(
slice(None)
)
# Launch all fetch requests
await tqdm.gather(
*func_map, desc="Fetching WB2 climatology data", disable=(not self._verbose)
)
return xr_array
async def fetch_array(self, time: datetime, variable: str) -> np.ndarray:
"""Fetches requested array from remote store
Parameters
----------
time : datetime
Time to fetch
variable : str
Variable to fetch
Returns
-------
np.ndarray
Data
"""
if self.zarr_group is None:
raise ValueError("Zarr group is not initialized")
# Get time index (vanilla zarr doesnt support date indices)
hour_index, day_of_year_index = self._get_time_index(time)
logger.debug(
f"Fetching WB2 climatology zarr array for variable: {variable} at {time.isoformat()}"
)
try:
wb2_name, modifier = WB2ClimatetologyLexicon[variable] # type: ignore
except KeyError as e:
logger.error(f"variable id {variable} not found in WB2 lexicon")
raise e
wb2_name, level = wb2_name.split("::")
zarr_array = await self.zarr_group.get(wb2_name)
shape = zarr_array.shape
# Surface variable [hour idx (6 hour), day index, lat, lon]
if len(shape) == 4:
data = await zarr_array.getitem((hour_index, day_of_year_index))
output = modifier(data)
# Atmospheric variable [hour idx (6 hour), day index, level lat, lon]
else:
# Load levels coordinate system from Zarr store and check
level_index = np.searchsorted(self.level_coords, int(level))
data = await zarr_array.getitem(
(hour_index, day_of_year_index, level_index)
)
output = modifier(data)
return output
@classmethod
def _get_time_index(cls, time: datetime) -> tuple[int, int]: # type: ignore[override]
"""Little index converter to go from datetime to integer index for hour
and day of year.
Parameters
----------
time : datetime
Input date time
Returns
-------
int
hour coordinate index of data
int
day_of_year coordinate index of data
"""
tt = time.timetuple()
return tt.tm_hour // 6, tt.tm_yday - 1
@classmethod
def _validate_time(cls, times: list[datetime]) -> None:
"""Verify if date time is valid for WeatherBench 2 climatology.
Parameters
----------
times : list[datetime]
list of date times to fetch data
"""
for time in times:
if not (time - datetime(1900, 1, 1)).total_seconds() % 21600 == 0:
raise ValueError(
f"Requested date time {time} needs to be 6 hour interval for WeatherBench 2 climatology"
)