earth2studio.models.px.ACE2ERA5#
- class earth2studio.models.px.ACE2ERA5(stepper, forcing_data_source=<earth2studio.data.ace2.ACE2ERA5Data object>, dt=np.timedelta64(6, 'h'))[source]#
ACE2-ERA5 prognostic model wrapper.
ACE2 (Ai2 Climate Emulator v2) is a 450M-parameter autoregressive emulator with 6-hour time steps, 1-degree horizontal resolution, and eight vertical layers that exactly conserves global dry air mass and moisture and can be stepped stably for arbitrarily many steps at about 1500 simulated years per wall-clock day. ACE2-ERA5 was trained on the ERA5 dataset and requires forcing data during rollout (see forcing_data_source parameter). This wrapper makes use of the
fmepackage to run model forward passes.- Parameters:
stepper (Stepper) – ACE2-ERA5 fme.ace.stepper.single_module.Stepper instance loaded from a checkpoint.
forcing_data_source (DataSource, optional) – Data source providing forcing data during rollout. Must provide all forcing variables described in the ACE2-ERA5 paper, by default ACE2ERA5(mode=”forcing”).
dt (numpy.timedelta64, optional) – Model timestep used to advance lead time coordinates, by default 6 hours.
References
ACE2-ERA5 paper: https://arxiv.org/abs/2411.11268v1
ACE2 code: ai2cm/ace
Warning
This model may only be used with input data on the GPU device that the model was loaded on. Specifically, the data must be on the same device as whatever
torch.cuda.current_device()was set to when the model package was loaded.- __call__(x, coords)[source]#
Runs one prognostic step using fme predict_paired API.
- Parameters:
x (torch.Tensor) – Input tensor
coords (CoordSystem) – Input coordinate system
- Returns:
Output tensor and coordinate system 6 hours in the future
- Return type:
tuple[torch.Tensor, CoordSystem]
- create_iterator(x, coords)[source]#
Creates an iterator to perform time-integration of ACE2ERA5.
Yields the first forecast step, then continues autoregressively by feeding previous outputs as the next prognostic state while fetching/using external forcings under the hood via _forward.
- Parameters:
x (torch.Tensor) – Input tensor
coords (CoordSystem) – Input coordinate system
- Returns:
Iterator of output tensors and coordinate systems
- Return type:
Iterator[tuple[torch.Tensor, CoordSystem]]
- classmethod load_default_package()[source]#
Load default ACE2-ERA5 package from HuggingFace.
- Return type:
- classmethod load_model(package, forcing_data_source=<earth2studio.data.ace2.ACE2ERA5Data object>, dt=np.timedelta64(6, 'h'))[source]#
Load ACE2-ERA5 prognostic model from a package.
- Parameters:
package (Package) – Package to load the model checkpoint from.
forcing_data_source (DataSource, optional) – External forcing data source. Must provide all forcing variables described in the ACE2-ERA5 paper, by default ACE2ERA5(mode=”forcing”).
dt (numpy.timedelta64, optional) – Timestep for advancing lead time coordinates, by default 6 hours.
- Returns:
ACE2-ERA5 prognostic model
- Return type:
PrognosticModel