Source code for earth2studio.data.ncar

# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import calendar
import concurrent.futures
import hashlib
import os
import shutil
import uuid
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any

import nest_asyncio
import numpy as np
import pandas as pd
import s3fs
import xarray as xr
from loguru import logger
from tqdm.asyncio import tqdm

from earth2studio.data.utils import (
    datasource_cache_root,
    prep_data_inputs,
)
from earth2studio.lexicon import NCAR_ERA5Lexicon
from earth2studio.utils.type import TimeArray, VariableArray

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)


@dataclass
class NCARAsyncTask:
    """Small helper struct for Async tasks"""

    ncar_file_uri: str
    ncar_data_variable: str
    # Dictionary mapping time index -> time id
    ncar_time_indices: dict[int, np.datetime64]
    # Dictionary mapping level index -> varaible id
    ncar_level_indices: dict[int, str]
    # Time index mapping for time, only used for accum files atm
    ncar_meta: dict[int, dict[str, Any]]


[docs] class NCAR_ERA5: """ERA5 data provided by NSF NCAR via the AWS Open Data Sponsorship Program. ERA5 is the fifth generation of the ECMWF global reanalysis and available on a 0.25 degree WGS84 grid at hourly intervals spanning from 1940 to the present. Parameters ---------- max_workers : int, optional Max works in async io thread pool. Only applied when using sync call function and will modify the default async loop if one exists, by default 24 cache : bool, optional Cache data source on local memory, by default True verbose : bool, optional Print download progress, by default True async_timeout : int, optional Timeout in seconds for async operations, by default 600 Warning ------- This is a remote data source and can potentially download a large amount of data to your local machine for large requests. Note ---- Additional resources: https://registry.opendata.aws/nsf-ncar-era5/ """ NCAR_EAR5_LAT = np.linspace(90, -90, 721) NCAR_EAR5_LON = np.linspace(0, 360, 1440, endpoint=False) def __init__( self, max_workers: int = 24, cache: bool = True, verbose: bool = True, async_timeout: int = 600, ): self._max_workers = max_workers self._cache = cache self._verbose = verbose self.async_timeout = async_timeout self._tmp_cache_hash: str | None = None
[docs] def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Function to get data Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for (UTC). variable : str | list[str] | VariableArray String, list of strings or array of strings that refer to variables to return. Must be in the NCAR lexicon. Returns ------- xr.DataArray ERA5 weather data array from NCAR ERA5 """ nest_asyncio.apply() # Patch asyncio to work in notebooks try: loop = asyncio.get_event_loop() except RuntimeError: # If no event loop exists, create one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # Modify the worker amount loop.set_default_executor( concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) ) xr_array = loop.run_until_complete( asyncio.wait_for(self.fetch(time, variable), timeout=self.async_timeout) ) # Delete cache if needed if not self._cache: shutil.rmtree(self.cache) return xr_array
[docs] async def fetch( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Async function to get data Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for (UTC). variable : str | list[str] | VariableArray String, list of strings or array of strings that refer to variables to return. Must be in the ARCO lexicon. Returns ------- xr.DataArray ERA5 weather data array from ARCO """ time, variable = prep_data_inputs(time, variable) # Create cache dir if doesnt exist Path(self.cache).mkdir(parents=True, exist_ok=True) # Make sure input time is valid self._validate_time(time) # Create tasks and group based on variable data_arrays: dict[str, list[xr.DataArray]] = {} async_tasks = [] for task in self._create_tasks(time, variable).values(): future = self.fetch_wrapper(task) async_tasks.append(future) # Now wait results = await tqdm.gather( *async_tasks, desc="Fetching NCAR ERA5 data", disable=(not self._verbose) ) # Group based on variable for result in results: key = str(result.coords["variable"]) if key not in data_arrays: data_arrays[key] = [] data_arrays[key].append(result) # Concat times for same variable groups array_list = [] for arrs in data_arrays.values(): if len(arrs) > 1 and "time" in arrs[0].dims: # Only concat on time if multiple arrays and time dimension exists array_list.append(xr.concat(arrs, dim="time")) else: # For single arrays or arrays without time dim, just take the first array_list.append(arrs[0]) # Now concat varaibles res = xr.concat(array_list, dim="variable", coords="minimal") res.name = None # remove name, which is kept from one of the arrays # Delete cache if needed if not self._cache: shutil.rmtree(self.cache) if "time" in res.dims: return res.sel(time=time, variable=variable) else: # For files without time dimension, just select variables logger.warning( "No time dimension found in dataset, selecting variables only" ) return res.sel(variable=variable)
def _create_tasks( self, time: list[datetime], variable: list[str] ) -> dict[str, NCARAsyncTask]: """Create download tasks, each corresponding to one file on S3. The H5 file stored in the dataset contains Parameters ---------- times : list[datetime] Timestamps to be downloaded (UTC). variables : list[str] List of variables to be downloaded. Returns ------- list[dict] List of download tasks. """ tasks: dict[str, NCARAsyncTask] = {} # group pressure-level variables s3_pattern = "s3://nsf-ncar-era5/{product}/{year}{month:02}/{product}.{variable}.{grid}.{year}{month:02}{daystart:02}00_{year}{month:02}{dayend:02}23.nc" s3_pattern_accum = "s3://nsf-ncar-era5/{product}/{year1}{month1:02}/{product}.{variable}.{grid}.{year1}{month1:02}{daystart:02}06_{year2}{month2:02}{dayend:02}06.nc" for i, t in enumerate(time): for j, v in enumerate(variable): ncar_name, _ = NCAR_ERA5Lexicon[v] product = ncar_name.split("::")[0] variable_name = ncar_name.split("::")[1] grid = ncar_name.split("::")[2] level_index = int(ncar_name.split("::")[3]) data_variable = f"{variable_name.split('_')[-1].upper()}" # Pressure is held in daily nc files if product == "e5.oper.an.pl": daystart = t.day dayend = t.day time_index = t.hour file_name = s3_pattern.format( product=product, variable=variable_name, grid=grid, year=t.year, month=t.month, daystart=daystart, dayend=dayend, ) meta = {} # Accumulated products are split into bi-monthly files which are have # the range (start, end], for example file: # e5.oper.fc.sfc.accumu.128_142_lsp.ll025sc.2025020106_2025021606.nc # will include lsp measurements for the times # 20250201T07:00:00, 20250201T08:00:00, ... , 20250216T06:00:00 elif product == "e5.oper.fc.sfc.accumu": # Data is stored in two time dims: forecast_initial_time, forecast_hour # forecast_initial_time is at hours 06 and 18 # forecast_hour is between [1-12] initial_time = t.replace( hour=0, minute=0, second=0, microsecond=0 ) + pd.Timedelta(hours=((t.hour - 7) // 12) * 12 + 6) fc_hour = int((t - initial_time).total_seconds() / 3600) # Determine the start and end day for s3 file bi-monthly interval if initial_time.day >= 16: date1 = initial_time.replace(day=16) if initial_time.month == 12: date2 = initial_time.replace( year=initial_time.year + 1, month=1, day=1 ) else: date2 = initial_time.replace( month=initial_time.month + 1, day=1 ) else: date1 = initial_time.replace(day=1) date2 = initial_time.replace(day=16) file_name = s3_pattern_accum.format( product=product, variable=variable_name, grid=grid, year1=date1.year, year2=date2.year, month1=date1.month, month2=date2.month, daystart=date1.day, dayend=date2.day, ) time_index = i meta = { "forecast_initial_time": initial_time, "forecast_hour": fc_hour, "time": np.datetime64(t), } # Surface held in monthly else: daystart = 1 dayend = calendar.monthrange(t.year, t.month)[-1] time_index = int( (t - datetime(t.year, t.month, 1)).total_seconds() / 3600 ) file_name = s3_pattern.format( product=product, variable=variable_name, grid=grid, year=t.year, month=t.month, daystart=daystart, dayend=dayend, ) meta = {} # Place into dict, if we already have a request for a certain file # just append the time and variable needed if file_name in tasks: tasks[file_name].ncar_time_indices[time_index] = np.datetime64(t) tasks[file_name].ncar_level_indices[level_index] = v tasks[file_name].ncar_meta[time_index] = meta else: tasks[file_name] = NCARAsyncTask( ncar_file_uri=file_name, ncar_data_variable=data_variable, ncar_time_indices={time_index: np.datetime64(t)}, ncar_level_indices={level_index: v}, ncar_meta={time_index: meta}, ) return tasks async def fetch_wrapper( self, task: NCARAsyncTask, ) -> xr.DataArray: """Small wrapper to pack arrays into the DataArray""" out = await self.fetch_array( task.ncar_file_uri, task.ncar_data_variable, list(task.ncar_time_indices.keys()), list(task.ncar_level_indices.keys()), task.ncar_meta, ) # Rename levels coord to variable out = out.rename({"level": "variable", "longitude": "lon", "latitude": "lat"}) out = out.assign_coords(variable=list(task.ncar_level_indices.values())) # Shouldnt be needed but just in case, to validate out = out.assign_coords(time=np.array(list(task.ncar_time_indices.values()))) return out async def fetch_array( self, nc_file_uri: str, data_variable: str, time_idx: list[int], level_idx: list[int], ncar_meta: dict, ) -> xr.DataArray: """Fetches requested array from remote store Parameters ---------- nc_file_uri : str S3 URI to NetCDF file data_variable : str Data variable name of the array to use in the NetCDF file time_idx : list[int] Time indexes (hours since start time of file) level_idx : list[int] Pressure level indices if applicable, should be same length as time_idx Returns ------- xr.DataArray Data array loaded from requested file """ logger.debug( f"Fetching NCAR ERA5 variable: {data_variable} in file {nc_file_uri}" ) # Here we manually cache the data arrays, this is because fsspec caches the # extracted NetCDF file. Not super optimal, can have some repeat storage given # different level / time indexes # Not super optimal here... could have repeat data under different hashs but # better than saving the entire file on disk for like 1 date sha = hashlib.sha256( ( str(nc_file_uri) + str(data_variable) + str(time_idx) + str(level_idx) + str(ncar_meta) ).encode() ) filename = sha.hexdigest() cache_path = os.path.join(self.cache, filename) if os.path.exists(cache_path): ds = await asyncio.to_thread( xr.open_dataarray, cache_path, engine="h5netcdf", cache=False ) else: # New fs every call so we dont block, netcdf reads seems to not support # open_async -> S3AsyncStreamedFile (big sad) fs = s3fs.S3FileSystem( anon=True, asynchronous=False, skip_instance_cache=True ) with fs.open(nc_file_uri, "rb", block_size=4 * 1400 * 720) as f: ds = await asyncio.to_thread( xr.open_dataset, f, engine="h5netcdf", cache=False ) # Sometimes data field have VAR_ prepended if f"VAR_{data_variable}" in ds: data_variable = f"VAR_{data_variable}" if data_variable not in ds: raise ValueError( f"Variable '{data_variable}' or 'VAR_{data_variable}' from task not found in dataset. " + f"Available variables: {list(ds.keys())}." ) # Pressure level variable if "level" in ds.dims: ds = ds.isel(time=list(time_idx), level=list(level_idx))[ data_variable ] # Other product indexing else: if "e5.oper.an.sfc" in nc_file_uri: ds = ds.isel(time=list(time_idx))[data_variable] elif "e5.oper.fc.sfc.accumu" in nc_file_uri: # This is annoying here because we are dealing with mapping # two dimensions to a single time coord outputs = [] ds = ds[data_variable] for i in time_idx: out = ds.sel(forecast_hour=ncar_meta[i]["forecast_hour"]) out = out.sel( forecast_initial_time=ncar_meta[i][ "forecast_initial_time" ] ) out = out.expand_dims( {"time": [ncar_meta[i]["time"]]}, axis=0 ) out = out.drop_vars( ["forecast_hour", "forecast_initial_time"], errors="ignore", ) outputs.append(out) ds = xr.concat(outputs, dim="time", coords="minimal") else: raise ValueError("Unknown product") ds = ds.expand_dims({"level": [0]}, axis=1) # Load the data, this is the actual download ds = await asyncio.to_thread(ds.load) # Cache nc file if present if self._cache: await asyncio.to_thread(ds.to_netcdf, cache_path, engine="h5netcdf") return ds @classmethod def _validate_time(cls, times: list[datetime]) -> None: """Verify that date time is valid for ERA5 based on offline knowledge. Parameters ---------- times : list[datetime] Timestamps to be downloaded (UTC). """ for time in times: if time < datetime(1940, 1, 1): raise ValueError( f"Requested date time {time} must be after January 1st, 1940 for NCAR ERA5" ) if not (time - datetime(1900, 1, 1)).total_seconds() % 3600 == 0: raise ValueError( f"Requested date time {time} needs to be 1 hour interval for NCAR ERA5" ) @property def cache(self) -> str: """Return appropriate cache location.""" cache_location = os.path.join(datasource_cache_root(), "ncar_era5") if not self._cache: if self._tmp_cache_hash is None: # First access for temp cache: create a random suffix to avoid collisions self._tmp_cache_hash = uuid.uuid4().hex[:8] cache_location = os.path.join( cache_location, f"tmp_ncar_{self._tmp_cache_hash}" ) return cache_location