# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zipfile
from collections import OrderedDict
from collections.abc import Generator, Iterator
from pathlib import Path
import numpy as np
import torch
from earth2studio.models.auto import AutoModelMixin, Package
from earth2studio.models.batch import batch_coords, batch_func
from earth2studio.models.px.base import PrognosticModel
from earth2studio.models.px.utils import PrognosticMixin
from earth2studio.utils import handshake_coords, handshake_dim
from earth2studio.utils.imports import (
OptionalDependencyFailure,
check_optional_dependencies,
)
from earth2studio.utils.type import CoordSystem
try:
from physicsnemo.models.afno import AFNO
except ImportError:
OptionalDependencyFailure("fcn")
AFNO = None
VARIABLES = [
"u10m",
"v10m",
"t2m",
"sp",
"msl",
"t850",
"u1000",
"v1000",
"z1000",
"u850",
"v850",
"z850",
"u500",
"v500",
"z500",
"t500",
"z50",
"r500",
"r850",
"tcwv",
"u100m",
"v100m",
"u250",
"v250",
"z250",
"t250",
]
[docs]
@check_optional_dependencies()
class FCN(torch.nn.Module, AutoModelMixin, PrognosticMixin):
"""FourCastNet global prognostic model. Consists of a single model with a time-step
size of 6 hours. FourCastNet operates on 0.25 degree lat-lon grid (south-pole
excluding) equirectangular grid with 26 variables.
Note
----
This model is a retrained version on more atmospgeric variables from the FourCastNet
paper. For additional information see the following resources:
- https://arxiv.org/abs/2202.11214
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/models/modulus_fcn
Parameters
----------
core_model : torch.nn.Module
Core PyTorch model with loaded weights
center : torch.Tensor
Model center normalization tensor of size [26]
scale : torch.Tensor
Model scale normalization tensor of size [26]
"""
def __init__(
self,
core_model: torch.nn.Module,
center: torch.Tensor,
scale: torch.Tensor,
):
super().__init__()
self.model = core_model
self.register_buffer("center", center)
self.register_buffer("scale", scale)
# sphinx - coords start
def input_coords(self) -> CoordSystem:
"""Input coordinate system of the prognostic model
Returns
-------
CoordSystem
Coordinate system dictionary
"""
return OrderedDict(
{
"batch": np.empty(0),
"lead_time": np.array([np.timedelta64(0, "h")]),
"variable": np.array(VARIABLES),
"lat": np.linspace(90, -90, 720, endpoint=False),
"lon": np.linspace(0, 360, 1440, endpoint=False),
}
)
@batch_coords()
def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
"""Output coordinate system of the prognostic model
Parameters
----------
input_coords : CoordSystem
Input coordinate system to transform into output_coords
Returns
-------
CoordSystem
Coordinate system dictionary
"""
output_coords = OrderedDict(
{
"batch": np.empty(0),
"lead_time": np.array([np.timedelta64(6, "h")]),
"variable": np.array(VARIABLES),
"lat": np.linspace(90, -90, 720, endpoint=False),
"lon": np.linspace(0, 360, 1440, endpoint=False),
}
)
test_coords = input_coords.copy()
test_coords["lead_time"] = (
test_coords["lead_time"] - input_coords["lead_time"][-1]
)
target_input_coords = self.input_coords()
for i, key in enumerate(target_input_coords):
if key != "batch":
handshake_dim(test_coords, key, i)
handshake_coords(test_coords, target_input_coords, key)
output_coords = output_coords.copy()
output_coords["batch"] = input_coords["batch"]
output_coords["lead_time"] = (
output_coords["lead_time"] + input_coords["lead_time"]
)
return output_coords
# sphinx - coords end
def __str__(
self,
) -> str:
return "fcn"
[docs]
@classmethod
def load_default_package(cls) -> Package:
"""Load prognostic package"""
return Package(
"ngc://models/nvidia/modulus/modulus_fcn@v0.2",
cache_options={
"cache_storage": Package.default_cache("fcn"),
"same_names": True,
},
)
[docs]
@classmethod
@check_optional_dependencies()
def load_model(
cls,
package: Package,
) -> PrognosticModel:
"""Load prognostic from package"""
fcn_zip = Path(package.resolve("fcn.zip"))
# Have to manually unzip here. Should not zip checkpoints in the future
with zipfile.ZipFile(fcn_zip, "r") as zip_ref:
zip_ref.extractall(fcn_zip.parent)
model = AFNO.from_checkpoint(str(fcn_zip.parent / Path("fcn/fcn.mdlus")))
model.eval()
local_center = torch.Tensor(
np.load(str(fcn_zip.parent / Path("fcn/global_means.npy")))
)
local_std = torch.Tensor(
np.load(str(fcn_zip.parent / Path("fcn/global_stds.npy")))
)
return cls(model, center=local_center, scale=local_std)
@torch.inference_mode()
def _forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.squeeze(1)
x = (x - self.center) / self.scale
x = self.model(x)
x = self.scale * x + self.center
x = x.unsqueeze(1)
return x
[docs]
@batch_func()
def __call__(
self,
x: torch.Tensor,
coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
"""Runs prognostic model 1 step.
Parameters
----------
x : torch.Tensor
Input tensor
coords : CoordSystem
Input coordinate system
Returns
-------
tuple[torch.Tensor, CoordSystem]
Output tensor and coordinate system 6 hours in the future
"""
output_coords = self.output_coords(coords)
x = self._forward(x)
return x, output_coords
@batch_func()
def _default_generator(
self, x: torch.Tensor, coords: CoordSystem
) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:
coords = coords.copy()
self.output_coords(coords)
yield x, coords
while True:
# Front hook
x, coords = self.front_hook(x, coords)
# Forward is identity operator
coords = self.output_coords(coords)
x = self._forward(x)
# Rear hook
x, coords = self.rear_hook(x, coords)
yield x, coords.copy()
[docs]
def create_iterator(
self, x: torch.Tensor, coords: CoordSystem
) -> Iterator[tuple[torch.Tensor, CoordSystem]]:
"""Creates a iterator which can be used to perform time-integration of the
prognostic model. Will return the initial condition first (0th step).
Parameters
----------
x : torch.Tensor
Input tensor
coords : CoordSystem
Input coordinate system
Yields
------
Iterator[tuple[torch.Tensor, CoordSystem]]
Iterator that generates time-steps of the prognostic model container the
output data tensor and coordinate system dictionary.
"""
yield from self._default_generator(x, coords)