Source code for earth2studio.data.utils

# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import inspect
import os
import tempfile
import time
import typing
import weakref
from collections import OrderedDict
from collections.abc import Callable
from datetime import datetime, timedelta
from hashlib import sha256
from inspect import signature
from pathlib import Path
from shutil import rmtree
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar

import numpy as np
import torch
import xarray as xr
from fsspec import filesystem
from fsspec.asyn import AsyncFileSystem
from fsspec.callbacks import DEFAULT_CALLBACK
from fsspec.implementations.cache_mapper import create_cache_mapper
from fsspec.implementations.cache_metadata import CacheMetadata
from fsspec.utils import isfilelike
from loguru import logger

from earth2studio.data.base import DataSource, ForecastSource
from earth2studio.utils.interp import LatLonInterpolation
from earth2studio.utils.time import (
    leadtimearray_to_timedelta,
    timearray_to_datetime,
    to_time_array,
)
from earth2studio.utils.type import CoordSystem, LeadTimeArray, TimeArray, VariableArray

if TYPE_CHECKING:
    from fsspec.implementations.cache_mapper import AbstractCacheMapper


[docs] def fetch_data( source: DataSource | ForecastSource, time: TimeArray, variable: VariableArray, lead_time: LeadTimeArray = np.array([np.timedelta64(0, "h")]), device: torch.device = "cpu", interp_to: CoordSystem = None, interp_method: str = "nearest", ) -> tuple[torch.Tensor, CoordSystem]: """Utility function to fetch data for models and load data on the target device. If desired, xarray interpolation/regridding in the spatial domain can be used by passing a target coordinate system via the optional `interp_to` argument. Parameters ---------- source : DataSource The data source to fetch from time : TimeArray Timestamps to return data for (UTC). variable : VariableArray Strings or list of strings that refer to variables to return lead_time : LeadTimeArray, optional Lead times to fetch for each provided time, by default np.array(np.timedelta64(0, "h")) device : torch.device, optional Torch devive to load data tensor to, by default "cpu" interp_to : CoordSystem, optional If provided, the fetched data will be interpolated to the coordinates specified by lat/lon arrays in this CoordSystem interp_method : str Interpolation method to use with xarray (by default 'nearest') Returns ------- tuple[torch.Tensor, CoordSystem] Tuple containing output tensor and coordinate OrderedDict """ sig = signature(source.__call__) if "lead_time" in sig.parameters: # Working with a Forecast Data Source da = source(time, lead_time, variable) # type: ignore else: da = [] for lead in lead_time: adjust_times = np.array([t + lead for t in time], dtype="datetime64[ns]") da0 = source(adjust_times, variable) # type: ignore da0 = da0.expand_dims(dim={"lead_time": 1}, axis=1) da0 = da0.assign_coords(lead_time=np.array([lead], dtype="timedelta64[ns]")) da0 = da0.assign_coords(time=time) da.append(da0) da = xr.concat(da, "lead_time") return prep_data_array( da, device=device, interp_to=interp_to, interp_method=interp_method, )
[docs] def prep_data_array( da: xr.DataArray, device: torch.device = "cpu", interp_to: CoordSystem | None = None, interp_method: str = "nearest", ) -> tuple[torch.Tensor, CoordSystem]: """Prepares a data array from a data source for inference workflows by converting the data array to a torch tensor and the coordinate system to an OrderedDict. If desired, xarray interpolation/regridding in the spatial domain can be used by passing a target coordinate system via the optional `interp_to` argument. Parameters ---------- da : xr.DataArray Input data array device : torch.device, optional Torch devive to load data tensor to, by default "cpu" interp_to : CoordSystem, optional If provided, the fetched data will be interpolated to the coordinates specified by lat/lon arrays in this CoordSystem interp_method : str Interpolation method to use with xarray (by default 'nearest') Returns ------- tuple[torch.Tensor, CoordSystem] Tuple containing output tensor and coordinate OrderedDict """ # Initialize the output CoordSystem out_coords = OrderedDict() for dim in da.coords.dims: if dim in da.coords: out_coords[dim] = np.array(da.coords[dim]) # Fetch data and regrid if necessary if interp_to is not None: if len(interp_to["_lat"].shape) != len(interp_to["_lon"].shape): raise ValueError( "Discrepancy in interpolation coordinates: latitude has different number of dims than longitude" ) if "lat" not in da.dims: # Data source uses curvilinear coordinates if interp_method != "linear": raise ValueError( "fetch_data does not support interpolation methods other than linear when data source has a curvilinear grid" ) interp = LatLonInterpolation( lat_in=da["_lat"].values, lon_in=da["_lon"].values, lat_out=interp_to["_lat"], lon_out=interp_to["_lon"], ).to(device) data = torch.Tensor(da.values).to(device) out = interp(data) # HARD CODE FOR STORMCAST # TODO: FIX THIS BY CORRECTING STORMCAST COORDINATES if "hrrr_y" in out_coords: del out_coords["hrrr_y"] if "hrrr_x" in out_coords: del out_coords["hrrr_x"] else: if len(interp_to["_lat"].shape) > 1 or len(interp_to["_lon"].shape) > 1: # Target grid uses curvilinear coordinates: define internal dims y, x target_lat = xr.DataArray(interp_to["_lat"], dims=["y", "x"]) target_lon = xr.DataArray(interp_to["_lon"], dims=["y", "x"]) else: target_lat = xr.DataArray(interp_to["_lat"], dims=["_lat"]) target_lon = xr.DataArray(interp_to["_lon"], dims=["_lon"]) da = da.interp( lat=target_lat, lon=target_lon, method=interp_method, ) out = torch.Tensor(da.values).to(device) out_coords["_lat"] = interp_to["_lat"] out_coords["_lon"] = interp_to["_lon"] else: out = torch.Tensor(da.values).to(device) if "lat" in da.coords and "lat" not in da.coords.dims: # Curvilinear grid case: lat/lon coords are 2D arrays, not in dims out_coords["lat"] = da.coords["lat"].values out_coords["lon"] = da.coords["lon"].values else: for dim in da.coords.dims: if dim not in ["time", "lead_time", "variable"]: out_coords[dim] = np.array(da.coords[dim]) return out, out_coords
def prep_data_inputs( time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> tuple[list[datetime], list[str]]: """Simple method to pre-process data source inputs into a common form Parameters ---------- time : datetime | list[datetime] | TimeArray Datetime, list of datetimes or array of np.datetime64 to fetch variable : str | list[str] | VariableArray String, list of strings or array of strings that refer to variables Returns ------- tuple[list[datetime], list[str]] Time and variable lists """ if isinstance(variable, str): variable = [variable] if isinstance(time, datetime): time = [time] if isinstance(time, np.ndarray): # np.datetime64 -> datetime time = timearray_to_datetime(time) return time, variable def prep_forecast_inputs( time: datetime | list[datetime] | TimeArray, lead_time: timedelta | list[timedelta] | LeadTimeArray, variable: str | list[str] | VariableArray, ) -> tuple[list[datetime], list[timedelta], list[str]]: """Simple method to pre-process forecast source inputs into a common form Parameters ---------- time : datetime | list[datetime] | TimeArray Datetime, list of datetimes or array of np.datetime64 to fetch lead_time: timedelta | list[timedelta], LeadTimeArray Timedelta, list of timedeltas or array of np.timedelta to fetch variable : str | list[str] | VariableArray String, list of strings or array of strings that refer to variables Returns ------- tuple[list[datetime], list[timedelta], list[str]] Time, lead time, and variable lists """ if isinstance(lead_time, timedelta): lead_time = [lead_time] if isinstance(lead_time, np.ndarray): # np.timedelta64 -> timedelta lead_time = leadtimearray_to_timedelta(lead_time) time, variable = prep_data_inputs(time, variable) return time, lead_time, variable
[docs] def datasource_to_file( file_name: str, source: DataSource, time: list[str] | list[datetime] | TimeArray, variable: VariableArray, lead_time: LeadTimeArray = np.array([np.timedelta64(0, "h")]), backend: Literal["netcdf", "zarr"] = "netcdf", chunks: dict[str, int] = {"variable": 1}, dtype: np.dtype | None = None, backend_kwargs: dict[str, Any] = {}, ) -> None: """Utility function that can be used for building a local data store needed for an inference request. This file can then be used with the :py:class:`earth2studio.data.DataArrayFile` data source to load data from file. This is useful when multiple runs of the same input data is needed. Parameters ---------- file_name : str File name of output NetCDF source : DataSource The original data source to fetch from time : list[str] | list[datetime] | list[np.datetime64] List of time strings, datetimes or np.datetime64 (UTC) variable : VariableArray Strings or list of strings that refer to variables to return lead_time : LeadTimeArray, optional Lead times to fetch for each provided time, by default np.array(np.timedelta64(0, "h")) backend : Literal["netcdf", "zarr"], optional Storage backend to save output file as, by default "netcdf" chunks : dict[str, int], optional Chunk sizes along each dimension, by default {"variable": 1} dtype : np.dtype, optional Data type for storing data backend_kwargs : dict[str, Any], optional Dictionary of keyword arguments forwarded to the underlying ``xarray.DataArray.to_netcdf`` / ``xarray.DataArray.to_zarr`` call depending on the selected backend. """ if isinstance(time, datetime): time = [time] time = to_time_array(time) # Spot check the write location is okay before pull testfile = tempfile.TemporaryFile(dir=Path(file_name).parent.resolve()) testfile.close() # Compile all times for lead in lead_time: adjust_times = np.array([t + lead for t in time], dtype="datetime64[ns]") time = np.concatenate([time, adjust_times], axis=0) time = np.unique(time) # Fetch da = source(time, variable) da = da.assign_coords(time=time) da = da.chunk(chunks=chunks) if dtype is not None: da = da.astype(dtype=dtype) match backend: case "netcdf": da.to_netcdf(file_name, **backend_kwargs) case "zarr": da.to_zarr(file_name, **backend_kwargs) case _: raise ValueError(f"Unsupported backend {backend}")
def datasource_cache_root() -> str: """Returns the root directory for data sources""" default_cache = os.path.join(os.path.expanduser("~"), ".cache", "earth2studio") default_cache = os.environ.get("EARTH2STUDIO_CACHE", default_cache) try: os.makedirs(default_cache, exist_ok=True) except OSError as e: logger.error( f"Failed to create cache folder {default_cache}, check permissions" ) raise e return default_cache def get_msc_filesystem() -> filesystem | None: """This helper function checks if Multi-Storage Client is available and sets up the MSC configuration if needed. Note ---- Can force MSC to not be used with the environment variable EARTH2STUDIO_DISABLE_MSC Returns ------- filesystem | None Returns multi-storage file system if installed, None otherwise """ if str(os.getenv("EARTH2STUDIO_DISABLE_MSC", "")).strip().lower() in ("1", "true"): return None try: from multistorageclient.contrib.async_fs import MultiStorageAsyncFileSystem config_path = Path(__file__).parent / "msc_config.yaml" os.environ["MSC_CONFIG"] = str(config_path) return MultiStorageAsyncFileSystem except ImportError: return None T = TypeVar("T") @typing.no_type_check class AsyncCachingFileSystem(AsyncFileSystem): """Async locally caching filesystem, layer over any other FS for use with zarr 3.0. Presently extremely limited, much is just copied from the WholeFileCache Parameters ---------- target_protocol: str (optional) Target filesystem protocol. Provide either this or ``fs``. cache_storage: str or list(str) Location to store files. If "TMP", this is a temporary directory, and will be cleaned up by the OS when this process ends (or later). If a list, each location will be tried in the order given, but only the last will be considered writable. cache_check: int Number of seconds between reload of cache metadata check_files: bool Whether to explicitly see if the UID of the remote file matches the stored one before using. Warning: some file systems such as HTTP cannot reliably give a unique hash of the contents of some path, so be sure to set this option to False. expiry_time: int The time in seconds after which a local copy is considered useless. Set to falsy to prevent expiry. The default is equivalent to one week. target_options: dict or None Passed to the instantiation of the FS, if fs is None. fs: filesystem instance The target filesystem to run against. Provide this or ``protocol``. same_names: bool (optional) By default, target URLs are hashed using a ``HashCacheMapper`` so that files from different backends with the same basename do not conflict. If this argument is ``true``, a ``BasenameCacheMapper`` is used instead. Other cache mapper options are available by using the ``cache_mapper`` keyword argument. Only one of this and ``cache_mapper`` should be specified. compression: str (optional) To decompress on download. Can be 'infer' (guess from the URL name), one of the entries in ``fsspec.compression.compr``, or None for no decompression. cache_mapper: AbstractCacheMapper (optional) The object use to map from original filenames to cached filenames. Only one of this and ``same_names`` should be specified. """ protocol: ClassVar[str | tuple[str, ...]] = ("blockcache", "cached") @typing.no_type_check def __init__( self, target_protocol=None, cache_storage="TMP", cache_check=10, check_files=False, expiry_time=604800, target_options=None, fs=None, same_names: bool | None = None, compression=None, cache_mapper: AbstractCacheMapper | None = None, **kwargs, ): super().__init__(**kwargs) if fs is None and target_protocol is None: raise ValueError( "Please provide filesystem instance(fs) or target_protocol" ) if not (fs is None) ^ (target_protocol is None): raise ValueError( "Both filesystems (fs) and target_protocol may not be both given." ) if cache_storage == "TMP": tempdir = tempfile.mkdtemp() storage = [tempdir] weakref.finalize(self, self._remove_tempdir, tempdir) else: if isinstance(cache_storage, str): storage = [cache_storage] else: storage = cache_storage os.makedirs(storage[-1], exist_ok=True) self.storage = storage self.kwargs = target_options or {} self.cache_check = cache_check self.check_files = check_files self.expiry = expiry_time self.compression = compression # Size of cache in bytes. If None then the size is unknown and will be # recalculated the next time cache_size() is called. On writes to the # cache this is reset to None. self._cache_size = None if same_names is not None and cache_mapper is not None: raise ValueError( "Cannot specify both same_names and cache_mapper in " "CachingFileSystem.__init__" ) if cache_mapper is not None: self._mapper = cache_mapper else: self._mapper = create_cache_mapper( same_names if same_names is not None else False ) self.target_protocol = ( target_protocol if isinstance(target_protocol, str) else (fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0]) ) self._metadata = CacheMetadata(self.storage) self.load_cache() self.fs = fs if fs is not None else filesystem(target_protocol, **self.kwargs) if not self.fs.async_impl: raise ValueError("Underlying filesystem needs to be async") def _strip_protocol(path): # acts as a method, since each instance has a difference target return self.fs._strip_protocol(type(self)._strip_protocol(path)) self._strip_protocol: Callable = _strip_protocol @typing.no_type_check @staticmethod def _remove_tempdir(tempdir): try: rmtree(tempdir) except Exception: # noqa: S110 pass @typing.no_type_check def _mkcache(self): os.makedirs(self.storage[-1], exist_ok=True) @typing.no_type_check def cache_size(self): """Return size of cache in bytes. If more than one cache directory is in use, only the size of the last one (the writable cache directory) is returned. """ if self._cache_size is None: cache_dir = self.storage[-1] self._cache_size = filesystem("file").du(cache_dir, withdirs=True) return self._cache_size @typing.no_type_check def load_cache(self): """Read set of stored blocks from file""" self._metadata.load() self._mkcache() self.last_cache = time.time() @typing.no_type_check def save_cache(self): """Save set of stored blocks from file""" self._mkcache() self._metadata.save() self.last_cache = time.time() self._cache_size = None @typing.no_type_check def _check_cache(self): """Reload caches if time elapsed or any disappeared""" self._mkcache() if not self.cache_check: # explicitly told not to bother checking return timecond = time.time() - self.last_cache > self.cache_check existcond = all(os.path.exists(storage) for storage in self.storage) if timecond or not existcond: self.load_cache() @typing.no_type_check def _check_file(self, path): """Is path in cache and still valid""" path = self._strip_protocol(path) self._check_cache() return self._metadata.check_file(path, self) @typing.no_type_check def clear_cache(self): """Remove all files and metadata from the cache In the case of multiple cache locations, this clears only the last one, which is assumed to be the read/write one. """ rmtree(self.storage[-1]) self.load_cache() self._cache_size = None @typing.no_type_check def clear_expired_cache(self, expiry_time=None): """Remove all expired files and metadata from the cache In the case of multiple cache locations, this clears only the last one, which is assumed to be the read/write one. Parameters ---------- expiry_time: int The time in seconds after which a local copy is considered useless. If not defined the default is equivalent to the attribute from the file caching instantiation. """ if not expiry_time: expiry_time = self.expiry self._check_cache() expired_files, writable_cache_empty = self._metadata.clear_expired(expiry_time) for fn in expired_files: if os.path.exists(fn): os.remove(fn) if writable_cache_empty: rmtree(self.storage[-1]) self.load_cache() self._cache_size = None @typing.no_type_check def pop_from_cache(self, path): """Remove cached version of given file Deletes local copy of the given (remote) path. If it is found in a cache location which is not the last, it is assumed to be read-only, and raises PermissionError """ path = self._strip_protocol(path) fn = self._metadata.pop_file(path) if fn is not None: os.remove(fn) self._cache_size = None @typing.no_type_check def _parent(self, path): return self.fs._parent(path) @typing.no_type_check async def _ukey(self, path): """Hash of file properties, to tell if it has changed""" return sha256(str(await self.fs._info(path)).encode()).hexdigest() @typing.no_type_check async def _make_local_details(self, path): """Create file detail dictionary for given file path""" hash = self._mapper(path) fn = os.path.join(self.storage[-1], hash) detail = { "original": path, "fn": hash, "blocks": True, "time": time.time(), "uid": await self._ukey(path), } self._metadata.update_file(path, detail) logger.debug(f"Copying {path} to local cache") return fn @typing.no_type_check async def _cat_file( self, path, start=None, end=None, on_error="raise", callback=DEFAULT_CALLBACK, **kwargs, ): """Cat file, this is what is used inside Zarr 3.0 at the moment""" getpath = None try: # Check if file is in cache # This doesnt do anything fancy for chunked data, different chunk ranges # are stored in different files even if there is over lap path_chunked = path if start: path_chunked += f"_{start}" if end: path_chunked += f"_{end}" detail = self._check_file(path_chunked) if not detail: storepath = await self._make_local_details(path_chunked) getpath = path else: detail, storepath = ( detail if isinstance(detail, tuple) else (None, detail) ) except Exception as e: if on_error == "raise": raise if on_error == "return": out = e # If file was not in cache, get it using the base file system if getpath: resp = await self.fs._cat_file(getpath, start=start, end=end) # Save to file if isfilelike(storepath): outfile = storepath else: outfile = open(storepath, "wb") # noqa: ASYNC101, ASYNC230 try: outfile.write(resp) # IDK yet # callback.relative_update(len(chunk)) finally: if not isfilelike(storepath): outfile.close() self.save_cache() # Call back is weird here, like the progress should be on the file fetch # but then how do we deal with the read? Maybe we times it by 2x? if start is None: start = 0 callback.set_size(1) with open(storepath, "rb") as f: f.seek(start) if end is not None: out = f.read(end - start) else: out = f.read() callback.relative_update(1) return out @typing.no_type_check def __getattribute__(self, item): # TODO: Update if item in { "load_cache", # "_open", "save_cache", # "close_and_update", "__init__", "__getattribute__", "__reduce__", "_make_local_details", "_ukey", "open", "cat", "_cat_file", "cat_ranges", "get", "read_block", "tail", "head", # "info", # "ls", "exists", "isfile", "isdir", "_check_file", "_check_cache", "_mkcache", "clear_cache", "clear_expired_cache", "pop_from_cache", "local_file", "_paths_from_path", "get_mapper", "open_many", "commit_many", "hash_name", "__hash__", "__eq__", "to_json", "to_dict", "cache_size", "pipe_file", "pipe", "start_transaction", "end_transaction", }: # all the methods defined in this class. Note `open` here, since # it calls `_open`, but is actually in superclass return lambda *args, **kw: getattr(type(self), item).__get__(self)( *args, **kw ) if item in ["__reduce_ex__"]: raise AttributeError if item in ["transaction"]: # property return type(self).transaction.__get__(self) if item in ["_cache", "transaction_type"]: # class attributes return getattr(type(self), item) if item == "__class__": return type(self) d = object.__getattribute__(self, "__dict__") fs = d.get("fs", None) # fs is not immediately defined if item in d: return d[item] elif fs is not None: if item in fs.__dict__: # attribute of instance return fs.__dict__[item] # attributed belonging to the target filesystem cls = type(fs) m = getattr(cls, item) if (inspect.isfunction(m) or inspect.isdatadescriptor(m)) and ( not hasattr(m, "__self__") or m.__self__ is None ): # instance method return m.__get__(fs, cls) return m # class method or attribute else: # attributes of the superclass, while target is being set up return super().__getattribute__(item)