# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import functools
import hashlib
import os
import pathlib
import shutil
import uuid
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Literal
import numpy as np
import pygrib
import xarray as xr
from loguru import logger
from tqdm.asyncio import tqdm
from earth2studio.data.utils import datasource_cache_root, prep_forecast_inputs
from earth2studio.lexicon import AIFSLexicon, IFSLexicon
from earth2studio.lexicon.ecmwf import ECMWFOpenDataLexicon
from earth2studio.utils.imports import (
OptionalDependencyFailure,
check_optional_dependencies,
)
from earth2studio.utils.type import LeadTimeArray, TimeArray, VariableArray
try:
import ecmwf.opendata as opendata
except ImportError:
OptionalDependencyFailure("data")
opendata = None
logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)
@dataclass
class ECMWFOpenDataAsyncTask:
"""Small helper struct for Async tasks"""
data_array_indices: tuple[int, int, int]
time: datetime
lead_time: timedelta
variable: str
levtype: str
level: str | list[str]
modifier: Callable
@check_optional_dependencies()
class _ECMWFOpenDataSource(ABC):
"""Baseclass for ECMWF open data sources
Parameters
----------
source : str, optional
Data source to fetch data from. For possible options refer to ECMWF's open data
Python SDK, by default "aws".
model: str, optional
Model to fetch data for, by default "ifs".
fc_type: str, optional
Forecast type (e.g., deterministic, control, perturbed). For possible options
refer to ECMWF's open data Python SDK, by default "fc".
members: list[int], optional
List of ensemble members used if perturb forcast is requested. By default [0].
cache : bool, optional
Cache data source in local memory, by default True.
verbose : bool, optional
Print download progress, by default True.
async_timeout: int, optional
Time in seconds after which the download will be cancelled if not finished
successfully, by default 600.
"""
LAT = np.linspace(90, -90, 721)
LON = np.linspace(0, 359.75, 1440)
LEXICON: type[ECMWFOpenDataLexicon]
def __init__(
self,
source: Literal["aws", "ecmwf", "azure"] = "aws",
model: Literal["ifs", "aifs-single", "aifs-ens"] = "ifs",
fc_type: Literal["fc", "cf", "pf"] = "fc",
members: list[int] = [0],
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
# Optional import not installed error
if opendata is None:
raise ImportError(
"ecmwf-opendata is not installed, install manually or using `pip install earth2studio[data]`"
)
self.client = opendata.Client(source=source, model=model)
self._fc_type = fc_type
self._members = members
self._cache = cache
self._tmp_cache_hash: str | None = None
self._verbose = verbose
self.async_timeout = async_timeout
# Model name for caching and logging
if model == "ifs":
if fc_type == "fc":
self._model = "IFS"
else:
self._model = "IFS-ENS"
elif model == "aifs-single":
self._model = "AIFS"
else:
self._model = "AIFS-ENS"
@abstractmethod
def __call__( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Retrieve ECMWF data. The child class should override this"""
pass
@abstractmethod
async def fetch( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data, the child class should over ride this and call/"""
pass
def _call( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Retrieve ECMWF data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
lead_time: timedelta | list[timedelta] | LeadTimeArray
Forecast lead times to fetch.
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Note
----
For peturbed data from ensemble models, the returned data array will have an
extra `sample` dimension added to it.
Returns
-------
xr.DataArray
ECMWF weather data array
"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# If no event loop exists, create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
xr_array = loop.run_until_complete(
asyncio.wait_for(
self._fetch(time, lead_time, variable), timeout=self.async_timeout
)
)
return xr_array
async def _fetch( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async method to retrieve ECMWF data."""
time, lead_time, variable = prep_forecast_inputs(time, lead_time, variable)
# Create cache dir if doesnt exist
pathlib.Path(self.cache).mkdir(parents=True, exist_ok=True)
# Make sure input time is valid
self._validate_time(time)
self._validate_leadtime(time, lead_time)
# Pre-allocate full array (could be made more efficient)
if not self._fc_type == "pf":
xr_array = xr.DataArray(
data=np.zeros(
(
len(time),
len(lead_time),
len(variable),
len(self.LAT),
len(self.LON),
)
),
dims=["time", "lead_time", "variable", "lat", "lon"],
coords={
"time": time,
"lead_time": lead_time,
"variable": variable,
"lat": self.LAT,
"lon": self.LON,
},
)
else:
xr_array = xr.DataArray(
data=np.zeros(
(
len(time),
len(lead_time),
len(variable),
len(self._members),
len(self.LAT),
len(self.LON),
)
),
dims=["time", "lead_time", "variable", "sample", "lat", "lon"],
coords={
"time": time,
"lead_time": lead_time,
"variable": variable,
"sample": np.array(self._members),
"lat": self.LAT,
"lon": self.LON,
},
)
async_tasks = await self._create_tasks(time, lead_time, variable)
func_map = map(
functools.partial(self.fetch_wrapper, xr_array=xr_array), async_tasks
)
await tqdm.gather(
*func_map,
desc=f"Fetching {self._model} data",
disable=(not self._verbose),
)
# Delete cache if needed
if not self._cache:
shutil.rmtree(self.cache)
return xr_array
async def _create_tasks(
self,
time: list[datetime],
lead_time: list[timedelta],
variable: list[str],
) -> list[ECMWFOpenDataAsyncTask]:
"""Create download tasks.
Parameters
----------
time : list[datetime]
Timestamps to be downloaded (UTC).
lead_time : list[datetime]
Lead times to be downloaded.
variable : list[str]
List of variables to be downloaded.
Returns
-------
list[ECMWFOpenDataAsyncTask]
List of download tasks.
"""
tasks: list[ECMWFOpenDataAsyncTask] = []
for i, t in enumerate(time):
for j, lt in enumerate(lead_time):
for k, var in enumerate(variable):
try:
ifs_name, modifier = self.LEXICON[var] # type: ignore[index]
except KeyError as e:
logger.error(f"Variable {var} not found in lexicon")
raise e
ifs_var, levtype, level = ifs_name.split("::")
tasks.append(
ECMWFOpenDataAsyncTask(
data_array_indices=(i, j, k),
time=t,
lead_time=lt,
variable=ifs_var,
levtype=levtype,
level=level,
modifier=modifier,
)
)
return tasks
async def fetch_wrapper(
self,
task: ECMWFOpenDataAsyncTask,
xr_array: xr.DataArray,
) -> None:
"""Small wrapper to pack arrays into the DataArray."""
grib_file = await self._download_ifs_grib_cached(
time=task.time,
lead_time=task.lead_time,
variable=task.variable,
levtype=task.levtype,
level=task.level,
)
# Open with pygrib for faster, lower-memory access and roll longitudes
try:
grbs = pygrib.open(grib_file)
except Exception as e:
logger.error(f"Failed to open GRIB file {grib_file}")
raise e
try:
# Handle ensemble (pf) by stacking members in requested order
if self._fc_type == "pf" and len(self._members) > 0:
member_arrays: list[np.ndarray] = []
for m in self._members:
msgs = grbs.select(number=m)
if not msgs:
raise RuntimeError(
f"No GRIB messages found for ensemble member {m} in {grib_file}"
)
member_arrays.append(msgs[0].values)
values = np.stack(member_arrays, axis=0) # [sample, y, x]
else:
values = grbs[1].values # [y, x]
# Provided [-180, 180], roll to [0, 360] along x dimension
values = np.roll(values, shift=-len(self.LON) // 2, axis=-1)
xr_array[task.data_array_indices] = task.modifier(values)
except Exception as e:
logger.error(f"Failed to read data from GRIB file {grib_file}")
raise e
finally:
grbs.close()
async def _download_ifs_grib_cached(
self,
time: datetime,
lead_time: timedelta,
variable: str,
levtype: str,
level: str | list[str],
) -> str:
"""Download GRIB2 file to (temporary) cache."""
if isinstance(level, str):
level = [level]
hash_parts = [
time,
lead_time,
variable,
levtype,
*level,
self._fc_type,
*self._members,
]
filename = hashlib.sha256(
"_".join(str(x) for x in hash_parts).encode()
).hexdigest()
cache_path = os.path.join(self.cache, filename)
if not pathlib.Path(cache_path).is_file():
request: dict[str, Any] = {
"date": time,
"type": self._fc_type, # "fc", "cf", "pf"
"param": variable,
# "levtype": levtype, # NOTE: Commenting this out fixes what seems to be a bug with Opendata API on soil levels
"step": int(lead_time.total_seconds() // 3600),
"target": cache_path,
}
if levtype == "pl" or levtype == "sl": # Pressure levels or soil levels
request["levelist"] = level
if self._fc_type == "pf":
request["number"] = self._members
# Download
await asyncio.to_thread(self.client.retrieve, **request)
return cache_path
def _validate_time(self, times: list[datetime]) -> None:
"""Verify all times are valid based on offline knowledge.
The child class should override this method as needed.
Parameters
----------
times : list[datetime]
List of date times to fetch data for.
"""
pass
def _validate_leadtime(
self, times: list[datetime], lead_times: list[timedelta]
) -> None:
"""Verify all lead times are valid based on offline knowledge.
The child class should override this method as needed.
Parameters
----------
times : list[datetime]
List of date times to fetch data for.
lead_times : list[timedelta]
List of lead times to fetch data for.
"""
pass
@property
def cache(self) -> str:
"""Get the appropriate cache location."""
cache_dir = (
self._model.lower() + "-opendata"
) # note that model is not part of cache hash
cache_location = os.path.join(datasource_cache_root(), cache_dir)
if not self._cache:
if self._tmp_cache_hash is None:
# First access for temp cache: create a random suffix to avoid collisions
self._tmp_cache_hash = uuid.uuid4().hex[:8]
return os.path.join(
cache_location, f"tmp_{cache_dir}_{self._tmp_cache_hash}"
)
return cache_location
# =============
# Child classes
# =============
[docs]
class IFS(_ECMWFOpenDataSource):
"""Integrated forecast system (IFS) HRES initial state (analysis) data source on an
equirectangular grid at 0.25 degree resolution. IFS is a forecast model developed by
ECMWF. Data for the most recent 4 days can be retrieved from ECMWF's servers
(source `ecmwf`). Historical data is part of ECMWF's open data project on AWS
(source `aws`).
Parameters
----------
source : str, optional
Data source to fetch data from. For possible options refer to ECMWF's open data
Python SDK, by default "aws".
cache : bool, optional
Cache data source in local memory, by default True.
verbose : bool, optional
Print download progress, by default True.
async_timeout: int, optional
Time in seconds after which the download will be cancelled if not finished
successfully, by default 600.
Warning
-------
This is a remote data source and can potentially download a large amount of data
to your local machine for large requests.
Note
----
Additional information on the data repository can be referenced here:
- https://github.com/ecmwf/ecmwf-opendata
- https://confluence.ecmwf.int/display/DAC/ECMWF+open+data%3A+real-time+forecasts
- https://registry.opendata.aws/ecmwf-forecasts/
- https://console.cloud.google.com/storage/browser/ecmwf-open-data/
"""
LEXICON = IFSLexicon
def __init__(
self,
source: Literal["aws", "ecmwf", "azure"] = "aws",
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
super().__init__(
source=source,
model="ifs",
cache=cache,
verbose=verbose,
async_timeout=async_timeout,
)
[docs]
def __call__( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Retrieve IFS data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
IFS analysis data array
"""
da = self._call(time, np.array([0], dtype="datetime64[h]"), variable)
return da.isel(lead_time=0)
[docs]
async def fetch( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
IFS analysis data array.
"""
da = await self._fetch(time, np.array([0], dtype="datetime64[h]"), variable)
return da.isel(lead_time=0)
def _validate_time(self, times: list[datetime]) -> None:
"""Verify all times are valid based on offline knowledge.
Parameters
----------
times : list[datetime]
List of date times to fetch data for.
"""
validate_time(
self._model, self.client.source, times, min_time=datetime(2024, 3, 1)
)
[docs]
class IFS_FX(_ECMWFOpenDataSource):
"""Integrated forecast system (IFS) HRES forecast data source on an equirectangular
grid at 0.25 degree resolution. IFS is a forecast model developed by
ECMWF. Data for the most recent 4 days can be retrieved from ECMWF's servers
(source `ecmwf`). Historical data is part of ECMWF's open data project on AWS
(source `aws`).
Parameters
----------
source : str, optional
Data source to fetch data from. For possible options refer to ECMWF's open data
Python SDK, by default "aws".
cache : bool, optional
Cache data source in local memory, by default True.
verbose : bool, optional
Print download progress, by default True.
async_timeout: int, optional
Time in seconds after which the download will be cancelled if not finished
successfully, by default 600.
Warning
-------
This is a remote data source and can potentially download a large amount of data
to your local machine for large requests.
Note
----
Additional information on the data repository can be referenced here:
- https://github.com/ecmwf/ecmwf-opendata
- https://confluence.ecmwf.int/display/DAC/ECMWF+open+data%3A+real-time+forecasts
- https://registry.opendata.aws/ecmwf-forecasts/
- https://console.cloud.google.com/storage/browser/ecmwf-open-data/
"""
LEXICON = IFSLexicon
def __init__(
self,
source: Literal["aws", "ecmwf", "azure"] = "aws",
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
super().__init__(
source=source,
model="ifs",
cache=cache,
verbose=verbose,
async_timeout=async_timeout,
)
[docs]
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Retrieve IFS forecast data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
lead_time: timedelta | list[timedelta] | LeadTimeArray
Forecast lead times to fetch.
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
IFS forecast data array
"""
return self._call(time, lead_time, variable)
[docs]
async def fetch( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
lead_time: timedelta | list[timedelta] | LeadTimeArray
Forecast lead times to fetch.
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
IFS forecast data array.
"""
return await self._fetch(time, lead_time, variable)
def _validate_time(self, times: list[datetime]) -> None:
validate_time(
self._model, self.client.source, times, min_time=datetime(2024, 3, 1)
)
def _validate_leadtime(
self,
times: list[datetime],
lead_times: list[timedelta],
) -> None:
validate_leadtime(self._model, lead_times, interval=3, max_lead_time=360)
for delta in lead_times:
hours = int(delta.total_seconds() // 3600)
if any([time.hour in [6, 18] for time in times]) and hours > 144:
# Shorter rollouts for forecasts starting at 06Z, 18Z
raise ValueError(
f"Requested lead time {delta} can not be more than 144 hours for {self._model} starting at 06Z, 18Z"
)
if hours > 144 and not hours % 6 == 0:
raise ValueError(
f"Requested lead time {delta} needs to be 6 hour interval for {self._model} after hour 144"
)
[docs]
class IFS_ENS(_ECMWFOpenDataSource):
"""Integrated forecast system (IFS) ensemble (ENS) initial state data source on an
equirectangular grid at 0.25 degree resolution. IFS is a forecast model developed by
ECMWF. Data for the most recent 4 days can be retrieved from ECMWF's servers
(source `ecmwf`). Historical data is part of ECMWF's open data project on AWS
(source `aws`).
Parameters
----------
source : str, optional
Data source to fetch data from. For possible options refer to ECMWF's open data
Python SDK, by default "aws".
member: int, optional
Ensemble member id to use. If 0 the control forecast will be requested, if
greater than 0 perturbed ensemble member will be requested, by default 0.
cache : bool, optional
Cache data source in local memory, by default True.
verbose : bool, optional
Print download progress, by default True.
async_timeout: int, optional
Time in seconds after which the download will be cancelled if not finished
successfully, by default 600.
Warning
-------
This is a remote data source and can potentially download a large amount of data
to your local machine for large requests.
Note
----
Additional information on the data repository can be referenced here:
- https://github.com/ecmwf/ecmwf-opendata
- https://confluence.ecmwf.int/display/DAC/ECMWF+open+data%3A+real-time+forecasts
- https://registry.opendata.aws/ecmwf-forecasts/
- https://console.cloud.google.com/storage/browser/ecmwf-open-data/
"""
LEXICON = IFSLexicon
def __init__(
self,
source: Literal["aws", "ecmwf", "azure"] = "aws",
member: int = 0,
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
fc_type: Literal["cf", "pf"]
if member == 0:
fc_type = "cf" # control forecast
elif member > 0:
fc_type = "pf" # perturbed forecast
else:
raise ValueError(f"Invalid member id provide {member}")
super().__init__(
source=source,
model="ifs",
fc_type=fc_type,
members=[member],
cache=cache,
verbose=verbose,
async_timeout=async_timeout,
)
[docs]
def __call__( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Retrieve IFS ENS data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
IFS ENS initial state data array.
"""
da = self._call(time, np.array([0], dtype="datetime64[h]"), variable)
if "sample" in da.dims:
da = da.isel(sample=0)
return da.isel(lead_time=0)
[docs]
async def fetch( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
IFS ENS initial state data array.
"""
da = await self._fetch(time, np.array([0], dtype="datetime64[h]"), variable)
if "sample" in da.dims:
da = da.isel(sample=0)
return da.isel(lead_time=0)
def _validate_time(self, times: list[datetime]) -> None:
validate_time(
self._model, self.client.source, times, min_time=datetime(2024, 3, 1)
)
def _validate_leadtime(
self,
times: list[datetime],
lead_times: list[timedelta],
) -> None:
validate_leadtime(self._model, lead_times, interval=3, max_lead_time=360)
for delta in lead_times:
hours = int(delta.total_seconds() // 3600)
if any([time.hour in [6, 18] for time in times]) and hours > 144:
# Shorter rollouts for forecasts starting at 06Z, 18Z
raise ValueError(
f"Requested lead time {delta} can not be more than 144 hours for {self._model} starting at 06Z, 18Z"
)
if hours > 144 and not hours % 6 == 0:
raise ValueError(
f"Requested lead time {delta} needs to be 6 hour interval for {self._model} after hour 144"
)
[docs]
class IFS_ENS_FX(_ECMWFOpenDataSource):
"""Integrated forecast system (IFS) ensemble (ENS) forecast data source on an
equirectangular grid at 0.25 degree resolution. IFS is a forecast model developed by
ECMWF. Data for the most recent 4 days can be retrieved from ECMWF's servers
(source `ecmwf`). Historical data is part of ECMWF's open data project on AWS
(source `aws`).
Parameters
----------
source : str, optional
Data source to fetch data from. For possible options refer to ECMWF's open data
Python SDK, by default "aws".
member: int, optional
Ensemble member id to use. If 0 the control forecast will be requested, if
greater than 0 perturbed ensemble member will be requested, by default 0.
cache : bool, optional
Cache data source in local memory, by default True.
verbose : bool, optional
Print download progress, by default True.
async_timeout: int, optional
Time in seconds after which the download will be cancelled if not finished
successfully, by default 600.
Warning
-------
This is a remote data source and can potentially download a large amount of data
to your local machine for large requests.
Note
----
Additional information on the data repository can be referenced here:
- https://github.com/ecmwf/ecmwf-opendata
- https://confluence.ecmwf.int/display/DAC/ECMWF+open+data%3A+real-time+forecasts
- https://registry.opendata.aws/ecmwf-forecasts/
- https://console.cloud.google.com/storage/browser/ecmwf-open-data/
"""
LEXICON = IFSLexicon
def __init__(
self,
source: Literal["aws", "ecmwf", "azure"] = "aws",
member: int = 0,
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
fc_type: Literal["cf", "pf"]
if member == 0:
fc_type = "cf" # control forecast
elif member > 0:
fc_type = "pf" # perturbed forecast
else:
raise ValueError(f"Invalid member id provide {member}")
super().__init__(
source=source,
model="ifs",
fc_type=fc_type,
members=[member],
cache=cache,
verbose=verbose,
async_timeout=async_timeout,
)
[docs]
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Retrieve IFS ENS forecast data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
lead_time: timedelta | list[timedelta] | LeadTimeArray
Forecast lead times to fetch.
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
IFS ENS forecast data array
"""
da = self._call(time, lead_time, variable)
if "sample" in da.dims:
da = da.isel(sample=0)
return da
[docs]
async def fetch(
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
lead_time: timedelta | list[timedelta] | LeadTimeArray
Forecast lead times to fetch.
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
IFS ENS forecast data array.
"""
da = await self._fetch(time, lead_time, variable)
if "sample" in da.dims:
da = da.isel(sample=0)
return da
def _validate_time(self, times: list[datetime]) -> None:
validate_time(
self._model, self.client.source, times, min_time=datetime(2024, 3, 1)
)
def _validate_leadtime(
self,
times: list[datetime],
lead_times: list[timedelta],
) -> None:
validate_leadtime(self._model, lead_times, interval=3, max_lead_time=360)
for delta in lead_times:
hours = int(delta.total_seconds() // 3600)
if any([time.hour in [6, 18] for time in times]) and hours > 144:
# Shorter rollouts for forecasts starting at 06Z, 18Z
raise ValueError(
f"Requested lead time {delta} can not be more than 144 hours for {self._model} starting at 06Z, 18Z"
)
if hours > 144 and not hours % 6 == 0:
raise ValueError(
f"Requested lead time {delta} needs to be 6 hour interval for {self._model} after hour 144"
)
[docs]
class AIFS_FX(_ECMWFOpenDataSource):
"""Artificial intelligence forecast system (AIFS) SINGLE forecast data on an
equirectangular grid at 0.25 degree resolution. AIFS is an AI forecast model
developed by ECMWF. Data for the most recent 4 days can be retrieved from ECMWF's
servers (source `ecmwf`). Historical data is part of ECMWF's open data project on
AWS (source `aws`).
Parameters
----------
source : str, optional
Data source to fetch data from. For possible options refer to ECMWF's open data
Python SDK, by default "aws".
cache : bool, optional
Cache data source in local memory, by default True.
verbose : bool, optional
Print download progress, by default True.
async_timeout: int, optional
Time in seconds after which the download will be cancelled if not finished
successfully, by default 600.
Warning
-------
This is a remote data source and can potentially download a large amount of data
to your local machine for large requests.
Note
----
Additional information on the data repository can be referenced here:
- https://github.com/ecmwf/ecmwf-opendata
- https://confluence.ecmwf.int/display/DAC/ECMWF+open+data%3A+real-time+forecasts
- https://registry.opendata.aws/ecmwf-forecasts/
- https://console.cloud.google.com/storage/browser/ecmwf-open-data/
"""
LEXICON = AIFSLexicon
def __init__(
self,
source: Literal["aws", "ecmwf", "azure"] = "aws",
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
super().__init__(
source=source,
model="aifs-single",
cache=cache,
verbose=verbose,
async_timeout=async_timeout,
)
[docs]
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Retrieve AIFS forecast data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
lead_time: timedelta | list[timedelta] | LeadTimeArray
Forecast lead times to fetch.
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
AIFS forecast data array
"""
return self._call(time, lead_time, variable)
[docs]
async def fetch( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
lead_time: timedelta | list[timedelta] | LeadTimeArray
Forecast lead times to fetch.
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
AIFS forecast data array.
"""
return await self._fetch(time, lead_time, variable)
def _validate_time(self, times: list[datetime]) -> None:
validate_time(
self._model, self.client.source, times, min_time=datetime(2025, 7, 1, 6)
)
def _validate_leadtime(
self, times: list[datetime], lead_times: list[timedelta]
) -> None:
validate_leadtime(self._model, lead_times, interval=6, max_lead_time=360)
[docs]
class AIFS_ENS_FX(_ECMWFOpenDataSource):
"""Artificial intelligence forecast system (AIFS) ENS forecast data on an
equirectangular grid at 0.25 degree resolution. AIFS is an AI forecast model
developed by ECMWF. Data for the most recent 4 days can be retrieved from ECMWF's
servers (source `ecmwf`). Historical data is part of ECMWF's open data project on
AWS (source `aws`).
Parameters
----------
source : str, optional
Data source to fetch data from. For possible options refer to ECMWF's open data
Python SDK, by default "aws".
member: int, optional
Ensemble member id to use. If 0 the control forecast will be requested, if
greater than 0 perturbed ensemble member will be requested, by default 0.
cache : bool, optional
Cache data source in local memory, by default True.
verbose : bool, optional
Print download progress, by default True.
async_timeout: int, optional
Time in seconds after which the download will be cancelled if not finished
successfully, by default 600.
Warning
-------
This is a remote data source and can potentially download a large amount of data
to your local machine for large requests.
Note
----
Additional information on the data repository can be referenced here:
- https://github.com/ecmwf/ecmwf-opendata
- https://confluence.ecmwf.int/display/DAC/ECMWF+open+data%3A+real-time+forecasts
- https://registry.opendata.aws/ecmwf-forecasts/
- https://console.cloud.google.com/storage/browser/ecmwf-open-data/
"""
LEXICON = AIFSLexicon
def __init__(
self,
source: Literal["aws", "ecmwf", "azure"] = "aws",
member: int = 0,
cache: bool = True,
verbose: bool = True,
async_timeout: int = 600,
):
fc_type: Literal["cf", "pf"]
if member == 0:
fc_type = "cf" # control forecast
elif member > 0:
fc_type = "pf" # perturbed forecast
else:
raise ValueError(f"Invalid member id provide {member}")
super().__init__(
source=source,
model="aifs-ens",
fc_type=fc_type,
members=[member],
cache=cache,
verbose=verbose,
async_timeout=async_timeout,
)
[docs]
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Retrieve AIFS ENS forecast data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
lead_time: timedelta | list[timedelta] | LeadTimeArray
Forecast lead times to fetch.
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
AIFS ENS forecast data array
"""
da = self._call(time, lead_time, variable)
if "sample" in da.dims:
da = da.isel(sample=0)
return da
[docs]
async def fetch( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for (UTC).
lead_time: timedelta | list[timedelta] | LeadTimeArray
Forecast lead times to fetch.
variable : str | list[str] | VariableArray
String, list of strings or array of strings that refer to variables to
return. Must be in the data lexicon.
Returns
-------
xr.DataArray
ECMWF weather data array.
"""
da = await self._fetch(time, lead_time, variable)
return da.isel(sample=0)
def _validate_time(self, times: list[datetime]) -> None:
validate_time(
self._model, self.client.source, times, min_time=datetime(2025, 7, 1, 6)
)
def _validate_leadtime(
self, times: list[datetime], lead_times: list[timedelta]
) -> None:
validate_leadtime(self._model, lead_times, interval=6, max_lead_time=360)
def validate_time(
model: str,
source: str,
times: list[datetime],
min_time: datetime,
) -> None:
"""Verify all times are valid based on offline knowledge.
Parameters
----------
model : str
Model name.
source : str
ECMWF client source.
times : list[datetime]
List of date times to fetch data for.
min_time : datetime
Earliest available datetime.
"""
for time in times:
if not (time - datetime(1900, 1, 1)).total_seconds() % 21600 == 0:
raise ValueError(
f"Requested start time {time} needs to be 6-hour interval for {model}"
)
if time < min_time:
raise ValueError(
f"Requested start time {time} needs to be at least {min_time} for {model}"
)
if source == "ecmwf":
if (datetime.now() - time).days > 4:
raise ValueError(
f"Requested start time {time} needs to be within the past 4 days for {model} with source ECMWF"
)
def validate_leadtime(
model: str,
lead_times: list[timedelta],
interval: int,
max_lead_time: int,
) -> None:
"""Verify all lead times are valid based on offline knowledge.
Parameters
----------
model : str
Model name.
lead_times : list[timedelta]
List of lead times to fetch data for.
interval : int
Required lead time interval in hours.
max_lead_time : timedelta
Maximum available lead time in hours.
"""
for delta in lead_times:
# See https://github.com/ecmwf/ecmwf-opendata?tab=readme-ov-file#time-steps
# But there seem to be some inaccuracies, e.g., HRES is available up to hour 360
# See S3: aws s3 ls --no-sign-request s3://ecmwf-forecasts/20251016/06z/aifs-single/0p25/oper/
hours = int(delta.total_seconds() // 3600)
if not delta.total_seconds() % 3600 == 0 or not hours % interval == 0:
raise ValueError(
f"Requested lead time {delta} needs to be {interval}-hour interval for {model}"
)
if hours > max_lead_time:
raise ValueError(
f"Requested lead time {delta} cannot be more than {max_lead_time} hours for {model}"
)