earth2studio.models.px.Atlas#
- class earth2studio.models.px.Atlas(autoencoders, autoencoder_processors, model, model_processor, sinterpolant, sinterpolant_sample_steps=60)[source]#
Atlas prognostic model for ERA5 variables on a 0.25° global lat-lon grid.
Atlas consumes two input lead times (t-6h and t) and predicts a single step at t+6h on a 721x1440 latitude-longitude grid.
- Parameters:
autoencoders (nn.ModuleList) – List of autoencoders for the full-resolution physical state.
autoencoder_processors (nn.ModuleList) – List of autoencoder processors for the full-resolution physical state.
model (nn.Module) – Model for the full-resolution physical state.
model_processor (nn.Module) – Model processor for the full-resolution physical state.
sinterpolant (nn.Module) – Stochastic interpolant for the low-resolution latent state.
sinterpolant_sample_steps (int) – Number of steps to sample for the stochastic interpolant.
Warning
This model is expected to use the iterator interface for autoregressive rollouts longer than one step. Iteratively using the
__call__andprep_next_inputmethods will not produce correct results, since the model performs autoregressive timestepping using a full-resolution physical state and an internal low-resolution latent state.- __call__(x, coords)[source]#
Forward pass of the prognostic model, integrating a single 6h step.
- Parameters:
x (torch.Tensor) – Input tensor of shape (…, lead_time, variable, lat, lon) corresponding to the coordinate system. Lead times expected: [-6h, 0h].
coords (CoordSystem) – Coordinate dictionary describing x.
- Returns:
Output tensor advanced to t+6h and its coordinate system.
- Return type:
tuple[torch.Tensor, CoordSystem]
- create_iterator(x, coords)[source]#
Create an iterator that yields the initial state then successive 6h steps.
- Parameters:
x (torch.Tensor) – Initial data tensor on device representing the initial condition.
coords (CoordSystem) – Coordinate system for the initial data tensor.
- Yields:
Iterator[tuple[torch.Tensor, CoordSystem]] – Iterator yielding successive model outputs and their coordinates.
- Return type:
Iterator[tuple[Tensor, OrderedDict[str, ndarray]]]