earth2studio.models.da.StormCastSDA#

class earth2studio.models.da.StormCastSDA(regression_model, diffusion_model, means, stds, invariants, hrrr_lat_lim=(273, 785), hrrr_lon_lim=(579, 1219), variables=array(['u10m', 'v10m', 't2m', 'msl', 'u1hl', 'u2hl', 'u3hl', 'u4hl', 'u5hl', 'u6hl', 'u7hl', 'u8hl', 'u9hl', 'u10hl', 'u11hl', 'u13hl', 'u15hl', 'u20hl', 'u25hl', 'u30hl', 'v1hl', 'v2hl', 'v3hl', 'v4hl', 'v5hl', 'v6hl', 'v7hl', 'v8hl', 'v9hl', 'v10hl', 'v11hl', 'v13hl', 'v15hl', 'v20hl', 'v25hl', 'v30hl', 't1hl', 't2hl', 't3hl', 't4hl', 't5hl', 't6hl', 't7hl', 't8hl', 't9hl', 't10hl', 't11hl', 't13hl', 't15hl', 't20hl', 't25hl', 't30hl', 'q1hl', 'q2hl', 'q3hl', 'q4hl', 'q5hl', 'q6hl', 'q7hl', 'q8hl', 'q9hl', 'q10hl', 'q11hl', 'q13hl', 'q15hl', 'q20hl', 'q25hl', 'q30hl', 'Z1hl', 'Z2hl', 'Z3hl', 'Z4hl', 'Z5hl', 'Z6hl', 'Z7hl', 'Z8hl', 'Z9hl', 'Z10hl', 'Z11hl', 'Z13hl', 'Z15hl', 'Z20hl', 'Z25hl', 'Z30hl', 'p1hl', 'p2hl', 'p3hl', 'p4hl', 'p5hl', 'p6hl', 'p7hl', 'p8hl', 'p9hl', 'p10hl', 'p11hl', 'p13hl', 'p15hl', 'p20hl', 'refc'], dtype='<U5'), conditioning_means=None, conditioning_stds=None, conditioning_variables=array(['u10m', 'v10m', 't2m', 'tcwv', 'sp', 'msl', 'u1000', 'u850', 'u500', 'u250', 'v1000', 'v850', 'v500', 'v250', 'z1000', 'z850', 'z500', 'z250', 't1000', 't850', 't500', 't250', 'q1000', 'q850', 'q500', 'q250'], dtype='<U5'), conditioning_data_source=None, time_tolerance=numpy.timedelta64(10, 'm'), sampler_steps=36, sampler_args=None, sda_std_obs=0.1, sda_gamma=0.001)[source]#

StormCast with score-based data assimilation (SDA) using diffusion posterior sampling for convection-allowing regional forecasts. Combines a regression and diffusion model with DPS guidance to assimilate observations during inference. Model time step size is 1 hour, taking as input:

  • High-resolution (3km) HRRR state over the central United States (99 vars)

  • High-resolution land-sea mask and orography invariants

  • Coarse resolution (25km) global state (26 vars)

  • Point observations for data assimilation

The high-resolution grid is the HRRR Lambert conformal projection. Coarse-resolution inputs are regridded to the HRRR grid internally.

Parameters:
  • regression_model (torch.nn.Module) – Deterministic model used to make an initial prediction

  • diffusion_model (torch.nn.Module) – Generative model correcting the deterministic prediciton

  • means (torch.Tensor) – Mean value of each input high-resolution variable

  • stds (torch.Tensor) – Standard deviation of each input high-resolution variable

  • invariants (torch.Tensor) – Static invariant quantities

  • hrrr_lat_lim (tuple[int, int], optional) – HRRR grid latitude limits, defaults to be the StormCastV1 region in central United States, by default (273, 785)

  • hrrr_lon_lim (tuple[int, int], optional) – HRRR grid longitude limits, defaults to be the StormCastV1 region in central United States,, by default (579, 1219)

  • variables (np.array, optional) – High-resolution variables, by default np.array(VARIABLES)

  • conditioning_means (torch.Tensor | None, optional) – Means to normalize conditioning data, by default None

  • conditioning_stds (torch.Tensor | None, optional) – Standard deviations to normalize conditioning data, by default None

  • conditioning_variables (np.array, optional) – Global variables for conditioning, by default np.array(CONDITIONING_VARIABLES)

  • conditioning_data_source (DataSource | ForecastSource | None, optional) – Data Source to use for global conditioning. Required for running in iterator mode, by default None

  • time_tolerance (TimeTolerance, optional) – Time tolerance for filtering observations. Observations within the tolerance window around each requested time will be used for data assimilation, by default np.timedelta64(10, “m”)

  • sampler_steps (int, optional) – Number of diffusion sampler steps, by default 36

  • sampler_args (dict[str, float | int] | None, optional) – Arguments to pass to the diffusion sampler, by default None

  • sda_std_obs (float, optional) – Observation noise standard deviation for DPS guidance, by default 0.1

  • sda_gamma (float, optional) – SDA scaling factor for DPS guidance, by default 0.001

__call__(x, obs)[source]#

Runs assimilation model 1 step.

Parameters:
  • x (xr.DataArray) – Input state on the HRRR curvilinear grid

  • obs (pd.DataFrame | None) – Sparse observations DataFrame, or None for no assimilation

Returns:

Output state one time-step into the future

Return type:

xr.DataArray

Raises:

RuntimeError – If conditioning data source is not initialized

create_generator(x)[source]#

Creates a generator for iterative forecast with data assimilation.

The generator yields forecast states and receives observation DataFrames via send(). At each step, conditioning data is fetched, observations are mapped to the HRRR grid, and the diffusion model produces the next forecast step.

Parameters:
  • x (xr.DataArray) – Initial state on the HRRR curvilinear grid

  • None (pd.DataFrame |) – Observations sent via generator.send(). Pass None for steps without assimilation.

Yields:

xr.DataArray – Forecast state at each time step

Return type:

Generator[DataArray, DataFrame | None, None]

Example

>>> gen = model.create_generator(x0)
>>> state = next(gen)           # yields initial state x0
>>> state = gen.send(obs_df)    # step 1 with observations
>>> state = gen.send(None)      # step 2 without observations
classmethod load_default_package()[source]#

Load assimilation package

Return type:

Package

classmethod load_model(package, conditioning_data_source=<earth2studio.data.gfs.GFS_FX object>, time_tolerance=numpy.timedelta64(10, 'm'), sampler_steps=36, sda_std_obs=0.1, sda_gamma=0.001)[source]#

Load assimilation from package

Parameters:
  • package (Package) – Package to load model from

  • conditioning_data_source (DataSource | ForecastSource, optional) – Data source to use for global conditioning, by default GFS_FX

  • time_tolerance (TimeTolerance, optional) – Time tolerance for filtering observations. Observations within the tolerance window around each requested time will be used for data assimilation, by default np.timedelta64(10, “m”)

  • sampler_steps (int, optional) – Number of diffusion sampler steps, by default 36

  • sda_std_obs (float, optional) – Observation noise standard deviation for DPS guidance, by default 0.1

  • sda_gamma (float, optional) – SDA scaling factor for DPS guidance, by default 0.001

Returns:

Assimilation model

Return type:

AssimilationModel

Examples using earth2studio.models.da.StormCastSDA#

StormCast Score-Based Data Assimilation

StormCast Score-Based Data Assimilation