OrbitGlobalPrecip#

class earth2studio.models.dx.OrbitGlobalPrecip(
core_model,
land_sea_mask,
orography,
lattitude,
landcover,
normalize_mean_lowres,
normalize_std_lowres,
normalize_mean_highres,
normalize_std_highres,
do_tiling,
div,
overlap,
)[source]#
GlobalMRF202540 GB

ORBIT-2 precipitation downscaling model supporting both 9.5m and 126m parameter variants.

Note

This model and checkpoint are from Wang et al. 2025. For more information see the following references:

Note

A few details regarding the model’s variables:

  • The input variables t2m_min and t2m_max are daily minimum and maximum.

  • t2m and sst are combined to represent global surface temperature.

  • The model is fine-tuned for IMERG 24-hour accumulated precipitation (tp24).

Parameters:
  • core_model (torch.nn.Module) – Core pytorch model

  • land_sea_mask (np.ndarray) – Binary land-sea mask at 0.25° resolution, shape (720, 1440). Values are 1 over land, 0 over ocean.

  • orography (np.ndarray) – Surface geopotential height (meters) at 0.25° resolution, shape (720, 1440).

  • lattitude (np.ndarray) – Latitude values broadcast to grid shape (720, 1440), used as a positional encoding input to the model.

  • landcover (np.ndarray) – Land-use / land-cover classification at 0.25° resolution, shape (720, 1440).

  • normalize_mean_lowres (np.lib.npyio.NpzFile) – Per-variable mean values for input normalization. Keys are variable names, values are single-element arrays.

  • normalize_std_lowres (np.lib.npyio.NpzFile) – Per-variable standard deviation values for input normalization. Keys are variable names, values are single-element arrays.

  • normalize_mean_highres (np.lib.npyio.NpzFile) – Per-variable mean values for output denormalization. Keys are variable names, values are single-element arrays.

  • normalize_std_highres (np.lib.npyio.NpzFile) – Per-variable standard deviation values for output denormalization. Keys are variable names, values are single-element arrays.

  • do_tiling (bool) – Boolean to indicate whether tiled inference is performed

  • div (int) – If performing tiling, number of tiles to divide input into

  • overlap (int) – If performing tiling, number of overlap pixels to use during tiled inference

Example

The derived inputs tp24, t2m_max, and t2m_min must be computed from hourly ERA5 fields before calling the model:

>>> import numpy as np
>>> import torch
>>> from earth2studio.data import NCAR_ERA5, prep_data_array
>>> from earth2studio.models.dx import OrbitGlobalPrecip
>>> from earth2studio.utils.time import to_time_array
>>>
>>> package = OrbitGlobalPrecip.load_default_package()
>>> orbit = OrbitGlobalPrecip.load_model(package)
>>> orbit = orbit.to("cuda")
>>> data = NCAR_ERA5()
>>> time = to_time_array([np.datetime64("2023-06-01")])
>>>
>>> # Fetch base variables (all except tp24, t2m_max, t2m_min)
>>> base_vars = orbit.input_coords()["variable"][:-3]
>>> x, coords = prep_data_array(data(time, base_vars), device="cuda")
>>>
>>> # Build past 24-hour precipitation accumulation and t2 max/min.
>>> batch_p = torch.zeros((len(time), 24, 4, x.shape[-2], x.shape[-1]), device="cuda")
>>> for i in range(24):
...     time0 = np.array(time) - np.timedelta64(i, "h")
...     p, _ = prep_data_array(data(time0, ["cp", "lsp", "t2m", "sst"]), device="cuda")
...     batch_p[:, i] = p
>>> total_p_24hr = (batch_p[:, :, 0] + batch_p[:, :, 1]).sum(dim=1).unsqueeze(1)
>>> t2_sst_combined = torch.where(
...     torch.isnan(batch_p[:, :, 3]), batch_p[:, :, 2], batch_p[:, :, 3]
... )
>>> t2_max = t2_sst_combined.max(dim=1).values.unsqueeze(1)
>>> t2_min = t2_sst_combined.min(dim=1).values.unsqueeze(1)
>>> x = torch.cat((x, total_p_24hr, t2_max, t2_min), dim=1)
>>>
>>> input_coords = OrderedDict(
...     {k: v for k, v in orbit.input_coords().items() if k != "batch"}
... )
>>> input_coords["time"] = time
>>> input_coords.move_to_end("time", last=False)
>>> output, output_coords = orbit(x, input_coords)
__call__(x, coords)[source]#

Forward pass of diagnostic

Parameters:
  • x (Tensor)

  • coords (OrderedDict[str, ndarray])

Return type:

tuple[Tensor, OrderedDict[str, ndarray]]

classmethod load_default_package()[source]#

Load prognostic package

Return type:

Package

classmethod load_model(
package,
model_type='global',
model_size='9.5m',
model_variable='precipitation',
)[source]#

Load ORBIT-2 precipitation diagnostic model from package files.

Parameters:
  • package (Package) – Model package containing configuration and checkpoint files.

  • model_type (Literal["global"], optional) – ORBIT-2 model family to load, by default “global”

  • model_size (Literal["9.5m", "126m"], optional) – ORBIT-2 model size variant to load, by default “9.5m”

  • model_variable (Literal["precipitation"], optional) – Target variable variant to load, by default “precipitation”

Returns:

Loaded ORBIT-2 precipitation diagnostic model

Return type:

DiagnosticModel