# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zipfile
from collections import OrderedDict
from collections.abc import Generator, Iterator
from datetime import timedelta
from pathlib import Path
import numpy as np
import torch
import xarray
from earth2studio.models.auto import AutoModelMixin, Package
from earth2studio.models.batch import batch_coords, batch_func
from earth2studio.models.px.base import PrognosticModel
from earth2studio.models.px.utils import PrognosticMixin
from earth2studio.utils import handshake_coords, handshake_dim
from earth2studio.utils.imports import (
OptionalDependencyFailure,
check_optional_dependencies,
)
from earth2studio.utils.time import timearray_to_datetime
from earth2studio.utils.type import CoordSystem
try:
import physicsnemo
from physicsnemo.utils.zenith_angle import cos_zenith_angle
except ImportError:
OptionalDependencyFailure("dlwp")
physicsnemo = None
cos_zenith_angle = None
VARIABLES = ["t850", "z1000", "z700", "z500", "z300", "tcwv", "t2m"]
[docs]
@check_optional_dependencies()
class DLWP(torch.nn.Module, AutoModelMixin, PrognosticMixin):
"""Deep learning weather prediction (DLWP) prognostic model. This is a parsimonious
global forecast model with a time-step size of 6 hours. The core model is a
convolutional encoder-decoder trained on [64,64] cubed sphere data that has an input
of 18 fields (2x7 atmos variables + 4 prescriptive) and outputs 14 fields (2x7 atmos
variables). This implementation provides a wrapper that accepts [721,1440] lat-lon
equirectangular grid of just the atmospheric varaibles as an input for better
compatability with common data sources. Prescriptive fields are added inside the
model wrapper.
Note
----
For more information about this model see:
- https://agupubs.onlinelibrary.wiley.com/doi/epdf/10.1029/2021MS002502
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/models/modulus_dlwp_cubesphere
Parameters
----------
core_model : torch.nn.Module
Core cubed-sphere DLWP model.
landsea_mask : torch.Tensor
Land sea mask in cubed sphere form [6,64,64]
orography : torch.Tensor
Surface geopotential (orography) in cubed sphere form [6,64,64]
latgrid : torch.Tensor
Cubed sphere latitude coordinates [6,64,64]
longrid : torch.Tensor
Cubed sphere longitude coordinates [6,64,64]
cubed_sphere_transform : torch.Tensor
Sparse pytorch tensor to transform equirectangular fields to cubed sphere of
size [24576, 1038240]
cubed_sphere_inverse : torch.Tensor
Sparse pytorch tensor to transform cubed sphere fields to equirectangular of
size [1038240, 24576]
center : torch.Tensor
Model atmospheric variable center normalization tensor of size [1,7,1,1]
scale : torch.Tensor
Model atmospheric variable scale normalization tensor of size [1,7,1,1]
"""
def __init__(
self,
core_model: torch.nn.Module,
landsea_mask: torch.Tensor,
orography: torch.Tensor,
latgrid: torch.Tensor,
longrid: torch.Tensor,
cubed_sphere_transform: torch.Tensor,
cubed_sphere_inverse: torch.Tensor,
center: torch.Tensor,
scale: torch.Tensor,
):
super().__init__()
self.model = core_model
self.register_buffer("latgrid", latgrid)
self.register_buffer("longrid", longrid)
self.register_buffer("center", center)
self.register_buffer("scale", scale)
self.register_buffer("landsea_mask", landsea_mask.unsqueeze(0))
self.register_buffer(
"topographic_height",
(orography.unsqueeze(0).unsqueeze(0) - 3.724e03) / 8.349e03,
)
self.register_buffer("M", cubed_sphere_transform.T)
self.register_buffer("N", cubed_sphere_inverse)
def input_coords(self) -> CoordSystem:
"""Input coordinate system of the prognostic model
Returns
-------
CoordSystem
Coordinate system dictionary
"""
return OrderedDict(
{
"batch": np.empty(0),
"time": np.empty(0),
"lead_time": np.array(
[np.timedelta64(-6, "h"), np.timedelta64(0, "h")]
),
"variable": np.array(VARIABLES),
"lat": np.linspace(90, -90, 721),
"lon": np.linspace(0, 360, 1440, endpoint=False),
}
)
@batch_coords()
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the prognostic model
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = OrderedDict(
{
"batch": np.empty(0),
"time": np.empty(0),
"lead_time": np.array([np.timedelta64(6, "h")]),
"variable": np.array(VARIABLES),
"lat": np.linspace(90, -90, 721),
"lon": np.linspace(0, 360, 1440, endpoint=False),
}
)
test_coords = input_coords.copy()
test_coords["lead_time"] = (
test_coords["lead_time"] - input_coords["lead_time"][-1]
)
target_input_coords = self.input_coords()
for i, key in enumerate(target_input_coords):
handshake_dim(test_coords, key, i)
if key not in ["batch", "time"]:
handshake_coords(test_coords, target_input_coords, key)
# Normal forward pass of DLWP, this method returns two time-steps
output_coords["batch"] = input_coords["batch"]
output_coords["time"] = input_coords["time"]
output_coords["lead_time"] = (
input_coords["lead_time"][-1] + output_coords["lead_time"]
)
return output_coords
[docs]
@classmethod
def load_default_package(cls) -> Package:
"""Default DLWP model package on NGC"""
return Package(
"ngc://models/nvidia/modulus/modulus_dlwp_cubesphere@v0.2",
cache_options={
"cache_storage": Package.default_cache("dlwp"),
"same_names": True,
},
)
[docs]
@classmethod
@check_optional_dependencies()
def load_model(
cls,
package: Package,
) -> PrognosticModel:
"""Load prognostic from package"""
# Ghetto at the moment because NGC files are zipped. This will download zip and
# unpack them then give the cached folder location from which we can then
# access the needed files.
dlwp_zip = Path(package.resolve("dlwp_cubesphere.zip"))
# Have to manually unzip here. Should not zip checkpoints in the future
with zipfile.ZipFile(dlwp_zip, "r") as zip_ref:
zip_ref.extractall(dlwp_zip.parent)
lsm = torch.Tensor(
xarray.open_dataset(
str(dlwp_zip.parent / Path("dlwp/land_sea_mask_rs_cs.nc"))
)["lsm"].values
)
topographic_height = torch.Tensor(
xarray.open_dataset(
str(dlwp_zip.parent / Path("dlwp/geopotential_rs_cs.nc"))
)["z"].values
)
latlon_grids = xarray.open_dataset(
str(dlwp_zip.parent / Path("dlwp/latlon_grid_field_rs_cs.nc"))
)
latgrid = torch.Tensor(latlon_grids["latgrid"].values)
longrid = torch.Tensor(latlon_grids["longrid"].values)
# load maps
input_map_wts = xarray.open_dataset(
str(dlwp_zip.parent / Path("dlwp/map_LL721x1440_CS64.nc"))
)
output_map_wts = xarray.open_dataset(
str(dlwp_zip.parent / Path("dlwp/map_CS64_LL721x1440.nc"))
)
i = input_map_wts.row.values - 1
j = input_map_wts.col.values - 1
data = input_map_wts.S.values
cubed_sphere_transform = torch.sparse_coo_tensor(
np.array((i, j)), data, dtype=torch.float
)
i = output_map_wts.row.values - 1
j = output_map_wts.col.values - 1
data = output_map_wts.S.values
cubed_sphere_inverse = torch.sparse_coo_tensor(
np.array((i, j)), data, dtype=torch.float
)
core_model = physicsnemo.Module.from_checkpoint(
str(dlwp_zip.parent / Path("dlwp/dlwp.mdlus"))
)
center = torch.Tensor(
np.load(str(dlwp_zip.parent / Path("dlwp/global_means.npy")))
)
scale = torch.Tensor(
np.load(str(dlwp_zip.parent / Path("dlwp/global_stds.npy")))
)
return cls(
core_model,
landsea_mask=lsm,
orography=topographic_height,
latgrid=latgrid,
longrid=longrid,
cubed_sphere_transform=cubed_sphere_transform,
cubed_sphere_inverse=cubed_sphere_inverse,
center=center,
scale=scale,
)
def to_cubedsphere(self, x: torch.Tensor) -> torch.Tensor:
"""[721,1440] eqr to [6,64,64] cs"""
x = x.reshape(*x.shape[:-2], -1) @ self.M
x = x.reshape(*x.shape[:-1], 6, 64, 64)
return x
def to_equirectangular(self, x: torch.Tensor) -> torch.Tensor:
"""[6,64,64] cs to [721,1440] eqr"""
input_shape = x.shape[:-3]
x = (self.N @ x.reshape(-1, 6 * 64 * 64).T).T
x = x.reshape(*input_shape, 721, 1440)
return x
def get_cosine_zenith_fields(
self, times: np.array, lead_time: timedelta, device: torch.device | str = "cuda"
) -> torch.Tensor:
"""Creates cosine zenith fields for input time array"""
output = []
for time in timearray_to_datetime(times):
uvcossza = cos_zenith_angle(
time + lead_time,
self.longrid.cpu(),
self.latgrid.cpu(),
)
# Normalize
uvcossza = torch.Tensor(uvcossza).to(device)
uvcossza = torch.clamp(uvcossza, min=0) - 1.0 / np.pi
output.append(uvcossza)
return torch.stack(output, axis=0)
def _prepare_input(self, input: torch.Tensor, coords: CoordSystem) -> torch.Tensor:
"""Prepares input cubed sphere tensor by adding land sea mask, uvcossza and
orography fields to input atmospheric ([14,6,64,64] -> [18,6,64,64])
"""
# Compress batch dim into time
time_array = np.tile(coords["time"], input.shape[0])
input = input.view(-1, *input.shape[2:])
uvcossza_6 = self.get_cosine_zenith_fields(
time_array, timedelta(hours=-6), input.device
).unsqueeze(1)
uvcossza_0 = self.get_cosine_zenith_fields(
time_array, timedelta(hours=0), input.device
).unsqueeze(1)
x = torch.cat([input[:, 0], uvcossza_6, input[:, 1], uvcossza_0], dim=1)
input = torch.cat(
(
x,
self.landsea_mask.repeat(x.shape[0], 1, 1, 1, 1),
self.topographic_height.repeat(x.shape[0], 1, 1, 1, 1),
),
dim=1,
)
return input
def _prepare_output(
self, output: torch.Tensor, coords: CoordSystem
) -> torch.Tensor:
output = torch.split(output, output.shape[1] // 2, dim=1)
# Add lead time dimension back in
output = torch.stack(output, dim=1)
# Add batch dimension back in
output = output.view(-1, coords["time"].shape[0], *output.shape[1:])
return output
@torch.inference_mode()
def _forward(
self,
x: torch.Tensor,
coords: CoordSystem,
) -> torch.Tensor:
center = self.center.unsqueeze(-1)
scale = self.scale.unsqueeze(-1)
x = (x - center) / scale
x = self._prepare_input(x, coords)
x = self.model(x)
x = self._prepare_output(x, coords)
x = scale * x + center
return x
[docs]
@batch_func()
def __call__(
self,
x: torch.Tensor,
coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
"""Runs prognostic model 1 step.
Parameters
----------
x : torch.Tensor
Input tensor
coords : CoordSystem
Input coordinate system
Returns
-------
tuple[torch.Tensor, CoordSystem]
Output tensor and coordinate system 6 hours in the future
"""
output_coords = self.output_coords(coords)
x = self.to_cubedsphere(x)
x = self._forward(x, coords)
x = self.to_equirectangular(x)
return x[:, :, :1], output_coords
@batch_func()
def _default_generator(
self, x: torch.Tensor, coords: CoordSystem
) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:
coords = coords.copy()
self.output_coords(coords)
coords_out = coords.copy()
coords_out["lead_time"] = coords["lead_time"][1:]
yield x[:, :, 1:], coords_out
x = self.to_cubedsphere(x)
while True:
# Front hook
x, coords = self.front_hook(x, coords)
# Forward pass
x = self._forward(x, coords)
coords["lead_time"] = (
coords["lead_time"]
+ 2 * self.output_coords(self.input_coords())["lead_time"]
)
x = x.clone()
# Rear hook for first predicted step
coords_out = coords.copy()
coords_out["lead_time"] = coords["lead_time"][0:1]
x[:, :, :1], coords_out = self.rear_hook(x[:, :, :1], coords_out)
# Output first predicted step
out = self.to_equirectangular(x[:, :, :1])
yield out, coords_out
# Rear hook for second predicted step
coords_out["lead_time"] = coords["lead_time"][-1:]
x[:, :, 1:], coords_out = self.rear_hook(x[:, :, 1:], coords_out)
out = self.to_equirectangular(x[:, :, 1:])
yield out, coords_out
[docs]
def create_iterator(
self, x: torch.Tensor, coords: CoordSystem
) -> Iterator[tuple[torch.Tensor, CoordSystem]]:
"""Creates a iterator which can be used to perform time-integration of the
prognostic model. Will return the initial condition first (0th step).
Parameters
----------
x : torch.Tensor
Input tensor
coords : CoordSystem
Input coordinate system
Yields
------
Iterator[tuple[torch.Tensor, CoordSystem]]
Iterator that generates time-steps of the prognostic model container the
output data tensor and coordinate system dictionary.
"""
yield from self._default_generator(x, coords)