earth2studio.models.dx.CorrDiffCMIP6#
- class earth2studio.models.dx.CorrDiffCMIP6(input_variables, output_variables, residual_model, regression_model, lat_input_grid, lon_input_grid, lat_output_grid, lon_output_grid, in_center, in_scale, out_center, out_scale, invariants=None, invariant_center=None, invariant_scale=None, number_of_samples=1, number_of_steps=18, solver='euler', sampler_type='stochastic', inference_mode='both', hr_mean_conditioning=True, seed=None, grid_spacing_tolerance=1e-05, grid_bounds_margin=0.0, sigma_min=None, sigma_max=None, time_feature_center=None, time_feature_scale=None, output_lead_times=array([-12], dtype='timedelta64[h]'))[source]#
CMIP6 to ERA5 downscaling model based on the CorrDiff architecture. This model can be used to downscale both in the spatial and temporal dimensions. This model works with the
earth2studio.data.CMIP6MultiRealmdata source.Note
For more information see the following references:
- Parameters:
input_variables (Sequence[str]) – List of input variable names
output_variables (Sequence[str]) – List of output variable names
residual_model (torch.nn.Module) – Core pytorch model for diffusion step
regression_model (torch.nn.Module) – Core pytorch model for regression step
lat_input_grid (torch.Tensor) – Input latitude grid of size [in_lat]
lon_input_grid (torch.Tensor) – Input longitude grid of size [in_lon]
lat_output_grid (torch.Tensor) – Output latitude grid of size [out_lat]
lon_output_grid (torch.Tensor) – Output longitude grid of size [out_lon]
in_center (torch.Tensor) – Model input center normalization tensor of size [in_var]
in_scale (torch.Tensor) – Model input scale normalization tensor of size [in_var]
out_center (torch.Tensor) – Model output center normalization tensor of size [out_var]
out_scale (torch.Tensor) – Model output scale normalization tensor of size [out_var]
invariants (OrderedDict | None, optional) – Dictionary of invariant features, by default None
invariant_center (torch.Tensor | None, optional) – Model invariant center normalization tensor, by default None
invariant_scale (torch.Tensor | None, optional) – Model invariant scale normalization tensor, by default None
number_of_samples (int, optional) – Number of high resolution samples to draw from diffusion model, by default 1
number_of_steps (int, optional) – Number of langevin diffusion steps during sampling algorithm, by default 18
solver (Literal["euler", "heun"], optional) – Discretization of diffusion process, by default “euler”
sampler_type (Literal["deterministic", "stochastic"], optional) – Type of sampler to use, by default “stochastic”
inference_mode (Literal["regression", "both"], optional) – Which inference mode to use (“both” or “regression”); diffusion-only is not supported in CorrDiffCMIP6. Default is “both”.
hr_mean_conditioning (bool, optional) – Whether to use high-res mean conditioning, by default True
seed (int | None, optional) – Random seed for reproducibility, by default None
grid_spacing_tolerance (float, optional) – Relative tolerance for checking regular grid spacing, by default 1e-5
grid_bounds_margin (float, optional) – Fraction of input grid range to allow for extrapolation, by default 0.0
sigma_min (float | None, optional) – Minimum noise level for diffusion process, by default None
sigma_max (float | None, optional) – Maximum noise level for diffusion process, by default None
time_feature_center (torch.Tensor | None, optional) – Normalization center for time features (sza, hod) of size [2], by default None
time_feature_scale (torch.Tensor | None, optional) – Normalization scale for time features (sza, hod) of size [2], by default None
output_lead_times (LeadTimeArray, optional) – Output lead times to sample at within the input time window. The default package is trained to support lead times between [-12, +11] hours at hourly intervals. This constraint ensures the input data remains aligned with the temporal features (SZA, HOD) calculated at the valid time. By default np.array([np.timedelta64(-12, “h”)])
Examples
Run a single forward pass to predict CMIP6->ERA5 at two lead times within input window
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> model = CorrDiffCMIP6.load_model( ... CorrDiffCMIP6.load_default_package(), ... output_lead_times=np.array([np.timedelta64(-12, "h"), np.timedelta64(-6, "h")]), ... ) >>> model.seed = 1 # Set seed for reprod >>> model.number_of_samples = 1 # Modify number of samples if needed >>> model = model.to(device) >>> >>> # Build CMIP6 multi-realm data source, about 60 Gbs of data will be fetched >>> cmip6_kwargs = dict( ... experiment_id="ssp585", ... source_id="CanESM5", ... variant_label="r1i1p2f1", ... exact_time_match=True, ... ) >>> data = CMIP6MultiRealm([CMIP6(table_id=t, **cmip6_kwargs) for t in ("day", "Eday", "SIday")]) >>> >>> x, coords = fetch_data( ... source=data, ... time=np.array([np.datetime64("2037-09-06T12:00")]), # Time must be 12:00 UTC ... lead_time=model.input_coords()["lead_time"], ... variable=model.input_coords()["variable"], ... device=device, ... ) >>> >>> # Run model forward pass >>> out, out_coords = model(x, coords) >>> da = xr.DataArray(data=out.cpu().numpy(), coords=out_coords, dims=list(out_coords.keys()))
- __call__(x, coords)[source]#
Forward pass of diagnostic
- Parameters:
x (Tensor)
coords (OrderedDict[str, ndarray])
- Return type:
tuple[Tensor, OrderedDict[str, ndarray]]
- classmethod load_model(package, output_lead_times=array([-12], dtype='timedelta64[h]'), device='cpu')[source]#
Load diagnostic from package
- Parameters:
package (Package) – Package containing model weights and configuration
output_lead_times (LeadTimeArray, optional) – Output lead times to sample at, by default np.array([np.timedelta64(-12, “h”)])
device (str, optional) – Device to load model on, by default “cpu”
- Returns:
Diagnostic model
- Return type:
DiagnosticModel