Source code for earth2studio.data.rand

# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from collections.abc import Callable
from datetime import datetime, timedelta
from typing import Any

import numpy as np
import pandas as pd
import pyarrow as pa
import xarray as xr

from earth2studio.data.utils import prep_data_inputs, prep_forecast_inputs
from earth2studio.utils.type import FieldArray, LeadTimeArray, TimeArray, VariableArray


[docs] class Random: """A randomly generated normally distributed data. Primarily useful for testing. Parameters ---------- domain_coords: OrderedDict[str, np.ndarray] Domain coordinates that the random data will assume (such as lat, lon). """ def __init__( self, domain_coords: OrderedDict[str, np.ndarray], ): self.domain_coords = domain_coords
[docs] def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Retrieve random gaussian data. Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for. variable : str | list[str] | VariableArray Strings or list of strings that refer to variables to return. Returns ------- xr.DataArray Random data array """ time, variable = prep_data_inputs(time, variable) shape = [len(time), len(variable)] coords = {"time": time, "variable": variable} for key, value in self.domain_coords.items(): shape.append(len(value)) coords[key] = value da = xr.DataArray( data=np.random.randn(*shape), dims=list(coords), coords=coords ) return da
class Random_FX: """A randomly generated normally distributed data. Primarily useful for testing. Parameters ---------- domain_coords: OrderedDict[str, np.ndarray] Domain coordinates that the random data will assume (such as lat, lon). """ def __init__( self, domain_coords: OrderedDict[str, np.ndarray], ): self.domain_coords = domain_coords def __call__( # type: ignore[override] self, time: datetime | list[datetime] | TimeArray, lead_time: timedelta | list[timedelta] | LeadTimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: """Retrieve random gaussian data. Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for. variable : str | list[str] | VariableArray Strings or list of strings that refer to variables to return. Returns ------- xr.DataArray Random data array """ time, lead_time, variable = prep_forecast_inputs(time, lead_time, variable) shape = [len(time), len(lead_time), len(variable)] coords = {"time": time, "lead_time": lead_time, "variable": variable} for key, value in self.domain_coords.items(): shape.append(len(value)) coords[key] = value da = xr.DataArray( data=np.random.randn(*shape), dims=list(coords), coords=coords ) return da
[docs] class RandomDataFrame: """A randomly generated DataFrame source. Primarily useful for testing. Generates random observations at random locations for specified times and variables. Each observation is a point in space-time with a random value. Parameters ---------- n_obs : int, optional Number of random observations to generate per time step, by default 10 tolerance : timedelta | np.timedelta64, optional Time tolerance; observations will be randomly sampled within +/- tolerance of each requested time, by default np.timedelta64(0) schema : pa.Schema | None, optional PyArrow schema to use for data generation. If None, uses default SCHEMA. Data will be generated dynamically based on schema field types, by default None field_generators : dict[str, Callable[[], Any]] | None, optional Dictionary mapping field names to generator functions. These will be merged with the default generators. Default generators include: time, lat, lon, observation, variable. User-provided generators will override defaults, by default None """ SOURCE_ID = "earth2studio.data.RandomDataFrame" SCHEMA = pa.schema( [ pa.field("time", pa.timestamp("ns")), pa.field("lat", pa.float32()), pa.field("lon", pa.float32()), pa.field("observation", pa.float32()), pa.field("variable", pa.string()), ] ) def __init__( self, n_obs: int = 10, tolerance: timedelta | np.timedelta64 = np.timedelta64(0), schema: pa.Schema | None = None, field_generators: dict[str, Callable[[], Any]] | None = None, ): self.n_obs = n_obs # Normalize tolerance to python timedelta if isinstance(tolerance, np.timedelta64): self.tolerance = pd.to_timedelta(tolerance).to_pytimedelta() else: self.tolerance = tolerance # Use provided schema or default class schema self.schema = schema if schema is not None else self.SCHEMA # Default field generators default_generators: dict[str, Callable[[], Any]] = { "time": lambda: None, # Will be set based on obs_time "lat": lambda: np.random.uniform(-90.0, 90.0), "lon": lambda: np.random.uniform(0.0, 360.0), "elev": lambda: np.random.uniform(0.0, 1000.0), "observation": lambda: np.random.randn(), "variable": lambda: None, # Will be set based on variable v } user_generators = field_generators if field_generators is not None else {} self._field_generators = {**default_generators, **user_generators}
[docs] def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, fields: str | list[str] | pa.Schema | FieldArray | None = None, ) -> pd.DataFrame: """Retrieve random observation DataFrame. Parameters ---------- time : datetime | list[datetime] | TimeArray Timestamps to return data for. variable : str | list[str] | VariableArray Strings or list of strings that refer to variables to return. fields : str | list[str] | pa.Schema | FieldArray | None, optional Fields to include in output, by default None (all fields). Returns ------- pd.DataFrame Random observation DataFrame """ time, variable = prep_data_inputs(time, variable) schema = self._resolve_fields(fields) # Generate random observations data = [] for t in time: # Convert time to datetime if needed t_dt = pd.to_datetime(t) for v in variable: for _ in range(self.n_obs): # Sample random time within tolerance range if self.tolerance.total_seconds() > 0: # Generate random time within tolerance window random_offset_seconds = np.random.uniform( -self.tolerance.total_seconds(), self.tolerance.total_seconds(), ) obs_time = t_dt + timedelta(seconds=random_offset_seconds) else: # No tolerance, use exact time obs_time = t_dt # Generate data dynamically based on schema row = {} for field in schema: field_name = field.name # Check if field has a generator function if field_name in self._field_generators: # Special handling for time and variable if field_name == "time": row[field_name] = obs_time elif field_name == "variable": row[field_name] = v else: generator = self._field_generators[field_name] row[field_name] = generator() else: # Field not in dictionary if field.nullable: row[field_name] = None else: raise KeyError( f"Field '{field_name}' not found in field generators " f"and is not nullable. Available generators: " f"{list(self._field_generators.keys())}" ) data.append(row) df = pd.DataFrame(data) df.attrs["source"] = self.SOURCE_ID # Filter to requested fields if fields is not None: field_names = [name for name in schema.names if name in df.columns] df = df[field_names] return df
def _resolve_fields( self, fields: str | list[str] | pa.Schema | FieldArray | None ) -> pa.Schema: """Convert fields parameter into a validated PyArrow schema. Parameters ---------- fields : str | list[str] | pa.Schema | FieldArray | None Field specification. Can be: - None: Returns the full instance schema - str: Single field name to select from schema - list[str]: List of field names to select from schema - pa.Schema: Validated against instance schema for compatibility Returns ------- pa.Schema A PyArrow schema containing only the requested fields """ if fields is None: return self.schema if isinstance(fields, str): fields = [fields] if isinstance(fields, pa.Schema): # Validate provided schema against instance schema for field in fields: if field.name not in self.schema.names: raise KeyError( f"Field '{field.name}' not found in schema. " f"Available fields: {self.schema.names}" ) expected_type = self.schema.field(field.name).type if field.type != expected_type: raise TypeError( f"Field '{field.name}' has type {field.type}, " f"expected {expected_type} from schema" ) return fields # fields is list[str] - select fields from instance schema selected_fields = [] for name in fields: if name not in self.schema.names: raise KeyError( f"Field '{name}' not found in schema. " f"Available fields: {self.schema.names}" ) selected_fields.append(self.schema.field(name)) return pa.schema(selected_fields)