Source code for earth2studio.data.ncar

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import calendar
import hashlib
import multiprocessing
import os
import shutil
from datetime import date, datetime
from functools import partial

import numpy as np
import pandas as pd
import s3fs
import xarray as xr
from loguru import logger
from tqdm import tqdm

from earth2studio.data.utils import (
    datasource_cache_root,
    prep_data_inputs,
)
from earth2studio.lexicon import NCAR_ERA5Lexicon
from earth2studio.utils.type import TimeArray, VariableArray

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)


[docs] class NCAR_ERA5: """ERA5 data provided by NSF NCAR via the AWS Open Data Sponsorship Program. ERA5 is the fifth generation of the ECMWF global reanalysis and available on a 0.25 degree WGS84 grid at hourly intervals spanning from 1940 to the present. Parameters ---------- n_workers : int, optional Number of parallel workers, by default 8 cache : bool, optional Cache data source on local memory, by default True verbose : bool, optional Print download progress, by default True Warning ------- This is a remote data source and can potentially download a large amount of data to your local machine for large requests. Note ---- Additional resources: https://registry.opendata.aws/nsf-ncar-era5/ """ def __init__(self, n_workers: int = 8, cache: bool = True, verbose: bool = False): self._cache = cache self._verbose = verbose self._n_workers = n_workers
[docs] def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Retrieve ERA5 data. Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for (UTC). variable : str | list[str] | VariableArray String, list of strings or array of strings that refer to variables to return. Must be in the NCAR_ERA5 lexicon. Returns ------- xr.DataArray ERA5 weather data array """ time, variable = prep_data_inputs(time, variable) self._validate_time(time) data_arrays: dict[str, xr.DataArray] = {} tasks = self._create_tasks(time, variable) logger.debug("Download tasks: {}", str(tasks)) ctx = multiprocessing.get_context("spawn") # s3fs requires spawn or forkserver fn = partial(self._fetch_dataarray, cache_path=self.cache, cache=self._cache) with ctx.Pool(self._n_workers) as p: for ename, arr in tqdm( p.imap_unordered(fn, tasks), "Step", len(tasks), disable=(not self._verbose), ): data_arrays.setdefault(ename, []).append(arr) # Concat time and variable dims array_list = [xr.concat(arrs, dim="time") for arrs in data_arrays.values()] res = xr.concat(array_list, dim="variable", combine_attrs="drop") res.name = None # remove name, which is kept from one of the arrays res = res.transpose("time", "variable", ...) # Delete cache if needed if not self._cache: shutil.rmtree(self.cache) return res.sel(time=time, variable=variable) # reorder to match inputs
@staticmethod def _create_tasks(times: list[datetime], variables: list[str]) -> list[dict]: """Create download tasks, each corresponding to one file on S3. Parameters ---------- times : list[datetime] Timestamps to be downloaded (UTC). variables : list[str] List of variables to be downloaded. Returns ------- list[dict] List of download tasks. """ groups: dict[str, dict] = {} # group pressure-level variables for var in set(variables): spec, _ = NCAR_ERA5Lexicon[var] lvl = spec.split(".")[3] # surface/pressure-level ename = spec.split(".")[4].split("_")[-1] # ECMWF name groups.setdefault( ename, {"var": var if lvl == "sfc" else var[0], "levels": [], "spec": spec}, ) if lvl == "pl": # Collect pressure levels groups[ename]["levels"].append(int(var[1:])) pattern = "s3://nsf-ncar-era5/e5.oper.an.{lvl}/{y}{m:02}/{spec}.{y}{m:02}{d:02}00_{y}{m:02}{dend:02}23.nc" tasks = [] # group tasks by S3 object times_by_day: dict[ date, list[datetime] ] = {} # pressure-level variables are in daily files times_by_month: dict[ date, list[datetime] ] = {} # surface-level variables are in monthly files for dt in times: times_by_day.setdefault(dt.date(), []).append(dt) times_by_month.setdefault(dt.date().replace(day=1), []).append(dt) for ename, group in groups.items(): if len(group["levels"]) == 0: # Surface-level variable, monthly files for month, dts in times_by_month.items(): tasks.append( { "ename": ename, "var": group["var"], # Earth-2 variable specifier "levels": [], "dts": dts, "s3pfx": pattern.format( lvl="sfc", y=month.year, m=month.month, spec=group["spec"], d=1, dend=calendar.monthrange(month.year, month.month)[-1], ), } ) else: # Pressure-level variable, daily files for day, dts in times_by_day.items(): tasks.append( { "ename": ename, "var": group["var"], # Earth-2 variable specifier "levels": group["levels"], "dts": dts, "s3pfx": pattern.format( lvl="pl", y=day.year, m=day.month, spec=group["spec"], d=day.day, dend=day.day, ), } ) return tasks @staticmethod def _fetch_dataarray(task: dict, cache_path: str, cache: bool) -> xr.DataArray: """Retrieve ERA5 data for single group of times/variables. Parameters ---------- task : dict Download task, specifying the variables, times, and S3 location. cache_path: str Locally cache directory cache: bool Cache data source on local memory Returns ------- xr.DataArray ERA5 data for the given group of times/variables. """ ename, var, levels, dts, s3pfx = ( task["ename"], task["var"], task["levels"], task["dts"], task["s3pfx"], ) fs = s3fs.S3FileSystem(anon=True) # Here we manually cache the data arrays, this is because fsspec caches the # entire HDF5 file by default which is large for this data, so instead we manually # cache the slice we need sha = hashlib.sha256( (str(ename) + str(var) + str(levels) + str(dts) + str(s3pfx)).encode() ) filename = sha.hexdigest() cache_path = os.path.join(cache_path, filename) if os.path.exists(cache_path): ds = xr.open_dataarray(cache_path) else: with fs.open(s3pfx, "rb", block_size=1 * 1024 * 1024) as f: ds = xr.open_dataset(f, engine="h5netcdf", cache=False) if ename[0].isalpha(): # ECMWF names starting with a digit xrname = ename.upper() else: xrname = f"VAR_{ename.upper()}" if len(levels) == 0: # Surface-level variable ds = ds.sel(time=dts)[xrname] ds = ds.rename({"latitude": "lat", "longitude": "lon"}) ds = xr.concat([ds], pd.Index([var], name="variable")) ds = ds.load() else: # Pressure-level variable ds = ds.sel(time=dts, level=[float(lvl) for lvl in levels])[xrname] ds = ds.rename(latitude="lat", longitude="lon", level="variable") ds["variable"] = np.array([f"{var}{lvl}" for lvl in levels]) ds = ds.load() # Save to cache, could be better optimized by not saving the coords # For some reason the default netcdf engine was giving errors if cache: ds.to_netcdf(cache_path, engine="h5netcdf") return ename, ds @classmethod def _validate_time(cls, times: list[datetime]) -> None: """Verify that date time is valid for ERA5 based on offline knowledge. Parameters ---------- times : list[datetime] Timestamps to be downloaded (UTC). """ for time in times: if not (time - datetime(1900, 1, 1)).total_seconds() % 3600 == 0: raise ValueError( f"Requested date time {time} needs to be 1 hour interval for ERA5" ) @property def cache(self) -> str: """Return appropriate cache location.""" cache_path = os.path.join(datasource_cache_root(), "ncar_era5") if not os.path.exists(cache_path): os.makedirs(cache_path, exist_ok=True) return cache_path