earth2studio.models.px.ACE2ERA5#

class earth2studio.models.px.ACE2ERA5(stepper, forcing_data_source=<earth2studio.data.ace2.ACE2ERA5Data object>, dt=np.timedelta64(6, 'h'))[source]#

ACE2-ERA5 prognostic model wrapper.

ACE2 (Ai2 Climate Emulator v2) is a 450M-parameter autoregressive emulator with 6-hour time steps, 1-degree horizontal resolution, and eight vertical layers that exactly conserves global dry air mass and moisture and can be stepped stably for arbitrarily many steps at about 1500 simulated years per wall-clock day. ACE2-ERA5 was trained on the ERA5 dataset and requires forcing data during rollout (see forcing_data_source parameter). This wrapper makes use of the fme package to run model forward passes.

Parameters:
  • stepper (Stepper) – ACE2-ERA5 fme.ace.stepper.single_module.Stepper instance loaded from a checkpoint.

  • forcing_data_source (DataSource, optional) – Data source providing forcing data during rollout. Must provide all forcing variables described in the ACE2-ERA5 paper, by default ACE2ERA5(mode=”forcing”).

  • dt (numpy.timedelta64, optional) – Model timestep used to advance lead time coordinates, by default 6 hours.

References

Warning

This model may only be used with input data on the GPU device that the model was loaded on. Specifically, the data must be on the same device as whatever torch.cuda.current_device() was set to when the model package was loaded.

__call__(x, coords)[source]#

Runs one prognostic step using fme predict_paired API.

Parameters:
  • x (torch.Tensor) – Input tensor

  • coords (CoordSystem) – Input coordinate system

Returns:

Output tensor and coordinate system 6 hours in the future

Return type:

tuple[torch.Tensor, CoordSystem]

create_iterator(x, coords)[source]#

Creates an iterator to perform time-integration of ACE2ERA5.

Yields the first forecast step, then continues autoregressively by feeding previous outputs as the next prognostic state while fetching/using external forcings under the hood via _forward.

Parameters:
  • x (torch.Tensor) – Input tensor

  • coords (CoordSystem) – Input coordinate system

Returns:

Iterator of output tensors and coordinate systems

Return type:

Iterator[tuple[torch.Tensor, CoordSystem]]

classmethod load_default_package()[source]#

Load default ACE2-ERA5 package from HuggingFace.

Return type:

Package

classmethod load_model(package, forcing_data_source=<earth2studio.data.ace2.ACE2ERA5Data object>, dt=np.timedelta64(6, 'h'))[source]#

Load ACE2-ERA5 prognostic model from a package.

Parameters:
  • package (Package) – Package to load the model checkpoint from.

  • forcing_data_source (DataSource, optional) – External forcing data source. Must provide all forcing variables described in the ACE2-ERA5 paper, by default ACE2ERA5(mode=”forcing”).

  • dt (numpy.timedelta64, optional) – Timestep for advancing lead time coordinates, by default 6 hours.

Returns:

ACE2-ERA5 prognostic model

Return type:

PrognosticModel