Source code for earth2studio.data.ncar

# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import calendar
import concurrent.futures
import hashlib
import os
import shutil
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path

import nest_asyncio
import numpy as np
import s3fs
import xarray as xr
from loguru import logger
from tqdm.asyncio import tqdm

from earth2studio.data.utils import (
    datasource_cache_root,
    prep_data_inputs,
)
from earth2studio.lexicon import NCAR_ERA5Lexicon
from earth2studio.utils.type import TimeArray, VariableArray

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)


@dataclass
class NCARAsyncTask:
    """Small helper struct for Async tasks"""

    ncar_file_uri: str
    ncar_data_variable: str
    # Dictionary mapping time index -> time id
    ncar_time_indices: dict[int, datetime]
    # Dictionary mapping level index -> varaible id
    ncar_level_indices: dict[int, str]


[docs] class NCAR_ERA5: """ERA5 data provided by NSF NCAR via the AWS Open Data Sponsorship Program. ERA5 is the fifth generation of the ECMWF global reanalysis and available on a 0.25 degree WGS84 grid at hourly intervals spanning from 1940 to the present. Parameters ---------- max_workers : int, optional Max works in async io thread pool. Only applied when using sync call function and will modify the default async loop if one exists, by default 24 cache : bool, optional Cache data source on local memory, by default True verbose : bool, optional Print download progress, by default True async_timeout : int, optional Timeout in seconds for async operations, by default 600 Warning ------- This is a remote data source and can potentially download a large amount of data to your local machine for large requests. Note ---- Additional resources: https://registry.opendata.aws/nsf-ncar-era5/ """ NCAR_EAR5_LAT = np.linspace(90, -90, 721) NCAR_EAR5_LON = np.linspace(0, 359.75, 1440) def __init__( self, max_workers: int = 24, cache: bool = True, verbose: bool = True, async_timeout: int = 600, ): self._max_workers = max_workers self._cache = cache self._verbose = verbose self.async_timeout = async_timeout
[docs] def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Function to get data Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for (UTC). variable : str | list[str] | VariableArray String, list of strings or array of strings that refer to variables to return. Must be in the NCAR lexicon. Returns ------- xr.DataArray ERA5 weather data array from NCAR ERA5 """ nest_asyncio.apply() # Patch asyncio to work in notebooks try: loop = asyncio.get_event_loop() except RuntimeError: # If no event loop exists, create one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # Modify the worker amount loop.set_default_executor( concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) ) xr_array = loop.run_until_complete( asyncio.wait_for(self.fetch(time, variable), timeout=self.async_timeout) ) # Delete cache if needed if not self._cache: shutil.rmtree(self.cache) return xr_array
[docs] async def fetch( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Async function to get data Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for (UTC). variable : str | list[str] | VariableArray String, list of strings or array of strings that refer to variables to return. Must be in the ARCO lexicon. Returns ------- xr.DataArray ERA5 weather data array from ARCO """ time, variable = prep_data_inputs(time, variable) # Create cache dir if doesnt exist Path(self.cache).mkdir(parents=True, exist_ok=True) # Make sure input time is valid self._validate_time(time) # Create tasks and group based on variable data_arrays: dict[str, list[xr.DataArray]] = {} async_tasks = [] for task in self._create_tasks(time, variable).values(): future = self.fetch_wrapper(task) async_tasks.append(future) # Now wait results = await tqdm.gather( *async_tasks, desc="Fetching NCAR ERA5 data", disable=(not self._verbose) ) # Group based on variable for result in results: key = str(result.coords["variable"]) if key not in data_arrays: data_arrays[key] = [] data_arrays[key].append(result) # Concat times for same variable groups array_list = [xr.concat(arrs, dim="time") for arrs in data_arrays.values()] # Now concat varaibles res = xr.concat(array_list, dim="variable") res.name = None # remove name, which is kept from one of the arrays # Delete cache if needed if not self._cache: shutil.rmtree(self.cache) return res.sel(time=time, variable=variable)
def _create_tasks( self, time: list[datetime], variable: list[str] ) -> dict[str, NCARAsyncTask]: """Create download tasks, each corresponding to one file on S3. The H5 file stored in the dataset contains Parameters ---------- times : list[datetime] Timestamps to be downloaded (UTC). variables : list[str] List of variables to be downloaded. Returns ------- list[dict] List of download tasks. """ tasks: dict[str, NCARAsyncTask] = {} # group pressure-level variables s3_pattern = "s3://nsf-ncar-era5/{product}/{year}{month:02}/{product}.{variable}.{grid}.{year}{month:02}{daystart:02}00_{year}{month:02}{dayend:02}23.nc" for i, t in enumerate(time): for j, v in enumerate(variable): ncar_name, _ = NCAR_ERA5Lexicon[v] product = ncar_name.split("::")[0] variable_name = ncar_name.split("::")[1] grid = ncar_name.split("::")[2] level_index = int(ncar_name.split("::")[3]) data_variable = f"{variable_name.split('_')[-1].upper()}" # Pressure is held in daily nc files if product == "e5.oper.an.pl": daystart = t.day dayend = t.day time_index = t.hour # Surface held in monthly else: daystart = 1 dayend = calendar.monthrange(t.year, t.month)[-1] time_index = int( (t - datetime(t.year, t.month, 1)).total_seconds() / 3600 ) file_name = s3_pattern.format( product=product, variable=variable_name, grid=grid, year=t.year, month=t.month, daystart=daystart, dayend=dayend, ) if file_name in tasks: tasks[file_name].ncar_time_indices[time_index] = t tasks[file_name].ncar_level_indices[level_index] = v else: tasks[file_name] = NCARAsyncTask( ncar_file_uri=file_name, ncar_data_variable=data_variable, ncar_time_indices={time_index: t}, ncar_level_indices={level_index: v}, ) return tasks async def fetch_wrapper( self, task: NCARAsyncTask, ) -> xr.DataArray: """Small wrapper to pack arrays into the DataArray""" out = await self.fetch_array( task.ncar_file_uri, task.ncar_data_variable, list(task.ncar_time_indices.keys()), list(task.ncar_level_indices.keys()), ) # Rename levels coord to variable out = out.rename({"level": "variable"}) out = out.assign_coords(variable=list(task.ncar_level_indices.values())) # Shouldnt be needed but just in case out = out.assign_coords(time=list(task.ncar_time_indices.values())) return out async def fetch_array( self, nc_file_uri: str, data_variable: str, time_idx: list[int], level_idx: list[int], ) -> xr.DataArray: """Fetches requested array from remote store Parameters ---------- nc_file_uri : str S3 URI to NetCDF file data_variable : str Data variable name of the array to use in the NetCDF file time_idx : list[int] Time indexes (hours since start time of file) level_idx : list[int] Pressure level indices if applicable, should be same length as time_idx Returns ------- np.ndarray Data """ logger.debug( f"Fetching NCAR ERA5 variable: {data_variable} in file {nc_file_uri}" ) # Here we manually cache the data arrays, this is because fsspec caches the # extracted NetCDF file. Not super optimal, can have some repeat storage given # different level / time indexes sha = hashlib.sha256( ( str(nc_file_uri) + str(data_variable) + str(time_idx) + str(level_idx) ).encode() ) filename = sha.hexdigest() cache_path = os.path.join(self.cache, filename) if os.path.exists(cache_path): ds = await asyncio.to_thread(xr.open_dataarray, cache_path) ds = await asyncio.to_thread(ds.load) else: # New fs every call so we dont block, netcdf reads seems to not support # open_async -> S3AsyncStreamedFile (big sad) fs = s3fs.S3FileSystem(anon=True, asynchronous=False) with fs.open(nc_file_uri, "rb", block_size=4 * 1400 * 720) as f: ds = await asyncio.to_thread( xr.open_dataset, f, engine="h5netcdf", cache=False ) # Sometimes data field have VAR_ prepended if f"VAR_{data_variable}" in ds: data_variable = f"VAR_{data_variable}" if data_variable not in ds: raise ValueError( f"Variable '{data_variable}' or 'VAR_{data_variable}' from task not found in dataset. " + f"Available variables: {list(ds.keys())}." ) # Pressure level variable if "level" in ds.dims: ds = ds.isel(time=list(time_idx), level=list(level_idx))[ data_variable ] # Other product else: ds = ds.isel(time=list(time_idx))[data_variable] ds = ds.expand_dims({"level": [0]}, axis=1) # Load the data, this is the actual download ds = await asyncio.to_thread(ds.load) # Cache nc file if present if self._cache: await asyncio.to_thread(ds.to_netcdf, cache_path, engine="h5netcdf") return ds @classmethod def _validate_time(cls, times: list[datetime]) -> None: """Verify that date time is valid for ERA5 based on offline knowledge. Parameters ---------- times : list[datetime] Timestamps to be downloaded (UTC). """ for time in times: if time < datetime(1940, 1, 1): raise ValueError( f"Requested date time {time} must be after January 1st, 1940 for NCAR ERA5" ) if not (time - datetime(1900, 1, 1)).total_seconds() % 3600 == 0: raise ValueError( f"Requested date time {time} needs to be 1 hour interval for NCAR ERA5" ) @property def cache(self) -> str: """Return appropriate cache location.""" cache_location = os.path.join(datasource_cache_root(), "ncar_era5") if not self._cache: cache_location = os.path.join(cache_location, "tmp_ncar_era5") if not os.path.exists(cache_location): os.makedirs(cache_location, exist_ok=True) return cache_location