earth2studio.models.da.StormCastSDA#
- class earth2studio.models.da.StormCastSDA(regression_model, diffusion_model, means, stds, invariants, hrrr_lat_lim=(273, 785), hrrr_lon_lim=(579, 1219), variables=array(['u10m', 'v10m', 't2m', 'msl', 'u1hl', 'u2hl', 'u3hl', 'u4hl', 'u5hl', 'u6hl', 'u7hl', 'u8hl', 'u9hl', 'u10hl', 'u11hl', 'u13hl', 'u15hl', 'u20hl', 'u25hl', 'u30hl', 'v1hl', 'v2hl', 'v3hl', 'v4hl', 'v5hl', 'v6hl', 'v7hl', 'v8hl', 'v9hl', 'v10hl', 'v11hl', 'v13hl', 'v15hl', 'v20hl', 'v25hl', 'v30hl', 't1hl', 't2hl', 't3hl', 't4hl', 't5hl', 't6hl', 't7hl', 't8hl', 't9hl', 't10hl', 't11hl', 't13hl', 't15hl', 't20hl', 't25hl', 't30hl', 'q1hl', 'q2hl', 'q3hl', 'q4hl', 'q5hl', 'q6hl', 'q7hl', 'q8hl', 'q9hl', 'q10hl', 'q11hl', 'q13hl', 'q15hl', 'q20hl', 'q25hl', 'q30hl', 'Z1hl', 'Z2hl', 'Z3hl', 'Z4hl', 'Z5hl', 'Z6hl', 'Z7hl', 'Z8hl', 'Z9hl', 'Z10hl', 'Z11hl', 'Z13hl', 'Z15hl', 'Z20hl', 'Z25hl', 'Z30hl', 'p1hl', 'p2hl', 'p3hl', 'p4hl', 'p5hl', 'p6hl', 'p7hl', 'p8hl', 'p9hl', 'p10hl', 'p11hl', 'p13hl', 'p15hl', 'p20hl', 'refc'], dtype='<U5'), conditioning_means=None, conditioning_stds=None, conditioning_variables=array(['u10m', 'v10m', 't2m', 'tcwv', 'sp', 'msl', 'u1000', 'u850', 'u500', 'u250', 'v1000', 'v850', 'v500', 'v250', 'z1000', 'z850', 'z500', 'z250', 't1000', 't850', 't500', 't250', 'q1000', 'q850', 'q500', 'q250'], dtype='<U5'), conditioning_data_source=None, time_tolerance=numpy.timedelta64(10, 'm'), sampler_steps=36, sampler_args=None, sda_std_obs=0.1, sda_gamma=0.001)[source]#
StormCast with score-based data assimilation (SDA) using diffusion posterior sampling for convection-allowing regional forecasts. Combines a regression and diffusion model with DPS guidance to assimilate observations during inference. Model time step size is 1 hour, taking as input:
High-resolution (3km) HRRR state over the central United States (99 vars)
High-resolution land-sea mask and orography invariants
Coarse resolution (25km) global state (26 vars)
Point observations for data assimilation
The high-resolution grid is the HRRR Lambert conformal projection. Coarse-resolution inputs are regridded to the HRRR grid internally.
Note
For more information see the following references:
- Parameters:
regression_model (torch.nn.Module) – Deterministic model used to make an initial prediction
diffusion_model (torch.nn.Module) – Generative model correcting the deterministic prediciton
means (torch.Tensor) – Mean value of each input high-resolution variable
stds (torch.Tensor) – Standard deviation of each input high-resolution variable
invariants (torch.Tensor) – Static invariant quantities
hrrr_lat_lim (tuple[int, int], optional) – HRRR grid latitude limits, defaults to be the StormCastV1 region in central United States, by default (273, 785)
hrrr_lon_lim (tuple[int, int], optional) – HRRR grid longitude limits, defaults to be the StormCastV1 region in central United States,, by default (579, 1219)
variables (np.array, optional) – High-resolution variables, by default np.array(VARIABLES)
conditioning_means (torch.Tensor | None, optional) – Means to normalize conditioning data, by default None
conditioning_stds (torch.Tensor | None, optional) – Standard deviations to normalize conditioning data, by default None
conditioning_variables (np.array, optional) – Global variables for conditioning, by default np.array(CONDITIONING_VARIABLES)
conditioning_data_source (DataSource | ForecastSource | None, optional) – Data Source to use for global conditioning. Required for running in iterator mode, by default None
time_tolerance (TimeTolerance, optional) – Time tolerance for filtering observations. Observations within the tolerance window around each requested time will be used for data assimilation, by default np.timedelta64(10, “m”)
sampler_steps (int, optional) – Number of diffusion sampler steps, by default 36
sampler_args (dict[str, float | int] | None, optional) – Arguments to pass to the diffusion sampler, by default None
sda_std_obs (float, optional) – Observation noise standard deviation for DPS guidance, by default 0.1
sda_gamma (float, optional) – SDA scaling factor for DPS guidance, by default 0.001
- __call__(x, obs)[source]#
Runs assimilation model 1 step.
- Parameters:
x (xr.DataArray) – Input state on the HRRR curvilinear grid
obs (pd.DataFrame | None) – Sparse observations DataFrame, or None for no assimilation
- Returns:
Output state one time-step into the future
- Return type:
xr.DataArray
- Raises:
RuntimeError – If conditioning data source is not initialized
- create_generator(x)[source]#
Creates a generator for iterative forecast with data assimilation.
The generator yields forecast states and receives observation DataFrames via
send(). At each step, conditioning data is fetched, observations are mapped to the HRRR grid, and the diffusion model produces the next forecast step.- Parameters:
x (xr.DataArray) – Initial state on the HRRR curvilinear grid
None (pd.DataFrame |) – Observations sent via
generator.send(). PassNonefor steps without assimilation.
- Yields:
xr.DataArray – Forecast state at each time step
- Return type:
Generator[DataArray, DataFrame | None, None]
Example
>>> gen = model.create_generator(x0) >>> state = next(gen) # yields initial state x0 >>> state = gen.send(obs_df) # step 1 with observations >>> state = gen.send(None) # step 2 without observations
- classmethod load_model(package, conditioning_data_source=<earth2studio.data.gfs.GFS_FX object>, time_tolerance=numpy.timedelta64(10, 'm'), sampler_steps=36, sda_std_obs=0.1, sda_gamma=0.001)[source]#
Load assimilation from package
- Parameters:
package (Package) – Package to load model from
conditioning_data_source (DataSource | ForecastSource, optional) – Data source to use for global conditioning, by default GFS_FX
time_tolerance (TimeTolerance, optional) – Time tolerance for filtering observations. Observations within the tolerance window around each requested time will be used for data assimilation, by default np.timedelta64(10, “m”)
sampler_steps (int, optional) – Number of diffusion sampler steps, by default 36
sda_std_obs (float, optional) – Observation noise standard deviation for DPS guidance, by default 0.1
sda_gamma (float, optional) – SDA scaling factor for DPS guidance, by default 0.001
- Returns:
Assimilation model
- Return type:
AssimilationModel