GenCastMini#

class earth2studio.models.px.GenCastMini(
ckpt,
diffs_stddev_by_level,
mean_by_level,
stddev_by_level,
min_by_level,
land_sea_mask,
geopotential_at_surface,
sst_nan_mask,
seed=0,
jit_compile=True,
)[source]#
GlobalMRF202440 GB

GenCast Mini diffusion-based weather prediction model.

A stochastic weather prediction model based on conditional diffusion that predicts in 12-hour time steps. This mini variant operates at 1.0-degree (181x360) resolution with 13 pressure levels. The model takes 2 input frames (t-12h and t) and predicts 12 hours ahead.

The mini variant trained on ERA5 reanalysis data (pre-2019), offering significantly lower memory requirements (~16 GB vRAM) compared to the full 0.25-degree operational model. This wrapper runs the model with operational inputs which includes a zero 12hr total precipitation input.

Note

This model is provided by DeepMind. For more information see the following references:

Warning

We encourage users to familiarize themselves with the license restrictions of this model’s checkpoints.

Parameters:
  • ckpt (gencast.CheckPoint) – Model checkpoint containing weights and configuration

  • diffs_stddev_by_level (xr.Dataset) – Standard deviation of differences by level for normalization

  • mean_by_level (xr.Dataset) – Mean values by level for normalization

  • stddev_by_level (xr.Dataset) – Standard deviation by level for normalization

  • min_by_level (xr.Dataset) – Minimum values by level for NaN cleaning

  • land_sea_mask (np.ndarray) – Land-sea mask on lat-lon grid

  • geopotential_at_surface (np.ndarray) – Geopotential at surface on lat-lon grid

  • sst_nan_mask (np.ndarray) – Boolean mask indicating where SST values are NaN (ocean vs land)

  • seed (int | None, optional) – Random seed for JAX PRNG key used in stochastic sampling. If None, a random seed is generated each time the model is called, producing stochastic forecasts. By default 0.

  • jit_compile (bool, optional) – JIT-compile the model forward pass, requires 24GB of host RAM. JIT compilation adds a one-time cost (several minutes for the first call) but makes subsequent calls significantly faster, by default True.

__call__(x, coords)[source]#

Runs prognostic model 1 step.

Parameters:
  • x (torch.Tensor) – Input tensor

  • coords (CoordSystem) – Input coordinate system

Returns:

Output tensor and coordinate system 12 hours in the future

Return type:

tuple[torch.Tensor, CoordSystem]

create_iterator(x, coords)[source]#

Creates a iterator which can be used to perform time-integration of the prognostic model. Will return the initial condition first (0th step).

Parameters:
  • x (torch.Tensor) – Input tensor

  • coords (CoordSystem) – Input coordinate system

Yields:

Iterator[tuple[torch.Tensor, CoordSystem]] – Iterator that generates time-steps of the prognostic model container the output data tensor and coordinate system dictionary.

Return type:

Iterator[tuple[Tensor, OrderedDict[str, ndarray]]]

classmethod load_default_package()[source]#

Load default pre-trained GenCast Mini model package from Google Cloud.

Returns:

Model package

Return type:

Package

classmethod load_model(package, jit_compile=True, seed=0)[source]#

Load prognostic model from package.

Parameters:
  • package (Package) – Package to load model from

  • jit_compile (bool, optional) – JIT-compile the model forward pass with, by default True.

  • seed (int | None, optional) – Random seed for JAX PRNG key used in stochastic sampling, by default 0.

Returns:

Prognostic model

Return type:

PrognosticModel