OrbitGlobalPrecip#
- class earth2studio.models.dx.OrbitGlobalPrecip(
- core_model,
- land_sea_mask,
- orography,
- lattitude,
- landcover,
- normalize_mean_lowres,
- normalize_std_lowres,
- normalize_mean_highres,
- normalize_std_highres,
- do_tiling,
- div,
- overlap,
- GlobalMRF202540 GB
ORBIT-2 precipitation downscaling model supporting both 9.5m and 126m parameter variants.
Note
This model and checkpoint are from Wang et al. 2025. For more information see the following references:
Note
A few details regarding the model’s variables:
The input variables
t2m_minandt2m_maxare daily minimum and maximum.t2mandsstare combined to represent global surface temperature.The model is fine-tuned for IMERG 24-hour accumulated precipitation (
tp24).
- Parameters:
core_model (torch.nn.Module) – Core pytorch model
land_sea_mask (np.ndarray) – Binary land-sea mask at 0.25° resolution, shape
(720, 1440). Values are 1 over land, 0 over ocean.orography (np.ndarray) – Surface geopotential height (meters) at 0.25° resolution, shape
(720, 1440).lattitude (np.ndarray) – Latitude values broadcast to grid shape
(720, 1440), used as a positional encoding input to the model.landcover (np.ndarray) – Land-use / land-cover classification at 0.25° resolution, shape
(720, 1440).normalize_mean_lowres (np.lib.npyio.NpzFile) – Per-variable mean values for input normalization. Keys are variable names, values are single-element arrays.
normalize_std_lowres (np.lib.npyio.NpzFile) – Per-variable standard deviation values for input normalization. Keys are variable names, values are single-element arrays.
normalize_mean_highres (np.lib.npyio.NpzFile) – Per-variable mean values for output denormalization. Keys are variable names, values are single-element arrays.
normalize_std_highres (np.lib.npyio.NpzFile) – Per-variable standard deviation values for output denormalization. Keys are variable names, values are single-element arrays.
do_tiling (bool) – Boolean to indicate whether tiled inference is performed
div (int) – If performing tiling, number of tiles to divide input into
overlap (int) – If performing tiling, number of overlap pixels to use during tiled inference
Example
The derived inputs
tp24,t2m_max, andt2m_minmust be computed from hourly ERA5 fields before calling the model:>>> import numpy as np >>> import torch >>> from earth2studio.data import NCAR_ERA5, prep_data_array >>> from earth2studio.models.dx import OrbitGlobalPrecip >>> from earth2studio.utils.time import to_time_array >>> >>> package = OrbitGlobalPrecip.load_default_package() >>> orbit = OrbitGlobalPrecip.load_model(package) >>> orbit = orbit.to("cuda") >>> data = NCAR_ERA5() >>> time = to_time_array([np.datetime64("2023-06-01")]) >>> >>> # Fetch base variables (all except tp24, t2m_max, t2m_min) >>> base_vars = orbit.input_coords()["variable"][:-3] >>> x, coords = prep_data_array(data(time, base_vars), device="cuda") >>> >>> # Build past 24-hour precipitation accumulation and t2 max/min. >>> batch_p = torch.zeros((len(time), 24, 4, x.shape[-2], x.shape[-1]), device="cuda") >>> for i in range(24): ... time0 = np.array(time) - np.timedelta64(i, "h") ... p, _ = prep_data_array(data(time0, ["cp", "lsp", "t2m", "sst"]), device="cuda") ... batch_p[:, i] = p >>> total_p_24hr = (batch_p[:, :, 0] + batch_p[:, :, 1]).sum(dim=1).unsqueeze(1) >>> t2_sst_combined = torch.where( ... torch.isnan(batch_p[:, :, 3]), batch_p[:, :, 2], batch_p[:, :, 3] ... ) >>> t2_max = t2_sst_combined.max(dim=1).values.unsqueeze(1) >>> t2_min = t2_sst_combined.min(dim=1).values.unsqueeze(1) >>> x = torch.cat((x, total_p_24hr, t2_max, t2_min), dim=1) >>> >>> input_coords = OrderedDict( ... {k: v for k, v in orbit.input_coords().items() if k != "batch"} ... ) >>> input_coords["time"] = time >>> input_coords.move_to_end("time", last=False) >>> output, output_coords = orbit(x, input_coords)
- __call__(x, coords)[source]#
Forward pass of diagnostic
- Parameters:
x (Tensor)
coords (OrderedDict[str, ndarray])
- Return type:
tuple[Tensor, OrderedDict[str, ndarray]]
- classmethod load_model(
- package,
- model_type='global',
- model_size='9.5m',
- model_variable='precipitation',
Load ORBIT-2 precipitation diagnostic model from package files.
- Parameters:
package (Package) – Model package containing configuration and checkpoint files.
model_type (Literal["global"], optional) – ORBIT-2 model family to load, by default “global”
model_size (Literal["9.5m", "126m"], optional) – ORBIT-2 model size variant to load, by default “9.5m”
model_variable (Literal["precipitation"], optional) – Target variable variant to load, by default “precipitation”
- Returns:
Loaded ORBIT-2 precipitation diagnostic model
- Return type:
DiagnosticModel