GenCastMini#
- class earth2studio.models.px.GenCastMini(
- ckpt,
- diffs_stddev_by_level,
- mean_by_level,
- stddev_by_level,
- min_by_level,
- land_sea_mask,
- geopotential_at_surface,
- sst_nan_mask,
- seed=0,
- jit_compile=True,
- GlobalMRF202440 GB
GenCast Mini diffusion-based weather prediction model.
A stochastic weather prediction model based on conditional diffusion that predicts in 12-hour time steps. This mini variant operates at 1.0-degree (181x360) resolution with 13 pressure levels. The model takes 2 input frames (t-12h and t) and predicts 12 hours ahead.
The mini variant trained on ERA5 reanalysis data (pre-2019), offering significantly lower memory requirements (~16 GB vRAM) compared to the full 0.25-degree operational model. This wrapper runs the model with operational inputs which includes a zero 12hr total precipitation input.
Note
This model is provided by DeepMind. For more information see the following references:
Warning
We encourage users to familiarize themselves with the license restrictions of this model’s checkpoints.
- Parameters:
ckpt (gencast.CheckPoint) – Model checkpoint containing weights and configuration
diffs_stddev_by_level (xr.Dataset) – Standard deviation of differences by level for normalization
mean_by_level (xr.Dataset) – Mean values by level for normalization
stddev_by_level (xr.Dataset) – Standard deviation by level for normalization
min_by_level (xr.Dataset) – Minimum values by level for NaN cleaning
land_sea_mask (np.ndarray) – Land-sea mask on lat-lon grid
geopotential_at_surface (np.ndarray) – Geopotential at surface on lat-lon grid
sst_nan_mask (np.ndarray) – Boolean mask indicating where SST values are NaN (ocean vs land)
seed (int | None, optional) – Random seed for JAX PRNG key used in stochastic sampling. If None, a random seed is generated each time the model is called, producing stochastic forecasts. By default 0.
jit_compile (bool, optional) – JIT-compile the model forward pass, requires 24GB of host RAM. JIT compilation adds a one-time cost (several minutes for the first call) but makes subsequent calls significantly faster, by default True.
- __call__(x, coords)[source]#
Runs prognostic model 1 step.
- Parameters:
x (torch.Tensor) – Input tensor
coords (CoordSystem) – Input coordinate system
- Returns:
Output tensor and coordinate system 12 hours in the future
- Return type:
tuple[torch.Tensor, CoordSystem]
- create_iterator(x, coords)[source]#
Creates a iterator which can be used to perform time-integration of the prognostic model. Will return the initial condition first (0th step).
- Parameters:
x (torch.Tensor) – Input tensor
coords (CoordSystem) – Input coordinate system
- Yields:
Iterator[tuple[torch.Tensor, CoordSystem]] – Iterator that generates time-steps of the prognostic model container the output data tensor and coordinate system dictionary.
- Return type:
Iterator[tuple[Tensor, OrderedDict[str, ndarray]]]
- classmethod load_default_package()[source]#
Load default pre-trained GenCast Mini model package from Google Cloud.
- Returns:
Model package
- Return type:
- classmethod load_model(package, jit_compile=True, seed=0)[source]#
Load prognostic model from package.
- Parameters:
package (Package) – Package to load model from
jit_compile (bool, optional) – JIT-compile the model forward pass with, by default True.
seed (int | None, optional) – Random seed for JAX PRNG key used in stochastic sampling, by default 0.
- Returns:
Prognostic model
- Return type:
PrognosticModel