earth2studio.models.dx.CorrDiffTaiwan#

class earth2studio.models.dx.CorrDiffTaiwan(residual_model, regression_model, in_center, in_scale, out_center, out_scale, out_lat, out_lon, number_of_samples=1, number_of_steps=8, solver='euler')[source]#

CorrDiff is a Corrector Diffusion model that learns mappings between low- and high-resolution weather data with high fidelity. This particular model was trained over a particular region near Taiwan.

Note

This model and checkpoint are from Mardani, Morteza, et al. 2023. For more information see the following references:

Parameters:
  • residual_model (torch.nn.Module) – Core pytorch model

  • regression_model (torch.nn.Module) – Core pytorch model

  • in_center (torch.Tensor) – Model input center normalization tensor of size [20,1,1]

  • in_scale (torch.Tensor) – Model input scale normalization tensor of size [20,1,1]

  • out_center (torch.Tensor) – Model output center normalization tensor of size [4,1,1]

  • out_scale (torch.Tensor) – Model output scale normalization tensor of size [4,1,1]

  • out_lat (torch.Tensor) – Output latitude grid of size [448, 448]

  • out_lon (torch.Tensor) – Output longitude grid of size [448, 448]

  • number_of_samples (int, optional) – Number of high resolution samples to draw from diffusion model. Default is 1

  • number_of_steps (int, optional) – Number of langevin diffusion steps during sampling algorithm. Default is 8

  • solver (Literal['euler', 'heun']) – Discretization of diffusion process. Only ‘euler’ and ‘heun’ are supported. Default is ‘euler’

__call__(x, coords)[source]#

Forward pass of diagnostic

Parameters:
  • x (Tensor)

  • coords (OrderedDict[str, ndarray])

Return type:

tuple[Tensor, OrderedDict[str, ndarray]]

classmethod load_default_package()[source]#

Default pre-trained corrdiff model package from Nvidia model registry

Return type:

Package

classmethod load_model(package)[source]#

Load diagnostic from package

Parameters:

package (Package)

Return type:

DiagnosticModel

Examples using earth2studio.models.dx.CorrDiffTaiwan#

Generative Downscaling

Generative Downscaling