Source code for earth2studio.models.dx.corrdiff

# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import zipfile
from collections import OrderedDict
from collections.abc import Callable
from importlib.metadata import version
from pathlib import Path
from typing import Literal

import numpy as np
import torch
import zarr

try:
    from physicsnemo.models import Module as PhysicsNemoModule
    from physicsnemo.utils.generative import (
        StackedRandomGenerator,
    )
    from physicsnemo.utils.generative import (
        deterministic_sampler as ablation_sampler,
    )
except ImportError:
    PhysicsNemoModule = None
    StackedRandomGenerator = None
    ablation_sampler = None

from earth2studio.models.auto import AutoModelMixin, Package
from earth2studio.models.batch import batch_coords, batch_func
from earth2studio.models.dx.base import DiagnosticModel
from earth2studio.utils import (
    check_extra_imports,
    handshake_coords,
    handshake_dim,
)
from earth2studio.utils.interp import latlon_interpolation_regular
from earth2studio.utils.type import CoordSystem

VARIABLES = [
    "tcwv",
    "z500",
    "t500",
    "u500",
    "v500",
    "z850",
    "t850",
    "u850",
    "v850",
    "t2m",
    "u10m",
    "v10m",
]

OUT_VARIABLES = ["mrr", "t2m", "u10m", "v10m"]


[docs] @check_extra_imports( "corrdiff", [PhysicsNemoModule, StackedRandomGenerator, ablation_sampler] ) class CorrDiffTaiwan(torch.nn.Module, AutoModelMixin): """ CorrDiff is a Corrector Diffusion model that learns mappings between low- and high-resolution weather data with high fidelity. This particular model was trained over a particular region near Taiwan. Note ---- This model and checkpoint are from Mardani, Morteza, et al. 2023. For more information see the following references: - https://arxiv.org/html/2309.15214v - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/models/corrdiff_inference_package Parameters ---------- residual_model : torch.nn.Module Core pytorch model regression_model : torch.nn.Module Core pytorch model in_center : torch.Tensor Model input center normalization tensor of size [20,1,1] in_scale : torch.Tensor Model input scale normalization tensor of size [20,1,1] out_center : torch.Tensor Model output center normalization tensor of size [4,1,1] out_scale : torch.Tensor Model output scale normalization tensor of size [4,1,1] out_lat : torch.Tensor Output latitude grid of size [448, 448] out_lon : torch.Tensor Output longitude grid of size [448, 448] number_of_samples : int, optional Number of high resolution samples to draw from diffusion model. Default is 1 number_of_steps: int, optional Number of langevin diffusion steps during sampling algorithm. Default is 8 solver: Literal['euler', 'heun'] Discretization of diffusion process. Only 'euler' and 'heun' are supported. Default is 'euler' """ def __init__( self, residual_model: torch.nn.Module, regression_model: torch.nn.Module, in_center: torch.Tensor, in_scale: torch.Tensor, out_center: torch.Tensor, out_scale: torch.Tensor, out_lat: torch.Tensor, out_lon: torch.Tensor, number_of_samples: int = 1, number_of_steps: int = 8, solver: Literal["euler", "heun"] = "euler", ): super().__init__() self.residual_model = residual_model self.regression_model = regression_model self.register_buffer("in_center", in_center) self.register_buffer("in_scale", in_scale) self.register_buffer("out_center", out_center) self.register_buffer("out_scale", out_scale) self.register_buffer("out_lat_full", out_lat) self.register_buffer("out_lon_full", out_lon) self.register_buffer("out_lat", out_lat[1:-1, 1:-1]) self.register_buffer("out_lon", out_lon[1:-1, 1:-1]) if not isinstance(number_of_samples, int) and (number_of_samples > 1): raise ValueError("`number_of_samples` must be a positive integer.") if not isinstance(number_of_steps, int) and (number_of_steps > 1): raise ValueError("`number_of_steps` must be a positive integer.") if solver not in ["heun", "euler"]: raise ValueError(f"{solver} is not supported, must be in ['huen', 'euler']") self.number_of_samples = number_of_samples self.number_of_steps = number_of_steps self.solver = solver def input_coords(self) -> CoordSystem: """Input coordinate system""" return OrderedDict( { "batch": np.empty(0), "variable": np.array(VARIABLES), "lat": np.linspace(19.25, 28, 36, endpoint=True), "lon": np.linspace(116, 126, 40, endpoint=False), } ) @batch_coords() def output_coords(self, input_coords: CoordSystem) -> CoordSystem: """Output coordinate system of diagnostic model Parameters ---------- input_coords : CoordSystem Input coordinate system to transform into output_coords by default None, will use self.input_coords. Returns ------- CoordSystem Coordinate system dictionary """ output_coords = OrderedDict( { "batch": np.empty(0), "sample": np.arange(self.number_of_samples), "variable": np.array(OUT_VARIABLES), "lat": self.out_lat.cpu().numpy(), "lon": self.out_lon.cpu().numpy(), } ) target_input_coords = self.input_coords() handshake_dim(input_coords, "lon", 3) handshake_dim(input_coords, "lat", 2) handshake_dim(input_coords, "variable", 1) handshake_coords(input_coords, target_input_coords, "lon") handshake_coords(input_coords, target_input_coords, "lat") handshake_coords(input_coords, target_input_coords, "variable") output_coords["batch"] = input_coords["batch"] return output_coords
[docs] @classmethod def load_default_package(cls) -> Package: """Default pre-trained corrdiff model package from Nvidia model registry""" return Package( "ngc://models/nvidia/modulus/corrdiff_inference_package@1", cache_options={ "cache_storage": Package.default_cache("corrdiff_taiwan"), "same_names": True, }, )
[docs] @classmethod @check_extra_imports( "corrdiff", [PhysicsNemoModule, StackedRandomGenerator, ablation_sampler] ) def load_model(cls, package: Package) -> DiagnosticModel: """Load diagnostic from package""" if StackedRandomGenerator is None or ablation_sampler is None: raise ImportError( "Additional CorrDiff model dependencies are not installed. See install documentation for details." ) checkpoint_zip = Path(package.resolve("corrdiff_inference_package.zip")) # Have to manually unzip here. Should not zip checkpoints in the future with zipfile.ZipFile(checkpoint_zip, "r") as zip_ref: zip_ref.extractall(checkpoint_zip.parent) residual = PhysicsNemoModule.from_checkpoint( str( checkpoint_zip.parent / Path("corrdiff_inference_package/checkpoints/diffusion.mdlus") ) ).eval() regression = PhysicsNemoModule.from_checkpoint( str( checkpoint_zip.parent / Path("corrdiff_inference_package/checkpoints/regression.mdlus") ) ).eval() # Get dataset for lat/lon grid info and centers/stds' try: zarr_version = version("zarr") zarr_major_version = int(zarr_version.split(".")[0]) except Exception: # Fallback to older method if version check fails zarr_major_version = 2 # Assume older version if we can't determine if zarr_major_version >= 3: store = zarr.storage.LocalStore( str( checkpoint_zip.parent / Path( "corrdiff_inference_package/dataset/2023-01-24-cwb-4years_5times.zarr" ) ) ) else: store = zarr.storage.DirectoryStore( str( checkpoint_zip.parent / Path( "corrdiff_inference_package/dataset/2023-01-24-cwb-4years_5times.zarr" ) ) ) root = zarr.group(store) # Get output lat/lon grid out_lat = torch.as_tensor(root["XLAT"][:], dtype=torch.float32) out_lon = torch.as_tensor(root["XLONG"][:], dtype=torch.float32) # get normalization info in_inds = [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19] in_center = ( torch.as_tensor( root["era5_center"][in_inds], dtype=torch.float32, ) .unsqueeze(1) .unsqueeze(1) ) in_scale = ( torch.as_tensor( root["era5_scale"][in_inds], dtype=torch.float32, ) .unsqueeze(1) .unsqueeze(1) ) out_inds = [0, 17, 18, 19] out_center = ( torch.as_tensor( root["cwb_center"][out_inds], dtype=torch.float32, ) .unsqueeze(1) .unsqueeze(1) ) out_scale = ( torch.as_tensor( root["cwb_scale"][out_inds], dtype=torch.float32, ) .unsqueeze(1) .unsqueeze(1) ) return cls( residual, regression, in_center, in_scale, out_center, out_scale, out_lat, out_lon, )
def _interpolate(self, x: torch.Tensor) -> torch.Tensor: """Interpolate from input lat/lon (self.lat, self.lon) onto output lat/lon (self.lat_grid, self.lon_grid) using bilinear interpolation.""" input_coords = self.input_coords() return latlon_interpolation_regular( x, torch.as_tensor(input_coords["lat"], device=x.device, dtype=torch.float32), torch.as_tensor(input_coords["lon"], device=x.device, dtype=torch.float32), self.out_lat_full, self.out_lon_full, )[..., 1:-1, 1:-1] @torch.inference_mode() def _forward(self, x: torch.Tensor) -> torch.Tensor: if self.solver not in ["euler", "heun"]: raise ValueError( f"solver must be either 'euler' or 'heun' but got {self.solver}" ) # Interpolate x = self._interpolate(x) # Add sample dimension x = x.unsqueeze(0) x = (x - self.in_center) / self.in_scale # Create grid channels x1 = np.sin(np.linspace(0, 2 * np.pi, 448)) x2 = np.cos(np.linspace(0, 2 * np.pi, 448)) y1 = np.sin(np.linspace(0, 2 * np.pi, 448)) y2 = np.cos(np.linspace(0, 2 * np.pi, 448)) grid_x1, grid_y1 = np.meshgrid(y1, x1) grid_x2, grid_y2 = np.meshgrid(y2, x2) grid = torch.as_tensor( np.expand_dims( np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 ), dtype=torch.float32, device=x.device, ) # Concat Grids x = torch.cat((x, grid), dim=1) # Repeat for sample size sample_seeds = torch.arange(self.number_of_samples) x = x.repeat(self.number_of_samples, 1, 1, 1) # Create latents rnd = StackedRandomGenerator(x.device, sample_seeds) coord = self.output_coords(self.input_coords()) img_resolution_x = coord["lat"].shape[0] img_resolution_y = coord["lon"].shape[1] latents = rnd.randn( [ self.number_of_samples, self.regression_model.img_out_channels, img_resolution_x, img_resolution_y, ], device=x.device, ) mean = self.unet_regression( self.regression_model, torch.zeros_like(latents), x, num_steps=self.number_of_steps, ) res = ablation_sampler( self.residual_model, latents, x, randn_like=rnd.randn_like, num_steps=self.number_of_steps, solver=self.solver, ) x = mean + res x = self.out_scale * x + self.out_center return x
[docs] @batch_func() def __call__( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: """Forward pass of diagnostic""" output_coords = self.output_coords(coords) out = torch.zeros( [len(v) for v in output_coords.values()], device=x.device, dtype=torch.float32, ) for i in range(out.shape[0]): out[i] = self._forward(x[i]) return out, output_coords
@staticmethod def unet_regression( net: torch.nn.Module, latents: torch.Tensor, img_lr: torch.Tensor, class_labels: torch.Tensor = None, randn_like: Callable = torch.randn_like, num_steps: int = 8, sigma_min: float = 0.0, sigma_max: float = 0.0, rho: int = 7, S_churn: float = 0, S_min: float = 0, S_max: float = float("inf"), S_noise: float = 0.0, ) -> torch.Tensor: """ Perform U-Net regression with temporal sampling. Parameters ---------- net : torch.nn.Module U-Net model for regression. latents : torch.Tensor Latent representation. img_lr : torch.Tensor) Low-resolution input image. class_labels : torch.Tensor, optional Class labels for conditional generation. randn_like : function, optional Function for generating random noise. num_steps : int, optional Number of time steps for temporal sampling. sigma_min : float, optional Minimum noise level. sigma_max : float, optional Maximum noise level. rho : int, optional Exponent for noise level interpolation. S_churn : float, optional Churning parameter. S_min : float, optional Minimum churning value. S_max : float, optional Maximum churning value. S_noise : float, optional Noise level for churning. Returns ------- torch.Tensor: Predicted output at the next time step. """ # Adjust noise levels based on what's supported by the network. sigma_min = max(sigma_min, net.sigma_min) sigma_max = min(sigma_max, net.sigma_max) # Time step discretization. step_indices = torch.arange( num_steps, dtype=torch.float64, device=latents.device ) t_steps = ( sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho t_steps = torch.cat( [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] ) # t_N = 0 # conditioning x_lr = img_lr # Main sampling loop. x_hat = latents.to(torch.float64) * t_steps[0] t_hat = torch.tensor(1.0).to(torch.float64).to(latents.device) x_next = net(x_hat, x_lr, t_hat, class_labels).to(torch.float64) return x_next