earth2studio.models.dx.CorrDiffCMIP6#

class earth2studio.models.dx.CorrDiffCMIP6(input_variables, output_variables, residual_model, regression_model, lat_input_grid, lon_input_grid, lat_output_grid, lon_output_grid, in_center, in_scale, out_center, out_scale, invariants=None, invariant_center=None, invariant_scale=None, number_of_samples=1, number_of_steps=18, solver='euler', sampler_type='stochastic', inference_mode='both', hr_mean_conditioning=True, seed=None, grid_spacing_tolerance=1e-05, grid_bounds_margin=0.0, sigma_min=None, sigma_max=None, time_feature_center=None, time_feature_scale=None, output_lead_times=array([-12], dtype='timedelta64[h]'))[source]#

CMIP6 to ERA5 downscaling model based on the CorrDiff architecture. This model can be used to downscale both in the spatial and temporal dimensions. This model works with the earth2studio.data.CMIP6MultiRealm data source.

Note

For more information see the following references:

Parameters:
  • input_variables (Sequence[str]) – List of input variable names

  • output_variables (Sequence[str]) – List of output variable names

  • residual_model (torch.nn.Module) – Core pytorch model for diffusion step

  • regression_model (torch.nn.Module) – Core pytorch model for regression step

  • lat_input_grid (torch.Tensor) – Input latitude grid of size [in_lat]

  • lon_input_grid (torch.Tensor) – Input longitude grid of size [in_lon]

  • lat_output_grid (torch.Tensor) – Output latitude grid of size [out_lat]

  • lon_output_grid (torch.Tensor) – Output longitude grid of size [out_lon]

  • in_center (torch.Tensor) – Model input center normalization tensor of size [in_var]

  • in_scale (torch.Tensor) – Model input scale normalization tensor of size [in_var]

  • out_center (torch.Tensor) – Model output center normalization tensor of size [out_var]

  • out_scale (torch.Tensor) – Model output scale normalization tensor of size [out_var]

  • invariants (OrderedDict | None, optional) – Dictionary of invariant features, by default None

  • invariant_center (torch.Tensor | None, optional) – Model invariant center normalization tensor, by default None

  • invariant_scale (torch.Tensor | None, optional) – Model invariant scale normalization tensor, by default None

  • number_of_samples (int, optional) – Number of high resolution samples to draw from diffusion model, by default 1

  • number_of_steps (int, optional) – Number of langevin diffusion steps during sampling algorithm, by default 18

  • solver (Literal["euler", "heun"], optional) – Discretization of diffusion process, by default “euler”

  • sampler_type (Literal["deterministic", "stochastic"], optional) – Type of sampler to use, by default “stochastic”

  • inference_mode (Literal["regression", "both"], optional) – Which inference mode to use (“both” or “regression”); diffusion-only is not supported in CorrDiffCMIP6. Default is “both”.

  • hr_mean_conditioning (bool, optional) – Whether to use high-res mean conditioning, by default True

  • seed (int | None, optional) – Random seed for reproducibility, by default None

  • grid_spacing_tolerance (float, optional) – Relative tolerance for checking regular grid spacing, by default 1e-5

  • grid_bounds_margin (float, optional) – Fraction of input grid range to allow for extrapolation, by default 0.0

  • sigma_min (float | None, optional) – Minimum noise level for diffusion process, by default None

  • sigma_max (float | None, optional) – Maximum noise level for diffusion process, by default None

  • time_feature_center (torch.Tensor | None, optional) – Normalization center for time features (sza, hod) of size [2], by default None

  • time_feature_scale (torch.Tensor | None, optional) – Normalization scale for time features (sza, hod) of size [2], by default None

  • output_lead_times (LeadTimeArray, optional) – Output lead times to sample at within the input time window. The default package is trained to support lead times between [-12, +11] hours at hourly intervals. This constraint ensures the input data remains aligned with the temporal features (SZA, HOD) calculated at the valid time. By default np.array([np.timedelta64(-12, “h”)])

Examples

Run a single forward pass to predict CMIP6->ERA5 at two lead times within input window

>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> model = CorrDiffCMIP6.load_model(
...     CorrDiffCMIP6.load_default_package(),
...     output_lead_times=np.array([np.timedelta64(-12, "h"), np.timedelta64(-6, "h")]),
... )
>>> model.seed = 1 # Set seed for reprod
>>> model.number_of_samples = 1 # Modify number of samples if needed
>>> model = model.to(device)
>>>
>>> # Build CMIP6 multi-realm data source, about 60 Gbs of data will be fetched
>>> cmip6_kwargs = dict(
...     experiment_id="ssp585",
...     source_id="CanESM5",
...     variant_label="r1i1p2f1",
...     exact_time_match=True,
... )
>>> data = CMIP6MultiRealm([CMIP6(table_id=t, **cmip6_kwargs) for t in ("day", "Eday", "SIday")])
>>>
>>> x, coords = fetch_data(
...     source=data,
...     time=np.array([np.datetime64("2037-09-06T12:00")]), # Time must be 12:00 UTC
...     lead_time=model.input_coords()["lead_time"],
...     variable=model.input_coords()["variable"],
...     device=device,
... )
>>>
>>> # Run model forward pass
>>> out, out_coords = model(x, coords)
>>> da = xr.DataArray(data=out.cpu().numpy(), coords=out_coords, dims=list(out_coords.keys()))
__call__(x, coords)[source]#

Forward pass of diagnostic

Parameters:
  • x (Tensor)

  • coords (OrderedDict[str, ndarray])

Return type:

tuple[Tensor, OrderedDict[str, ndarray]]

classmethod load_default_package()[source]#

Load diagnostic package

Return type:

Package

classmethod load_model(package, output_lead_times=array([-12], dtype='timedelta64[h]'), device='cpu')[source]#

Load diagnostic from package

Parameters:
  • package (Package) – Package containing model weights and configuration

  • output_lead_times (LeadTimeArray, optional) – Output lead times to sample at, by default np.array([np.timedelta64(-12, “h”)])

  • device (str, optional) – Device to load model on, by default “cpu”

Returns:

Diagnostic model

Return type:

DiagnosticModel