# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import OrderedDict
from collections.abc import Generator
from itertools import product
import numpy as np
import pandas as pd
import torch
import xarray as xr
import zarr
from loguru import logger
from earth2studio.data import GFS_FX, HRRR, DataSource, ForecastSource, fetch_data
from earth2studio.models.auto import AutoModelMixin, Package
from earth2studio.models.da.base import AssimilationModel
from earth2studio.models.da.utils import filter_time_range
from earth2studio.utils import (
handshake_coords,
handshake_dim,
handshake_size,
)
from earth2studio.utils.imports import (
OptionalDependencyFailure,
check_optional_dependencies,
)
from earth2studio.utils.time import normalize_time_tolerance
from earth2studio.utils.type import CoordSystem, FrameSchema, TimeTolerance
try:
import cupy as cp
except ImportError:
cp = None
try:
from scipy.spatial import cKDTree
except ImportError:
OptionalDependencyFailure("stormcast")
cKDTree = None
try:
from omegaconf import OmegaConf
from physicsnemo.diffusion.guidance import (
DataConsistencyDPSGuidance,
DPSScorePredictor,
)
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.preconditioners.legacy import EDMPrecond
from physicsnemo.diffusion.samplers import sample
from physicsnemo.models.diffusion_unets import StormCastUNet
except ImportError:
OptionalDependencyFailure("stormcast")
StormCastUNet = None
EDMPreconditioner = None
OmegaConf = None
# Variables used in StormCastV1 paper
VARIABLES = (
["u10m", "v10m", "t2m", "msl"]
+ [
var + str(level)
for var, level in product(
["u", "v", "t", "q", "Z", "p"],
map(
lambda x: str(x) + "hl",
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20, 25, 30],
),
)
if not ((var == "p") and (int(level.replace("hl", "")) > 20))
]
+ [
"refc",
]
)
CONDITIONING_VARIABLES = ["u10m", "v10m", "t2m", "tcwv", "sp", "msl"] + [
var + str(level)
for var, level in product(["u", "v", "z", "t", "q"], [1000, 850, 500, 250])
]
INVARIANTS = ["lsm", "orography"]
[docs]
@check_optional_dependencies()
class StormCastSDA(torch.nn.Module, AutoModelMixin):
"""StormCast with score-based data assimilation (SDA) using diffusion posterior
sampling for convection-allowing regional forecasts. Combines a regression and
diffusion model with DPS guidance to assimilate observations during inference.
Model time step size is 1 hour, taking as input:
- High-resolution (3km) HRRR state over the central United States (99 vars)
- High-resolution land-sea mask and orography invariants
- Coarse resolution (25km) global state (26 vars)
- Point observations for data assimilation
The high-resolution grid is the HRRR Lambert conformal projection.
Coarse-resolution inputs are regridded to the HRRR grid internally.
Note
----
For more information see the following references:
- https://arxiv.org/abs/2408.10958
- https://huggingface.co/nvidia/stormcast-v1-era5-hrrr
- https://arxiv.org/abs/2306.10574
Parameters
----------
regression_model : torch.nn.Module
Deterministic model used to make an initial prediction
diffusion_model : torch.nn.Module
Generative model correcting the deterministic prediciton
means : torch.Tensor
Mean value of each input high-resolution variable
stds : torch.Tensor
Standard deviation of each input high-resolution variable
invariants : torch.Tensor
Static invariant quantities
hrrr_lat_lim : tuple[int, int], optional
HRRR grid latitude limits, defaults to be the StormCastV1 region in central
United States, by default (273, 785)
hrrr_lon_lim : tuple[int, int], optional
HRRR grid longitude limits, defaults to be the StormCastV1 region in central
United States,, by default (579, 1219)
variables : np.array, optional
High-resolution variables, by default np.array(VARIABLES)
conditioning_means : torch.Tensor | None, optional
Means to normalize conditioning data, by default None
conditioning_stds : torch.Tensor | None, optional
Standard deviations to normalize conditioning data, by default None
conditioning_variables : np.array, optional
Global variables for conditioning, by default np.array(CONDITIONING_VARIABLES)
conditioning_data_source : DataSource | ForecastSource | None, optional
Data Source to use for global conditioning. Required for running in iterator mode, by default None
time_tolerance : TimeTolerance, optional
Time tolerance for filtering observations. Observations within the tolerance
window around each requested time will be used for data assimilation,
by default np.timedelta64(10, "m")
sampler_steps : int, optional
Number of diffusion sampler steps, by default 36
sampler_args : dict[str, float | int] | None, optional
Arguments to pass to the diffusion sampler, by default None
sda_std_obs : float, optional
Observation noise standard deviation for DPS guidance, by default 0.1
sda_gamma : float, optional
SDA scaling factor for DPS guidance, by default 0.001
"""
def __init__(
self,
regression_model: torch.nn.Module,
diffusion_model: torch.nn.Module,
means: torch.Tensor,
stds: torch.Tensor,
invariants: torch.Tensor,
hrrr_lat_lim: tuple[int, int] = (273, 785),
hrrr_lon_lim: tuple[int, int] = (579, 1219),
variables: np.array = np.array(VARIABLES),
conditioning_means: torch.Tensor | None = None,
conditioning_stds: torch.Tensor | None = None,
conditioning_variables: np.array = np.array(CONDITIONING_VARIABLES),
conditioning_data_source: DataSource | ForecastSource | None = None,
time_tolerance: TimeTolerance = np.timedelta64(10, "m"),
sampler_steps: int = 36,
sampler_args: dict[str, float | int] | None = None,
sda_std_obs: float = 0.1,
sda_gamma: float = 0.001,
):
super().__init__()
self.regression_model = regression_model
self.diffusion_model = diffusion_model
self.register_buffer("means", means)
self.register_buffer("stds", stds)
self.register_buffer("invariants", invariants)
self.register_buffer("device_buffer", torch.empty(0))
self.sampler_steps = sampler_steps
self.sampler_args = {
"sigma_min": 0.002,
"sigma_max": 800,
"rho": 7,
"S_churn": 0.0,
"S_min": 0.0,
"S_max": float("inf"),
"S_noise": 1,
}
if sampler_args is not None:
self.sampler_args.update(sampler_args)
self._tolerance = normalize_time_tolerance(time_tolerance)
self.sda_std_obs = sda_std_obs
self.sda_dps_norm = 2
self.sda_gamma = sda_gamma
hrrr_lat, hrrr_lon = HRRR.grid()
self.lat = hrrr_lat[
hrrr_lat_lim[0] : hrrr_lat_lim[1], hrrr_lon_lim[0] : hrrr_lon_lim[1]
]
self.lon = hrrr_lon[
hrrr_lat_lim[0] : hrrr_lat_lim[1], hrrr_lon_lim[0] : hrrr_lon_lim[1]
]
self.hrrr_x = HRRR.HRRR_X[hrrr_lon_lim[0] : hrrr_lon_lim[1]]
self.hrrr_y = HRRR.HRRR_Y[hrrr_lat_lim[0] : hrrr_lat_lim[1]]
# Build ordered boundary polygon from 2D grid perimeter for
# point-in-grid testing (top row -> right col -> bottom row -> left col)
self._grid_boundary = np.column_stack(
[
np.concatenate(
[
self.lat[0, :],
self.lat[1:, -1],
self.lat[-1, -2::-1],
self.lat[-2:0:-1, 0],
]
),
np.concatenate(
[
self.lon[0, :],
self.lon[1:, -1],
self.lon[-1, -2::-1],
self.lon[-2:0:-1, 0],
]
),
]
) # [n_boundary, 2] ordered (lat, lon)
# Build a KD-tree over (lat, lon) for efficient nearest-grid-point queries
# TODO: Make cpu and gpu support
self._grid_tree = cKDTree(np.column_stack([self.lat.ravel(), self.lon.ravel()]))
self.variables = variables
self.conditioning_variables = conditioning_variables
self.conditioning_data_source = conditioning_data_source
if conditioning_data_source is None:
warnings.warn(
"No conditioning data source was provided to StormCast, "
+ "set the conditioning_data_source attribute of the model "
+ "before running inference."
)
if conditioning_means is not None:
self.register_buffer("conditioning_means", conditioning_means)
if conditioning_stds is not None:
self.register_buffer("conditioning_stds", conditioning_stds)
@property
def device(self) -> torch.device:
return self.device_buffer.device
def init_coords(self) -> tuple[CoordSystem]:
"""Initialization coordinate system"""
return (
OrderedDict(
{
"time": np.empty(0),
"lead_time": np.array([np.timedelta64(0, "h")]),
"variable": np.array(self.variables),
"hrrr_y": self.hrrr_y,
"hrrr_x": self.hrrr_x,
}
),
)
def input_coords(self) -> tuple[FrameSchema]:
"""Input coordinate system specifying required DataFrame fields."""
return (
FrameSchema(
{
"time": np.empty(0, dtype="datetime64[ns]"),
"lat": np.empty(0, dtype=np.float32),
"lon": np.empty(0, dtype=np.float32),
"observation": np.empty(0, dtype=np.float32),
"variable": np.array(self.variables, dtype=str),
}
),
)
def output_coords(self, input_coords: tuple[CoordSystem]) -> tuple[CoordSystem]:
"""Output coordinate system of the assimilation model
Parameters
----------
input_coords : tuple[CoordSystem]
Coordinates of tensor used to initialize the forecast model.
Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = OrderedDict(
{
"time": np.empty(0),
"lead_time": np.array([np.timedelta64(1, "h")]),
"variable": np.array(self.variables),
"hrrr_y": self.hrrr_y,
"hrrr_x": self.hrrr_x,
}
)
target_input_coords = self.init_coords()[0]
handshake_dim(input_coords[0], "hrrr_x", 4)
handshake_dim(input_coords[0], "hrrr_y", 3)
handshake_dim(input_coords[0], "variable", 2)
# Index coords are arbitrary as long its on the HRRR grid, so just check size
handshake_size(input_coords[0], "hrrr_y", self.lat.shape[0])
handshake_size(input_coords[0], "hrrr_x", self.lat.shape[1])
handshake_coords(input_coords[0], target_input_coords, "variable")
output_coords["time"] = input_coords[0]["time"]
output_coords["lead_time"] = (
output_coords["lead_time"] + input_coords[0]["lead_time"]
)
return (output_coords,)
[docs]
@classmethod
def load_default_package(cls) -> Package:
"""Load assimilation package"""
package = Package(
"hf://nvidia/stormcast-v1-era5-hrrr@6c89a0877a0d6b231033d3b0d8b9828a6f833ed8",
cache_options={
"cache_storage": Package.default_cache("stormcast"),
"same_names": True,
},
)
return package
[docs]
@classmethod
@check_optional_dependencies()
def load_model(
cls,
package: Package,
conditioning_data_source: DataSource | ForecastSource = GFS_FX(verbose=False),
time_tolerance: TimeTolerance = np.timedelta64(10, "m"),
sampler_steps: int = 36,
sda_std_obs: float = 0.1,
sda_gamma: float = 0.001,
) -> AssimilationModel:
"""Load assimilation from package
Parameters
----------
package : Package
Package to load model from
conditioning_data_source : DataSource | ForecastSource, optional
Data source to use for global conditioning, by default GFS_FX
time_tolerance : TimeTolerance, optional
Time tolerance for filtering observations. Observations within the tolerance
window around each requested time will be used for data assimilation,
by default np.timedelta64(10, "m")
sampler_steps : int, optional
Number of diffusion sampler steps, by default 36
sda_std_obs : float, optional
Observation noise standard deviation for DPS guidance, by default 0.1
sda_gamma : float, optional
SDA scaling factor for DPS guidance, by default 0.001
Returns
-------
AssimilationModel
Assimilation model
"""
try:
package.resolve("config.json") # HF tracking download statistics
except FileNotFoundError:
pass
try:
OmegaConf.register_new_resolver("eval", eval)
except ValueError:
# Likely already registered so skip
pass
# load model registry:
config = OmegaConf.load(package.resolve("model.yaml"))
# TODO: remove strict=False once checkpoints/imports updated to new diffusion API
regression = StormCastUNet.from_checkpoint(
package.resolve("StormCastUNet.0.0.mdlus"),
strict=False,
)
diffusion = EDMPrecond.from_checkpoint(
package.resolve("EDMPrecond.0.0.mdlus"),
strict=False,
)
# Load metadata: means, stds, grid
store = zarr.storage.ZipStore(package.resolve("metadata.zarr.zip"), mode="r")
metadata = xr.open_zarr(store, zarr_format=2)
variables = metadata["variable"].values
conditioning_variables = metadata["conditioning_variable"].values
# Expand dims and tensorify normalization buffers
means = torch.from_numpy(metadata["means"].values[None, :, None, None])
stds = torch.from_numpy(metadata["stds"].values[None, :, None, None])
conditioning_means = torch.from_numpy(
metadata["conditioning_means"].values[None, :, None, None]
)
conditioning_stds = torch.from_numpy(
metadata["conditioning_stds"].values[None, :, None, None]
)
# Load invariants
invariants = metadata["invariants"].sel(invariant=config.data.invariants).values
invariants = torch.from_numpy(invariants).repeat(1, 1, 1, 1)
# EDM sampler arguments
if config.sampler_args is not None:
sampler_args = config.sampler_args
else:
sampler_args = {}
return cls(
regression,
diffusion,
means,
stds,
invariants,
variables=variables,
conditioning_means=conditioning_means,
conditioning_stds=conditioning_stds,
conditioning_data_source=conditioning_data_source,
conditioning_variables=conditioning_variables,
time_tolerance=time_tolerance,
sampler_steps=sampler_steps,
sampler_args=sampler_args,
sda_std_obs=sda_std_obs,
sda_gamma=sda_gamma,
)
@torch.no_grad()
def _forward(
self,
x: torch.Tensor,
conditioning: torch.Tensor,
y_obs: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
# Scale data
if "conditioning_means" in self._buffers:
conditioning = conditioning - self.conditioning_means
if "conditioning_stds" in self._buffers:
conditioning = conditioning / self.conditioning_stds
x = (x - self.means) / self.stds
y_obs = (y_obs - self.means) / self.stds
# Run regression model
invariant_tensor = self.invariants.repeat(x.shape[0], 1, 1, 1)
concats = torch.cat((x, conditioning, invariant_tensor), dim=1)
out = self.regression_model(concats)
y_obs = y_obs - out # Convert to residual obs
# Concat for diffusion conditioning
condition = torch.cat((x, out, invariant_tensor), dim=1)
latents = torch.randn_like(x, dtype=torch.float64)
latents = self.sampler_args["sigma_max"] * latents # Initial guess
def _conditional_diffusion(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.diffusion_model(x, t, condition=condition)
scheduler = EDMNoiseScheduler(
sigma_min=self.sampler_args["sigma_min"],
sigma_max=self.sampler_args["sigma_max"],
rho=self.sampler_args["rho"],
)
guidance = DataConsistencyDPSGuidance(
mask=mask,
y=y_obs,
std_y=self.sda_std_obs,
norm=self.sda_dps_norm,
gamma=self.sda_gamma,
sigma_fn=scheduler.sigma,
alpha_fn=scheduler.alpha,
)
score_predictor = DPSScorePredictor(
x0_predictor=_conditional_diffusion,
x0_to_score_fn=scheduler.x0_to_score,
guidances=guidance,
)
denoiser = scheduler.get_denoiser(score_predictor=score_predictor)
edm_out = sample(
denoiser,
latents,
noise_scheduler=scheduler,
num_steps=self.sampler_steps,
solver="edm_stochastic_heun",
solver_options={
"S_churn": self.sampler_args["S_churn"],
"S_min": self.sampler_args["S_min"],
"S_max": self.sampler_args["S_max"],
"S_noise": self.sampler_args["S_noise"],
},
)
out += edm_out
out = out * self.stds + self.means
return out
@staticmethod
def _points_in_polygon(points: np.ndarray, polygon: np.ndarray) -> np.ndarray:
"""Vectorized ray casting point-in-polygon test.
TODO: Improved this (GPU and reduce memory requirement)
make a general purpose util maybe...
Note
----
For more information see the following references:
https://observablehq.com/@tmcw/understanding-point-in-polygon
Parameters
----------
points : np.ndarray
Points to test, shape [n, 2]
polygon : np.ndarray
Ordered polygon vertices, shape [m, 2]
Returns
-------
np.ndarray
Boolean array of shape [n], True if point is inside polygon
"""
px, py = points[:, 0], points[:, 1] # [n]
vx, vy = polygon[:, 0], polygon[:, 1] # [m]
vx_next = np.roll(vx, -1)
vy_next = np.roll(vy, -1)
# For each edge (m) and each point (n), check if horizontal ray crosses
# Broadcasting: [m, 1] vs [1, n] -> [m, n]
crosses = (vy[:, None] > py[None, :]) != (vy_next[:, None] > py[None, :])
dvy = vy_next[:, None] - vy[:, None]
safe_dvy = np.where(dvy == 0, 1.0, dvy) # avoid division by zero; masked later
x_intersect = (vx_next[:, None] - vx[:, None]) * (
py[None, :] - vy[:, None]
) / safe_dvy + vx[:, None]
hits = crosses & (px[None, :] < x_intersect)
# Odd number of crossings = inside
return (np.sum(hits, axis=0) % 2) == 1
def _build_obs_tensors(
self,
obs: pd.DataFrame | None,
request_time: np.datetime64,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
n_var = len(self.variables)
n_hrrr_y, n_hrrr_x = self.lat.shape
y_obs = torch.zeros(
1, n_var, n_hrrr_y, n_hrrr_x, device=device, dtype=torch.float32
)
mask = torch.zeros(
1, n_var, n_hrrr_y, n_hrrr_x, device=device, dtype=torch.float32
)
if obs is None or len(obs) == 0:
return y_obs, mask
# Filter observations within tolerance window
time_filtered = filter_time_range(
obs, request_time, self._tolerance, time_column="time"
)
if len(time_filtered) == 0:
return y_obs, mask
# TODO, make native cudf support
# Convert to pandas if cudf for reliable string/value access
if hasattr(time_filtered, "to_pandas"):
time_filtered = time_filtered.to_pandas()
obs_lat = time_filtered["lat"].values.astype(np.float64)
obs_lon = time_filtered["lon"].values.astype(np.float64)
obs_var = time_filtered["variable"].values
obs_val = time_filtered["observation"].values.astype(np.float32)
# Normalize lon to 0-360 to match HRRR grid
obs_lon = np.where(obs_lon < 0, obs_lon + 360.0, obs_lon)
# Filter observations to those inside the curvilinear grid boundary
# using ray casting point-in-polygon on the precomputed perimeter
obs_points = np.column_stack([obs_lat, obs_lon])
in_grid = self._points_in_polygon(obs_points, self._grid_boundary)
if not in_grid.any():
return y_obs, mask
obs_lat = obs_lat[in_grid]
obs_lon = obs_lon[in_grid]
obs_var = obs_var[in_grid]
obs_val = obs_val[in_grid]
# Find nearest HRRR grid point for each observation using a KD-tree
# to avoid allocating the full [n_obs, n_grid] distance matrix.
_, nearest_flat = self._grid_tree.query(np.column_stack([obs_lat, obs_lon]))
nearest_y = nearest_flat // n_hrrr_x
nearest_x = nearest_flat % n_hrrr_x
# Map variable names to indices
var_to_idx = {str(v): i for i, v in enumerate(self.variables)}
var_indices = np.array([var_to_idx.get(str(v), -1) for v in obs_var])
valid = var_indices >= 0
# Average multiple observations that map to the same grid cell
if valid.any():
vi = torch.tensor(var_indices[valid], device=device, dtype=torch.long)
yi = torch.tensor(nearest_y[valid], device=device, dtype=torch.long)
xi = torch.tensor(nearest_x[valid], device=device, dtype=torch.long)
vals = torch.tensor(obs_val[valid], device=device, dtype=torch.float32)
# Flatten (vi, yi, xi) into a single linear index for scatter ops
flat_idx = vi * (n_hrrr_y * n_hrrr_x) + yi * n_hrrr_x + xi
flat_sum = torch.zeros(
n_var * n_hrrr_y * n_hrrr_x, device=device, dtype=torch.float32
)
flat_cnt = torch.zeros_like(flat_sum)
flat_sum.scatter_add_(0, flat_idx, vals)
flat_cnt.scatter_add_(0, flat_idx, torch.ones_like(vals))
occupied = flat_cnt > 0
flat_avg = torch.where(occupied, flat_sum / flat_cnt, flat_sum)
y_obs[0] = flat_avg.view(n_var, n_hrrr_y, n_hrrr_x)
mask[0] = occupied.float().view(n_var, n_hrrr_y, n_hrrr_x)
return y_obs, mask
def _fetch_and_interp_conditioning(self, x: xr.DataArray) -> xr.DataArray:
"""Fetch global conditioning data and interpolate to HRRR curvilinear grid.
Parameters
----------
x : xr.DataArray
Input state DataArray with time and lead_time coordinates
Returns
-------
xr.DataArray
Conditioning data interpolated onto the HRRR grid
"""
device = self.device
if self.conditioning_data_source is None:
raise RuntimeError(
"StormCast has been called without initializing the model's conditioning_data_source"
)
c: xr.DataArray = fetch_data(
self.conditioning_data_source,
time=x.coords["time"].data,
variable=self.conditioning_variables,
lead_time=x.coords["lead_time"].data,
device=self.device,
legacy=False,
)
# Interpolate conditioning from regular lat/lon grid to HRRR curvilinear grid
if cp is not None and isinstance(c.data, cp.ndarray):
# GPU path: bilinear interpolation using cupy, data stays on GPU
with cp.cuda.Device(device.index or 0):
data = c.data
src_lat = cp.asarray(c.coords["lat"].values, dtype=cp.float64)
src_lon = cp.asarray(c.coords["lon"].values, dtype=cp.float64)
target_lat_cp = cp.asarray(self.lat, dtype=cp.float64)
target_lon_cp = cp.asarray(self.lon, dtype=cp.float64)
# Ensure ascending order for searchsorted (latitude is
# commonly descending in weather data, e.g. 90 -> -90)
if src_lat[-1] < src_lat[0]:
src_lat = src_lat[::-1]
data = data[..., ::-1, :]
if src_lon[-1] < src_lon[0]:
src_lon = src_lon[::-1]
data = data[..., :, ::-1]
# Compute fractional indices via searchsorted (handles
# non-uniform spacing), src_lat and src_lon needs to be acending
lat_idx = cp.searchsorted(src_lat, target_lat_cp.ravel()) - 1
lat_idx = cp.clip(lat_idx, 0, len(src_lat) - 2)
lat_idx = lat_idx.reshape(target_lat_cp.shape)
lon_idx = cp.searchsorted(src_lon, target_lon_cp.ravel()) - 1
lon_idx = cp.clip(lon_idx, 0, len(src_lon) - 2)
lon_idx = lon_idx.reshape(target_lon_cp.shape)
lat0 = lat_idx
lon0 = lon_idx
lat1 = lat0 + 1
lon1 = lon0 + 1
# Fractional weights between grid cells
wlat = (target_lat_cp - src_lat[lat0]) / (src_lat[lat1] - src_lat[lat0])
wlon = (target_lon_cp - src_lon[lon0]) / (src_lon[lon1] - src_lon[lon0])
wlat = cp.clip(wlat, 0.0, 1.0)
wlon = cp.clip(wlon, 0.0, 1.0)
interp_data = (
data[..., lat0, lon0] * (1 - wlat) * (1 - wlon)
+ data[..., lat0, lon1] * (1 - wlat) * wlon
+ data[..., lat1, lon0] * wlat * (1 - wlon)
+ data[..., lat1, lon1] * wlat * wlon
)
c = xr.DataArray(
data=interp_data,
dims=["time", "lead_time", "variable", "hrrr_y", "hrrr_x"],
coords={
"time": c.coords["time"],
"lead_time": c.coords["lead_time"],
"variable": c.coords["variable"],
"hrrr_y": self.hrrr_y,
"hrrr_x": self.hrrr_x,
"lat": (["hrrr_y", "hrrr_x"], self.lat),
"lon": (["hrrr_y", "hrrr_x"], self.lon),
},
)
else:
# CPU path: use xarray's built-in interpolation
target_lat = xr.DataArray(self.lat, dims=["hrrr_y", "hrrr_x"])
target_lon = xr.DataArray(self.lon, dims=["hrrr_y", "hrrr_x"])
c = c.interp(lat=target_lat, lon=target_lon, method="linear")
c = c.assign_coords(
hrrr_y=("hrrr_y", self.hrrr_y),
hrrr_x=("hrrr_x", self.hrrr_x),
lat=(["hrrr_y", "hrrr_x"], self.lat),
lon=(["hrrr_y", "hrrr_x"], self.lon),
)
return c
def _to_output_dataarray(
self,
x_tensor: torch.Tensor,
output_coords: tuple[CoordSystem],
) -> xr.DataArray:
"""Convert output tensor to xr.DataArray with HRRR grid coordinates.
Parameters
----------
x_tensor : torch.Tensor
Output tensor from _forward
output_coords : tuple[CoordSystem]
Output coordinate system from output_coords()
Returns
-------
xr.DataArray
Output DataArray with cupy backend on GPU or numpy on CPU
"""
device = self.device
(oc,) = output_coords
if device.type == "cuda" and cp is not None:
with cp.cuda.Device(device.index or 0):
out_data = cp.asarray(x_tensor.detach())
else:
out_data = x_tensor.detach().cpu().numpy()
return xr.DataArray(
data=out_data,
dims=list(oc.keys()),
coords={
k: ((["hrrr_y", "hrrr_x"], v) if k in ("lat", "lon") else v)
for k, v in oc.items()
}
| {
"lat": (["hrrr_y", "hrrr_x"], self.lat),
"lon": (["hrrr_y", "hrrr_x"], self.lon),
},
)
# NOTE: @torch.inference_mode() is intentionally omitted here.
# DPS guidance requires gradient computation through the denoiser for
# the score correction step; inference_mode would disable those gradients.
[docs]
def __call__(
self,
x: xr.DataArray,
obs: pd.DataFrame | None,
) -> xr.DataArray:
"""Runs assimilation model 1 step.
Parameters
----------
x : xr.DataArray
Input state on the HRRR curvilinear grid
obs : pd.DataFrame | None
Sparse observations DataFrame, or None for no assimilation
Returns
-------
xr.DataArray
Output state one time-step into the future
Raises
------
RuntimeError
If conditioning data source is not initialized
"""
if self.conditioning_data_source is None:
raise RuntimeError(
"StormCast has been called without initializing the model's conditioning_data_source"
)
device = self.device
c = self._fetch_and_interp_conditioning(x)
x_coords = OrderedDict({dim: x.coords[dim].values for dim in x.dims})
output_coords = self.output_coords((x_coords,))
x_tensor = torch.as_tensor(x.data)
c_tensor = torch.as_tensor(c.data)
for j, t in enumerate(x.coords["time"].data):
obs_time = t + output_coords[0]["lead_time"][0]
y_obs, mask = self._build_obs_tensors(obs, obs_time, device)
x_tensor[j, :] = self._forward(x_tensor[j, :], c_tensor[j, :], y_obs, mask)
return self._to_output_dataarray(x_tensor, output_coords)
[docs]
def create_generator(
self, x: xr.DataArray
) -> Generator[xr.DataArray, pd.DataFrame | None, None]:
"""Creates a generator for iterative forecast with data assimilation.
The generator yields forecast states and receives observation DataFrames
via ``send()``. At each step, conditioning data is fetched, observations
are mapped to the HRRR grid, and the diffusion model produces the next
forecast step.
Parameters
----------
x : xr.DataArray
Initial state on the HRRR curvilinear grid
Yields
------
xr.DataArray
Forecast state at each time step
Receives
--------
pd.DataFrame | None
Observations sent via ``generator.send()``. Pass ``None`` for
steps without assimilation.
Example
-------
>>> gen = model.create_generator(x0)
>>> state = next(gen) # yields initial state x0
>>> state = gen.send(obs_df) # step 1 with observations
>>> state = gen.send(None) # step 2 without observations
"""
if self.conditioning_data_source is None:
raise RuntimeError(
"StormCast has been called without initializing the model's "
"conditioning_data_source"
)
# Yield the initial state so the caller can inspect it
obs = yield x
try:
while True:
# Fetch and interpolate conditioning onto HRRR grid
c = self._fetch_and_interp_conditioning(x)
# Compute output coords (advances lead_time by 1h)
x_coords = OrderedDict({dim: x.coords[dim].values for dim in x.dims})
output_coords = self.output_coords((x_coords,))
# Zero-copy conversion to torch tensors
x_tensor = torch.as_tensor(x.data)
c_tensor = torch.as_tensor(c.data)
# Run forward with observations
for j, t in enumerate(x.coords["time"].data):
obs_time = t + output_coords[0]["lead_time"][0]
y_obs, mask = self._build_obs_tensors(obs, obs_time, self.device)
x_tensor[j] = self._forward(x_tensor[j], c_tensor[j], y_obs, mask)
# Build output DataArray and use as next input
x = self._to_output_dataarray(x_tensor, output_coords)
# Yield forecast result and wait for next observations
obs = yield x
except GeneratorExit:
logger.info("StormCast SDA clean up")