Source code for earth2studio.data.xr

# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import datetime
from typing import Any

import numpy as np
import xarray as xr
from numpy import ndarray
from pandas import to_datetime

from earth2studio.utils.type import TimeArray, VariableArray


[docs] class DataArrayFile: """A local xarray dataarray file data source. This file should be compatible with xarray. For example, a netCDF file. Parameters ---------- file_path : str Path to xarray data array compatible file. """ def __init__(self, file_path: str, **xr_args: Any): self.file_path = file_path self.da = xr.open_dataarray(self.file_path, **xr_args) # self.da = xr.open_dataarray(self.file_path, **xr_args)
[docs] def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Function to get data. Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for. variable : str | list[str] | VariableArray Strings or list of strings that refer to variables to return. Returns ------- xr.DataArray Loaded data array """ return self.da.sel(time=time, variable=variable)
class InferenceOutputSource: """Adapt a inference output into a data source. This data source loads an existing xarray Dataset, such as a NetCDF file or Zarr store from an Earth2Studio forecast inference pipeline, which can then be filtered to provide a data array give a variable and time. Time, lead_time and variable are expected dimensions to be present in the DataSet. Note ---- This data source performs automatic transformation of time coordinates based on the time and lead_time coordinates: - **If lead_time has length 1**: The single lead_time value is added to all time coordinates to produce valid forecast times (time + lead_time). The lead_time dimension is then removed. - **If time has length 1**: The single time value is broadcast to match the lead_time dimension, then lead_time values are added to produce valid forecast times. The time dimension is removed and lead_time is renamed to time. Either time or lead_time must have length 1 after applying the filter_dict - both cannot have length > 1 simultaneously. The resulting dataset will have a single time dimension containing the computed valid forecast timestamps. Parameters ---------- inference_output : str | xarray.Dataset An Xarray dataset or a path to an xarray-compatible dataset file (e.g., NetCDF/Zarr). filter_dict : dict, optional Dictionary of selections applied before transformation (e.g. ``{"ensemble": 0}``). Coordinates not in the required dimensions are dropped after selection, by default {} **xr_args : Any Additional keyword arguments forwarded to ``xarray.open_dataset``. """ def __init__( self, inference_output: str | xr.Dataset, filter_dict: dict = {}, **xr_args: Any ): if isinstance(inference_output, str): self.da = xr.open_dataset(inference_output, **xr_args) elif isinstance(inference_output, xr.Dataset): self.da = inference_output else: raise TypeError( f"Expected `inference_output` to be a string or xarray.Dataset, not {type(inference_output)}." ) self.da = self.da.to_array("variable") # Need to keep these dims, so make then a list if scalar value for k in ("time", "lead_time"): if k in filter_dict and not isinstance( filter_dict[k], (list, tuple, np.ndarray) ): filter_dict[k] = [filter_dict[k]] self.da = self.da.sel(filter_dict) # The following dimensions and their order is required for the data to be used as a datasource required_dims = ["time", "lead_time", "variable"] # Validate remaining dimensions if not set(required_dims).issubset(set(self.da.dims)): raise ValueError( f"Missing required dims. Data array loaded has dims {self.da.dims} but " + f"needs {required_dims}. Use filter_dict to select a subset of the data." ) if len(self.da["time"]) > 1 and len(self.da["lead_time"]) > 1: raise ValueError( "Either time or lead_time should have length of one. " + f"Length of time: {len(self.da['time'])}, lead_time: {len(self.da['lead_time'])}." + "Use filter_dict to select a subset of the data." ) # ensure ["time", "lead_time", "variable"] ordering of required_dims (other dims unaffected) dims = list(self.da.dims) required_dims_ind = [dims.index(dim) for dim in required_dims] for i, ind in enumerate(sorted(required_dims_ind)): dims[ind] = self.da.dims[required_dims_ind[i]] if dims != list(self.da.dims): self.da = self.da.transpose(*dims) # add "time" and "lead_time" so that only "time" remains if self.da["lead_time"].shape[0] == 1: time_array = ( self.da.coords["time"].values + self.da.coords["lead_time"].values[0] ) self.da = self.da.isel(lead_time=0).drop_vars(["lead_time"]) self.da = self.da.assign_coords(time=time_array) else: time_array = np.repeat( self.da.coords["time"].values[0], self.da["lead_time"].shape[0] ) time_array = time_array + self.da["lead_time"].values self.da = self.da.isel(time=0).drop_vars(["time"]) self.da = self.da.rename({"lead_time": "time"}) self.da = self.da.assign_coords(time=time_array) def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Retrieve data for specified valid times and variables. Parameters ---------- time : datetime | list[datetime] | TimeArray One or more valid forecast timestamps (after ``time + lead_time`` transformation). variable : str | list[str] | VariableArray One or more variable names to return. Returns ------- xr.DataArray Data array subset for the requested ``time`` and ``variable``. """ if not (isinstance(time, list) or isinstance(time, ndarray)): time = [time] if not (isinstance(variable, list) or isinstance(variable, ndarray)): variable = [variable] return self.da.sel(time=time, variable=variable)
[docs] class DataSetFile: """A local xarray dataset file data source. This file should be compatible with xarray. For example, a netCDF file. Parameters ---------- file_path : str Path to xarray dataset compatible file. array_name : str Data array name in xarray dataset """ def __init__(self, file_path: str, array_name: str, **xr_args: Any): self.file_path = file_path self.da = xr.open_dataset(self.file_path, **xr_args)[array_name]
[docs] def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Function to get data. Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for. variable : str | list[str] | VariableArray Strings or list of strings that refer to variables to return. Returns ------- xr.DataArray Loaded data array """ if not (isinstance(time, list) or isinstance(time, ndarray)): time = [time] if not (isinstance(variable, list) or isinstance(variable, ndarray)): variable = [variable] return self.da.sel(time=time, variable=variable)
class DataArrayDirectory: """A local xarray dataarray directory data source. This file should be compatible with xarray. For example, a netCDF file. the structure of the directory should be like path/to/monthly/files |___2020 | |___2020_01.nc | |___2020_02.nc | |___ ... | |___2021 |___2021_01.nc |___... Parameters ---------- file_path : str Path to xarray data array compatible file. xr_args : Any Keyword arguments to send to the xarray opening method. """ def __init__(self, dir_path: str, **xr_args: Any): self.dir_path = dir_path self.das: dict[str, dict[str, xr.DataArray]] = {} for yr in os.listdir(self.dir_path): yr_dir = os.path.join(self.dir_path, yr) if os.path.isdir(yr_dir): self.das[yr] = {} for fl in os.listdir(yr_dir): pth = os.path.join(yr_dir, fl) if os.path.isfile(pth): try: arr = xr.open_dataarray(pth, **xr_args) except: # noqa continue mon = fl.split(".")[0].split("_")[-1] self.das[yr][mon] = arr def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Function to get data. Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for. variable : str | list[str] | VariableArray Strings or list of strings that refer to variables to return. Returns ------- xr.DataArray Loaded data array """ if not (isinstance(time, list) or isinstance(time, ndarray)): time = [time] if not (isinstance(variable, list) or isinstance(variable, ndarray)): variable = [variable] arrs = [] for tt in time: yr = str(to_datetime(tt).year) mon = str(to_datetime(tt).month).zfill(2) arrs.append(self.das[yr][mon].sel(time=tt, variable=variable)) return xr.concat(arrs, dim="time")
[docs] class DataArrayPathList: """A local xarray dataarray directory data source that handles multiple files. This class provides functionality to work with multiple xarray-compatible files (e.g., netCDF) as a single data source. All input files must have consistent dimensions and variables. Under the hood, it uses xarray's open_mfdataset which leverages Dask for parallel and memory-efficient data processing. Parameters ---------- paths : str | list[str] Either a string glob pattern (e.g., "path/to/files/*.nc") or an explicit list of files. All specified files must exist and be readable. xr_args : Any Additional keyword arguments passed to xarray's open_mfdataset method. Raises ------ FileNotFoundError If no files match the provided path pattern or if any specified file doesn't exist. ValueError If the files have inconsistent dimensions or variables. RuntimeError If there are issues opening or processing the dataset. Notes ----- - The class uses Dask arrays internally through xarray's open_mfdataset, providing efficient parallel processing and lazy evaluation. Operations are only computed when data is actually requested through the __call__ method. - All files must share the same coordinate system and variable structure. - Required dimensions are: time, variable, lat, and lon. """ def __init__(self, paths: str | list[str], **xr_args: Any): self.paths = paths # Open multiple files as a single dataset dataset = xr.open_mfdataset(self.paths, **xr_args) # Convert to DataArray with proper dimension ordering and coordinates self.da = xr.DataArray( dataset.to_dataarray().data.squeeze(), dims=dataset.dims, coords=dataset.coords, ) # Validate required dimensions required_dims = {"time", "variable", "lat", "lon"} missing_dims = required_dims - set(self.da.dims) if missing_dims: raise ValueError(f"Dataset missing required dimensions: {missing_dims}")
[docs] def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Retrieve data for specified timestamps and variables. Parameters ---------- time : datetime | list[datetime] | TimeArray Single timestamp or list of timestamps to retrieve data for. variable : str | list[str] | VariableArray Single variable name or list of variable names to retrieve. Returns ------- xr.DataArray Data array containing the requested time and variable selections. Raises ------ ValueError If requested time or variable values are not present in the dataset. """ # Ensure inputs are lists for consistent processing times = [time] if not isinstance(time, (list, ndarray)) else time variables = ( [variable] if not isinstance(variable, (list, ndarray)) else variable ) # Process each timestamp arrays = self.da.sel(time=times, variable=variables) return xr.concat(arrays, dim="time")