# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import datetime
from typing import Any
import numpy as np
import xarray as xr
from numpy import ndarray
from pandas import to_datetime
from earth2studio.utils.type import TimeArray, VariableArray
[docs]
class DataArrayFile:
"""A local xarray dataarray file data source. This file should be compatible with
xarray. For example, a netCDF file.
Parameters
----------
file_path : str
Path to xarray data array compatible file.
"""
def __init__(self, file_path: str, **xr_args: Any):
self.file_path = file_path
self.da = xr.open_dataarray(self.file_path, **xr_args)
# self.da = xr.open_dataarray(self.file_path, **xr_args)
[docs]
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Function to get data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for.
variable : str | list[str] | VariableArray
Strings or list of strings that refer to variables to return.
Returns
-------
xr.DataArray
Loaded data array
"""
return self.da.sel(time=time, variable=variable)
class InferenceOutputSource:
"""Adapt a inference output into a data source.
This data source loads an existing xarray Dataset, such as a NetCDF file or Zarr
store from an Earth2Studio forecast inference pipeline, which can then be filtered
to provide a data array give a variable and time.
Time, lead_time and variable are expected dimensions to be present in the DataSet.
Note
----
This data source performs automatic transformation of time coordinates based on the
time and lead_time coordinates:
- **If lead_time has length 1**: The single lead_time value is added to all time
coordinates to produce valid forecast times (time + lead_time). The lead_time
dimension is then removed.
- **If time has length 1**: The single time value is broadcast to match the
lead_time dimension, then lead_time values are added to produce valid forecast
times. The time dimension is removed and lead_time is renamed to time.
Either time or lead_time must have length 1 after applying the filter_dict
- both cannot have length > 1 simultaneously. The resulting dataset will have a
single time dimension containing the computed valid forecast timestamps.
Parameters
----------
inference_output : str | xarray.Dataset
An Xarray dataset or a path to an xarray-compatible dataset file
(e.g., NetCDF/Zarr).
filter_dict : dict, optional
Dictionary of selections applied before transformation (e.g.
``{"ensemble": 0}``). Coordinates not in the required
dimensions are dropped after selection, by default {}
**xr_args : Any
Additional keyword arguments forwarded to ``xarray.open_dataset``.
"""
def __init__(
self, inference_output: str | xr.Dataset, filter_dict: dict = {}, **xr_args: Any
):
if isinstance(inference_output, str):
self.da = xr.open_dataset(inference_output, **xr_args)
elif isinstance(inference_output, xr.Dataset):
self.da = inference_output
else:
raise TypeError(
f"Expected `inference_output` to be a string or xarray.Dataset, not {type(inference_output)}."
)
self.da = self.da.to_array("variable")
# Need to keep these dims, so make then a list if scalar value
for k in ("time", "lead_time"):
if k in filter_dict and not isinstance(
filter_dict[k], (list, tuple, np.ndarray)
):
filter_dict[k] = [filter_dict[k]]
self.da = self.da.sel(filter_dict)
# The following dimensions and their order is required for the data to be used as a datasource
required_dims = ["time", "lead_time", "variable"]
# Validate remaining dimensions
if not set(required_dims).issubset(set(self.da.dims)):
raise ValueError(
f"Missing required dims. Data array loaded has dims {self.da.dims} but "
+ f"needs {required_dims}. Use filter_dict to select a subset of the data."
)
if len(self.da["time"]) > 1 and len(self.da["lead_time"]) > 1:
raise ValueError(
"Either time or lead_time should have length of one. "
+ f"Length of time: {len(self.da['time'])}, lead_time: {len(self.da['lead_time'])}."
+ "Use filter_dict to select a subset of the data."
)
# ensure ["time", "lead_time", "variable"] ordering of required_dims (other dims unaffected)
dims = list(self.da.dims)
required_dims_ind = [dims.index(dim) for dim in required_dims]
for i, ind in enumerate(sorted(required_dims_ind)):
dims[ind] = self.da.dims[required_dims_ind[i]]
if dims != list(self.da.dims):
self.da = self.da.transpose(*dims)
# add "time" and "lead_time" so that only "time" remains
if self.da["lead_time"].shape[0] == 1:
time_array = (
self.da.coords["time"].values + self.da.coords["lead_time"].values[0]
)
self.da = self.da.isel(lead_time=0).drop_vars(["lead_time"])
self.da = self.da.assign_coords(time=time_array)
else:
time_array = np.repeat(
self.da.coords["time"].values[0], self.da["lead_time"].shape[0]
)
time_array = time_array + self.da["lead_time"].values
self.da = self.da.isel(time=0).drop_vars(["time"])
self.da = self.da.rename({"lead_time": "time"})
self.da = self.da.assign_coords(time=time_array)
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Retrieve data for specified valid times and variables.
Parameters
----------
time : datetime | list[datetime] | TimeArray
One or more valid forecast timestamps (after ``time + lead_time``
transformation).
variable : str | list[str] | VariableArray
One or more variable names to return.
Returns
-------
xr.DataArray
Data array subset for the requested ``time`` and ``variable``.
"""
if not (isinstance(time, list) or isinstance(time, ndarray)):
time = [time]
if not (isinstance(variable, list) or isinstance(variable, ndarray)):
variable = [variable]
return self.da.sel(time=time, variable=variable)
[docs]
class DataSetFile:
"""A local xarray dataset file data source. This file should be compatible with
xarray. For example, a netCDF file.
Parameters
----------
file_path : str
Path to xarray dataset compatible file.
array_name : str
Data array name in xarray dataset
"""
def __init__(self, file_path: str, array_name: str, **xr_args: Any):
self.file_path = file_path
self.da = xr.open_dataset(self.file_path, **xr_args)[array_name]
[docs]
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Function to get data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for.
variable : str | list[str] | VariableArray
Strings or list of strings that refer to variables to return.
Returns
-------
xr.DataArray
Loaded data array
"""
if not (isinstance(time, list) or isinstance(time, ndarray)):
time = [time]
if not (isinstance(variable, list) or isinstance(variable, ndarray)):
variable = [variable]
return self.da.sel(time=time, variable=variable)
class DataArrayDirectory:
"""A local xarray dataarray directory data source. This file should be compatible with
xarray. For example, a netCDF file. the structure of the directory should be like
path/to/monthly/files
|___2020
| |___2020_01.nc
| |___2020_02.nc
| |___ ...
|
|___2021
|___2021_01.nc
|___...
Parameters
----------
file_path : str
Path to xarray data array compatible file.
xr_args : Any
Keyword arguments to send to the xarray opening method.
"""
def __init__(self, dir_path: str, **xr_args: Any):
self.dir_path = dir_path
self.das: dict[str, dict[str, xr.DataArray]] = {}
for yr in os.listdir(self.dir_path):
yr_dir = os.path.join(self.dir_path, yr)
if os.path.isdir(yr_dir):
self.das[yr] = {}
for fl in os.listdir(yr_dir):
pth = os.path.join(yr_dir, fl)
if os.path.isfile(pth):
try:
arr = xr.open_dataarray(pth, **xr_args)
except: # noqa
continue
mon = fl.split(".")[0].split("_")[-1]
self.das[yr][mon] = arr
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Function to get data.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Timestamps to return data for.
variable : str | list[str] | VariableArray
Strings or list of strings that refer to variables to return.
Returns
-------
xr.DataArray
Loaded data array
"""
if not (isinstance(time, list) or isinstance(time, ndarray)):
time = [time]
if not (isinstance(variable, list) or isinstance(variable, ndarray)):
variable = [variable]
arrs = []
for tt in time:
yr = str(to_datetime(tt).year)
mon = str(to_datetime(tt).month).zfill(2)
arrs.append(self.das[yr][mon].sel(time=tt, variable=variable))
return xr.concat(arrs, dim="time")
[docs]
class DataArrayPathList:
"""A local xarray dataarray directory data source that handles multiple files.
This class provides functionality to work with multiple xarray-compatible files (e.g., netCDF)
as a single data source. All input files must have consistent dimensions and variables.
Under the hood, it uses xarray's open_mfdataset which leverages Dask for parallel and
memory-efficient data processing.
Parameters
----------
paths : str | list[str]
Either a string glob pattern (e.g., "path/to/files/*.nc") or an explicit list of files.
All specified files must exist and be readable.
xr_args : Any
Additional keyword arguments passed to xarray's open_mfdataset method.
Raises
------
FileNotFoundError
If no files match the provided path pattern or if any specified file doesn't exist.
ValueError
If the files have inconsistent dimensions or variables.
RuntimeError
If there are issues opening or processing the dataset.
Notes
-----
- The class uses Dask arrays internally through xarray's open_mfdataset, providing efficient
parallel processing and lazy evaluation. Operations are only computed when data is actually
requested through the __call__ method.
- All files must share the same coordinate system and variable structure.
- Required dimensions are: time, variable, lat, and lon.
"""
def __init__(self, paths: str | list[str], **xr_args: Any):
self.paths = paths
# Open multiple files as a single dataset
dataset = xr.open_mfdataset(self.paths, **xr_args)
# Convert to DataArray with proper dimension ordering and coordinates
self.da = xr.DataArray(
dataset.to_dataarray().data.squeeze(),
dims=dataset.dims,
coords=dataset.coords,
)
# Validate required dimensions
required_dims = {"time", "variable", "lat", "lon"}
missing_dims = required_dims - set(self.da.dims)
if missing_dims:
raise ValueError(f"Dataset missing required dimensions: {missing_dims}")
[docs]
def __call__(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Retrieve data for specified timestamps and variables.
Parameters
----------
time : datetime | list[datetime] | TimeArray
Single timestamp or list of timestamps to retrieve data for.
variable : str | list[str] | VariableArray
Single variable name or list of variable names to retrieve.
Returns
-------
xr.DataArray
Data array containing the requested time and variable selections.
Raises
------
ValueError
If requested time or variable values are not present in the dataset.
"""
# Ensure inputs are lists for consistent processing
times = [time] if not isinstance(time, (list, ndarray)) else time
variables = (
[variable] if not isinstance(variable, (list, ndarray)) else variable
)
# Process each timestamp
arrays = self.da.sel(time=times, variable=variables)
return xr.concat(arrays, dim="time")