# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zipfile
from collections import OrderedDict
from collections.abc import Generator, Iterator
from pathlib import Path
import numpy as np
import torch
try:
from physicsnemo.models.afno import AFNO
except ImportError:
AFNO = None
from earth2studio.models.auto import AutoModelMixin, Package
from earth2studio.models.batch import batch_coords, batch_func
from earth2studio.models.px.base import PrognosticModel
from earth2studio.models.px.utils import PrognosticMixin
from earth2studio.utils import check_extra_imports, handshake_coords, handshake_dim
from earth2studio.utils.type import CoordSystem
VARIABLES = [
"u10m",
"v10m",
"t2m",
"sp",
"msl",
"t850",
"u1000",
"v1000",
"z1000",
"u850",
"v850",
"z850",
"u500",
"v500",
"z500",
"t500",
"z50",
"r500",
"r850",
"tcwv",
"u100m",
"v100m",
"u250",
"v250",
"z250",
"t250",
]
[docs]
@check_extra_imports("fcn", [AFNO])
class FCN(torch.nn.Module, AutoModelMixin, PrognosticMixin):
"""FourCastNet global prognostic model. Consists of a single model with a time-step
size of 6 hours. FourCastNet operates on 0.25 degree lat-lon grid (south-pole
excluding) equirectangular grid with 26 variables.
Note
----
This model is a retrained version on more atmospgeric variables from the FourCastNet
paper. For additional information see the following resources:
- https://arxiv.org/abs/2202.11214
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/models/modulus_fcn
Parameters
----------
core_model : torch.nn.Module
Core PyTorch model with loaded weights
center : torch.Tensor
Model center normalization tensor of size [26]
scale : torch.Tensor
Model scale normalization tensor of size [26]
"""
def __init__(
self,
core_model: torch.nn.Module,
center: torch.Tensor,
scale: torch.Tensor,
):
super().__init__()
self.model = core_model
self.register_buffer("center", center)
self.register_buffer("scale", scale)
# sphinx - coords start
def input_coords(self) -> CoordSystem:
"""Input coordinate system of the prognostic model
Returns
-------
CoordSystem
Coordinate system dictionary
"""
return OrderedDict(
{
"batch": np.empty(0),
"lead_time": np.array([np.timedelta64(0, "h")]),
"variable": np.array(VARIABLES),
"lat": np.linspace(90, -90, 720, endpoint=False),
"lon": np.linspace(0, 360, 1440, endpoint=False),
}
)
@batch_coords()
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the prognostic model
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = OrderedDict(
{
"batch": np.empty(0),
"lead_time": np.array([np.timedelta64(6, "h")]),
"variable": np.array(VARIABLES),
"lat": np.linspace(90, -90, 720, endpoint=False),
"lon": np.linspace(0, 360, 1440, endpoint=False),
}
)
test_coords = input_coords.copy()
test_coords["lead_time"] = (
test_coords["lead_time"] - input_coords["lead_time"][-1]
)
target_input_coords = self.input_coords()
for i, key in enumerate(target_input_coords):
if key != "batch":
handshake_dim(test_coords, key, i)
handshake_coords(test_coords, target_input_coords, key)
output_coords = output_coords.copy()
output_coords["batch"] = input_coords["batch"]
output_coords["lead_time"] = (
output_coords["lead_time"] + input_coords["lead_time"]
)
return output_coords
# sphinx - coords end
def __str__(
self,
) -> str:
return "fcn"
[docs]
@classmethod
def load_default_package(cls) -> Package:
"""Load prognostic package"""
return Package(
"ngc://models/nvidia/modulus/modulus_fcn@v0.2",
cache_options={
"cache_storage": Package.default_cache("fcn"),
"same_names": True,
},
)
[docs]
@classmethod
@check_extra_imports("fcn", [AFNO])
def load_model(
cls,
package: Package,
) -> PrognosticModel:
"""Load prognostic from package"""
fcn_zip = Path(package.resolve("fcn.zip"))
# Have to manually unzip here. Should not zip checkpoints in the future
with zipfile.ZipFile(fcn_zip, "r") as zip_ref:
zip_ref.extractall(fcn_zip.parent)
model = AFNO.from_checkpoint(str(fcn_zip.parent / Path("fcn/fcn.mdlus")))
model.eval()
local_center = torch.Tensor(
np.load(str(fcn_zip.parent / Path("fcn/global_means.npy")))
)
local_std = torch.Tensor(
np.load(str(fcn_zip.parent / Path("fcn/global_stds.npy")))
)
return cls(model, center=local_center, scale=local_std)
@torch.inference_mode()
def _forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.squeeze(1)
x = (x - self.center) / self.scale
x = self.model(x)
x = self.scale * x + self.center
x = x.unsqueeze(1)
return x
[docs]
@batch_func()
def __call__(
self,
x: torch.Tensor,
coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
"""Runs prognostic model 1 step.
Parameters
----------
x : torch.Tensor
Input tensor
coords : CoordSystem
Input coordinate system
Returns
-------
tuple[torch.Tensor, CoordSystem]
Output tensor and coordinate system 6 hours in the future
"""
output_coords = self.output_coords(coords)
x = self._forward(x)
return x, output_coords
@batch_func()
def _default_generator(
self, x: torch.Tensor, coords: CoordSystem
) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:
coords = coords.copy()
self.output_coords(coords)
yield x, coords
while True:
# Front hook
x, coords = self.front_hook(x, coords)
# Forward is identity operator
coords = self.output_coords(coords)
x = self._forward(x)
# Rear hook
x, coords = self.rear_hook(x, coords)
yield x, coords.copy()
[docs]
def create_iterator(
self, x: torch.Tensor, coords: CoordSystem
) -> Iterator[tuple[torch.Tensor, CoordSystem]]:
"""Creates a iterator which can be used to perform time-integration of the
prognostic model. Will return the initial condition first (0th step).
Parameters
----------
x : torch.Tensor
Input tensor
coords : CoordSystem
Input coordinate system
Yields
------
Iterator[tuple[torch.Tensor, CoordSystem]]
Iterator that generates time-steps of the prognostic model container the
output data tensor and coordinate system dictionary.
"""
yield from self._default_generator(x, coords)