earth2studio.models.px.Atlas#

class earth2studio.models.px.Atlas(autoencoders, autoencoder_processors, model, model_processor, sinterpolant, sinterpolant_sample_steps=60)[source]#

Atlas prognostic model for ERA5 variables on a 0.25° global lat-lon grid.

Atlas consumes two input lead times (t-6h and t) and predicts a single step at t+6h on a 721x1440 latitude-longitude grid.

Parameters:
  • autoencoders (nn.ModuleList) – List of autoencoders for the full-resolution physical state.

  • autoencoder_processors (nn.ModuleList) – List of autoencoder processors for the full-resolution physical state.

  • model (nn.Module) – Model for the full-resolution physical state.

  • model_processor (nn.Module) – Model processor for the full-resolution physical state.

  • sinterpolant (nn.Module) – Stochastic interpolant for the low-resolution latent state.

  • sinterpolant_sample_steps (int) – Number of steps to sample for the stochastic interpolant.

Warning

This model is expected to use the iterator interface for autoregressive rollouts longer than one step. Iteratively using the __call__ and prep_next_input methods will not produce correct results, since the model performs autoregressive timestepping using a full-resolution physical state and an internal low-resolution latent state.

__call__(x, coords)[source]#

Forward pass of the prognostic model, integrating a single 6h step.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (…, lead_time, variable, lat, lon) corresponding to the coordinate system. Lead times expected: [-6h, 0h].

  • coords (CoordSystem) – Coordinate dictionary describing x.

Returns:

Output tensor advanced to t+6h and its coordinate system.

Return type:

tuple[torch.Tensor, CoordSystem]

create_iterator(x, coords)[source]#

Create an iterator that yields the initial state then successive 6h steps.

Parameters:
  • x (torch.Tensor) – Initial data tensor on device representing the initial condition.

  • coords (CoordSystem) – Coordinate system for the initial data tensor.

Yields:

Iterator[tuple[torch.Tensor, CoordSystem]] – Iterator yielding successive model outputs and their coordinates.

Return type:

Iterator[tuple[Tensor, OrderedDict[str, ndarray]]]

classmethod load_default_package()[source]#

Load the default package for the Atlas model.

Return type:

Package

classmethod load_model(package)[source]#

Instantiate and load Atlas from a package.

Parameters:

package (Package)

Return type:

PrognosticModel