Source code for earth2studio.models.px.pangu

# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Pangu Weather License
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections import OrderedDict
from collections.abc import Generator, Iterator
from typing import TypeVar

import numpy as np
import torch

from earth2studio.models.auto import AutoModelMixin, Package
from earth2studio.models.batch import batch_coords, batch_func
from earth2studio.models.px.base import PrognosticModel
from earth2studio.models.px.utils import PrognosticMixin
from earth2studio.models.utils import create_ort_session
from earth2studio.utils import handshake_coords, handshake_dim
from earth2studio.utils.imports import (
    OptionalDependencyFailure,
    check_optional_dependencies,
)
from earth2studio.utils.type import CoordSystem

try:
    import onnxruntime as ort
    from onnxruntime import InferenceSession
except ImportError:
    OptionalDependencyFailure("pangu")
    ort = None
    InferenceSession = TypeVar("InferenceSession")  # type: ignore

VARIABLES = [
    "z1000",
    "z925",
    "z850",
    "z700",
    "z600",
    "z500",
    "z400",
    "z300",
    "z250",
    "z200",
    "z150",
    "z100",
    "z50",
    "q1000",
    "q925",
    "q850",
    "q700",
    "q600",
    "q500",
    "q400",
    "q300",
    "q250",
    "q200",
    "q150",
    "q100",
    "q50",
    "t1000",
    "t925",
    "t850",
    "t700",
    "t600",
    "t500",
    "t400",
    "t300",
    "t250",
    "t200",
    "t150",
    "t100",
    "t50",
    "u1000",
    "u925",
    "u850",
    "u700",
    "u600",
    "u500",
    "u400",
    "u300",
    "u250",
    "u200",
    "u150",
    "u100",
    "u50",
    "v1000",
    "v925",
    "v850",
    "v700",
    "v600",
    "v500",
    "v400",
    "v300",
    "v250",
    "v200",
    "v150",
    "v100",
    "v50",
    "msl",
    "u10m",
    "v10m",
    "t2m",
]


# Adapted from https://raw.githubusercontent.com/ecmwf-lab/ai-models-panguweather/main/ai_models_panguweather/model.py
class PanguBase(torch.nn.Module, AutoModelMixin, PrognosticMixin):
    """Pangu base class"""

    def __init__(self) -> None:
        super().__init__()
        # Shape of pressure fields (var, level, lat, lon)
        self.pressure_shape = (5, 13, 721, 1440)
        self.n_pres = 65
        # Shape of surface variable fields
        self.surface_shape = (4, 721, 1440)

        self._input_coords = OrderedDict(
            {
                "batch": np.empty(0),
                "lead_time": np.array([np.timedelta64(0, "h")]),
                "variable": np.array(VARIABLES),
                "lat": np.linspace(90, -90, 721, endpoint=True),
                "lon": np.linspace(0, 360, 1440, endpoint=False),
            }
        )

        self._output_coords = OrderedDict(
            {
                "batch": np.empty(0),
                "lead_time": np.array([np.timedelta64(6, "h")]),
                "variable": np.array(VARIABLES),
                "lat": np.linspace(90, -90, 721, endpoint=True),
                "lon": np.linspace(0, 360, 1440, endpoint=False),
            }
        )
        self.device = torch.ones(1).device  # Hack to get default device
        self.ort = None

    def input_coords(self) -> CoordSystem:
        """Input coordinate system of the prognostic model

        Returns
        -------
        CoordSystem
            Coordinate system dictionary
        """
        return self._input_coords.copy()

    @batch_coords()
    def output_coords(
        self,
        input_coords: CoordSystem,
    ) -> CoordSystem:
        """Output coordinate system of the prognostic model

        Parameters
        ----------
        input_coords : CoordSystem
            Input coordinate system to transform into output_coords

        Returns
        -------
        CoordSystem
            Coordinate system dictionary
        """
        output_coords = self._output_coords.copy()

        test_coords = input_coords.copy()
        test_coords["lead_time"] = (
            test_coords["lead_time"] - input_coords["lead_time"][-1]
        )
        target_input_coords = self.input_coords()
        for i, key in enumerate(target_input_coords):
            if key != "batch":
                handshake_dim(test_coords, key, i)
                handshake_coords(test_coords, target_input_coords, key)

        output_coords["batch"] = input_coords["batch"]
        output_coords["lead_time"] = (
            output_coords["lead_time"] + input_coords["lead_time"]
        )

        return output_coords

    @classmethod
    def load_default_package(cls) -> Package:
        """Load prognostic package"""
        return Package(
            "hf://NickGeneva/earth_ai/pangu",
            cache_options={
                "cache_storage": Package.default_cache("pangu"),
                "same_names": True,
            },
        )

    def to(self, device: str | torch.device | int) -> PrognosticModel:
        """Move model (and default ORT session) to device"""
        device = torch.device(device)
        if device.index is None:
            if device.type == "cuda":
                device = torch.device(device.type, torch.cuda.current_device())
            else:
                device = torch.device(device.type, 0)

        super().to(device)

        if device != self.device:
            self.device = device
            # Move base ort session
            if self.ort is not None:
                model_path = self.ort._model_path
                del self.ort
                self.ort = create_ort_session(model_path, device)

        return self

    @torch.inference_mode()
    def _forward(
        self,
        x: torch.Tensor,
        coords: CoordSystem,
        ort_session: InferenceSession,
        lead_time: np.ndarray | None = None,
    ) -> tuple[torch.Tensor, CoordSystem]:

        if lead_time is not None:
            previous_lead_time = self._output_coords["lead_time"]
            self._output_coords["lead_time"] = lead_time
            output_coords = self.output_coords(coords)
            self._output_coords["lead_time"] = previous_lead_time
        else:
            output_coords = self.output_coords(coords)

        # Ref: https://onnxruntime.ai/docs/api/python/api_summary.html
        binding = ort_session.io_binding()

        def bind_input(name: str, input: torch.Tensor) -> None:
            input = input.contiguous()
            binding.bind_input(
                name=name,
                device_type=self.device.type,
                device_id=self.device.index,
                element_type=np.float32,
                shape=tuple(input.shape),
                buffer_ptr=input.data_ptr(),
            )

        def bind_output(name: str, like: torch.Tensor) -> torch.Tensor:
            out = torch.empty_like(like).contiguous()
            binding.bind_output(
                name=name,
                device_type=self.device.type,
                device_id=self.device.index,
                element_type=np.float32,
                shape=tuple(out.shape),
                buffer_ptr=out.data_ptr(),
            )
            return out

        batch_output = torch.zeros_like(x)
        x = x.squeeze(1)
        # Process batches (model is single batch)
        for i in range(x.shape[0]):
            # Forward pass
            fields_pl = x[i, : self.n_pres].resize(*self.pressure_shape)
            fields_sfc = x[i, self.n_pres :]

            bind_input("input", fields_pl)
            bind_input("input_surface", fields_sfc)
            output = bind_output("output", like=fields_pl)
            output_sfc = bind_output("output_surface", like=fields_sfc)
            ort_session.run_with_iobinding(binding)
            output_tensor = torch.cat(
                [
                    output.view(-1, self.pressure_shape[-2], self.pressure_shape[-1]),
                    output_sfc,
                ],
                dim=0,
            ).contiguous()
            batch_output[i, 0] = output_tensor

        return batch_output, output_coords

    def _default_generator(
        self, x: torch.Tensor, coords: CoordSystem
    ) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:
        raise NotImplementedError

    def create_iterator(
        self, x: torch.Tensor, coords: CoordSystem
    ) -> Iterator[tuple[torch.Tensor, CoordSystem]]:
        """Creates a iterator which can be used to perform time-integration of the
        prognostic model. Will return the initial condition first (0th step).

        Parameters
        ----------
        x : torch.Tensor
            Input tensor
        coords : CoordSystem
            Input coordinate system


        Yields
        ------
        Iterator[tuple[torch.Tensor, CoordSystem]]
            Iterator that generates time-steps of the prognostic model container the
            output data tensor and coordinate system dictionary.
        """
        yield from self._default_generator(x, coords)


[docs] @check_optional_dependencies() class Pangu24(PanguBase): """Pangu Weather 24 hour model. This model consists of single auto-regressive model with a time-step size of 24 hours. Pangu Weather operates on 0.25 degree lat-lon grid (south-pole including) equirectangular grid with 69 atmospheric/surface variables. Note ---- This model uses the ONNX checkpoints from the original publication. For additional information see the following resources: - https://doi.org/10.1038/s41586-023-06185-3 - https://github.com/198808xc/Pangu-Weather Note ---- To avoid ONNX init session overhead of this model we recommend setting the default Pytorch device to the correct target prior to model construction. Warning ------- We encourage users to familiarize themselves with the license restrictions of this model's checkpoints. Parameters ---------- ort_24hr : str Path to Pangu 24 hour onnx file """ def __init__( self, ort_24hr: str, ): super().__init__() self.ort: ort.InferenceSession = create_ort_session(ort_24hr, self.device) self._output_coords["lead_time"] = np.array([np.timedelta64(24, "h")])
[docs] @classmethod @check_optional_dependencies() def load_model( cls, package: Package, ) -> PrognosticModel: """Load prognostic from package""" # Ghetto at the moment because NGC files are zipped. This will download zip and # unpack them then give the cached folder location from which we can then # access the needed files. onnx_file = package.resolve("pangu_weather_24.onnx") return cls(onnx_file)
[docs] @batch_func() def __call__( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: """Runs 24 hour prognostic model 1 step. Parameters ---------- x : torch.Tensor Input tensor coords : CoordSystem Input coordinate system Returns ------- tuple[torch.Tensor, CoordSystem] Output tensor and coordinate system 24 hours in the future """ return self._forward(x, coords, self.ort)
@batch_func() def _default_generator( self, x: torch.Tensor, coords: CoordSystem ) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]: coords = coords.copy() self.output_coords(coords) yield x, coords while True: # Front hook x, coords = self.front_hook(x, coords) # Forward is identity operator x, coords = self._forward(x, coords, self.ort) # Rear hook x, coords = self.rear_hook(x, coords) yield x, coords.copy()
[docs] @check_optional_dependencies() class Pangu6(PanguBase): """Pangu Weather 6 hour model. This model consists of two underlying auto-regressive models with a time-step size of 24 hours and 6 hours. These two models are interweaved during prediction. Pangu Weather operates on 0.25 degree lat-lon grid (south-pole including) equirectangular grid with 69 atmospheric/surface variables. Note ---- This model uses the ONNX checkpoints from the original publication. For additional information see the following resources: - https://doi.org/10.1038/s41586-023-06185-3 - https://github.com/198808xc/Pangu-Weather Note ---- To avoid ONNX init session overhead of this model we recommend setting the default Pytorch device to the correct target prior to model construction. Warning ------- We encourage users to familiarize themselves with the license restrictions of this model's checkpoints. Parameters ---------- ort_24hr : str Path to Pangu 24 hour onnx file ort_6hr : str Path to Pangu 6 hour onnx file """ def __init__( self, ort_24hr: str, ort_6hr: str, ): super().__init__() # Only require 6 hour to load session on construction self.ort: ort.InferenceSession = create_ort_session(ort_6hr, self.device) self.ort24 = ort_24hr self._output_coords["lead_time"] = np.array([np.timedelta64(6, "h")])
[docs] @classmethod @check_optional_dependencies() def load_model( cls, package: Package, ) -> PrognosticModel: """Load prognostic from package""" # Ghetto at the moment because NGC files are zipped. This will download zip and # unpack them then give the cached folder location from which we can then # access the needed files. onnx_file_24 = package.resolve("pangu_weather_24.onnx") onnx_file_6 = package.resolve("pangu_weather_6.onnx") return cls(onnx_file_24, onnx_file_6)
[docs] @batch_func() def __call__( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: """Runs 6 hour prognostic model 1 step. Parameters ---------- x : torch.Tensor Input tensor coords : CoordSystem Input coordinate system Returns ------- tuple[torch.Tensor, CoordSystem] Output tensor and coordinate system 6 hours in the future """ return self._forward(x, coords, self.ort)
@batch_func() def _default_generator( self, x: torch.Tensor, coords: CoordSystem ) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]: coords = coords.copy() # Load other sessions (note .to() does not impact these) ort24 = create_ort_session(self.ort24, self.device) self.output_coords(coords) yield x, coords while True: x24 = x.clone() coords24 = coords.copy() # Three 6-hour steps for i in range(3): x, coords = self.front_hook(x, coords) x, coords = self._forward( x, coords, self.ort, ) x, coords = self.rear_hook(x, coords) yield x, coords.copy() # 24 hour step x, coords = self.front_hook(x24, coords24) x, coords = self._forward( x, coords, ort24, np.array([np.timedelta64(24, "h")]) ) x, coords = self.rear_hook(x, coords) yield x, coords.copy()
[docs] @check_optional_dependencies() class Pangu3(PanguBase): """Pangu Weather 3 hour model. This model consists of three underlying auto-regressive models with a time-step size of 24, 6 and 3 hours. These three models are interweaved during prediction. Pangu Weather operates on 0.25 degree lat-lon grid (south-pole including) equirectangular grid with 69 atmospheric/surface variables. Note ---- This model uses the ONNX checkpoints from the original publication. For additional information see the following resources: - https://doi.org/10.1038/s41586-023-06185-3 - https://github.com/198808xc/Pangu-Weather Note ---- To avoid ONNX init session overhead of this model we recommend setting the default Pytorch device to the correct target prior to model construction. Warning ------- We encourage users to familiarize themselves with the license restrictions of this model's checkpoints. Parameters ---------- ort_24hr : str Path to Pangu 24 hour onnx file ort_6hr : str Path to Pangu 6 hour onnx file ort_3hr : str Path to Pangu 3 hour onnx file """ def __init__( self, ort_24hr: str, ort_6hr: str, ort_3hr: str, ): super().__init__() # Only require 3 hour to load session on construction self.ort: ort.InferenceSession = create_ort_session(ort_3hr, self.device) self.ort24 = ort_24hr self.ort6 = ort_6hr self._output_coords["lead_time"] = np.array([np.timedelta64(3, "h")])
[docs] @classmethod @check_optional_dependencies() def load_model( cls, package: Package, ) -> PrognosticModel: """Load prognostic from package""" # Ghetto at the moment because NGC files are zipped. This will download zip and # unpack them then give the cached folder location from which we can then # access the needed files. onnx_file_24 = package.resolve("pangu_weather_24.onnx") onnx_file_6 = package.resolve("pangu_weather_6.onnx") onnx_file = package.resolve("pangu_weather_3.onnx") return cls(onnx_file_24, onnx_file_6, onnx_file)
[docs] @batch_func() def __call__( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: """Runs 3 hour prognostic model 1 step. Parameters ---------- x : torch.Tensor Input tensor coords : CoordSystem Input coordinate system Returns ------- tuple[torch.Tensor, CoordSystem] Output tensor and coordinate system 3 hours in the future """ return self._forward(x, coords, self.ort)
@batch_func() def _default_generator( self, x: torch.Tensor, coords: CoordSystem ) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]: coords = coords.copy() # Load other sessions (note that .to() does not impact these) ort24 = create_ort_session(self.ort24, self.device) ort6 = create_ort_session(self.ort6, self.device) self.output_coords(coords) yield x, coords while True: x0 = x.clone() # Used with 24 hour model coord0 = coords.copy() x1 = x.clone() # Used with 6 hour model coords1 = coords.copy() # Single 3-hour step x, coords = self.front_hook(x, coords) x, coords = self._forward(x, coords, self.ort) x, coords = self.rear_hook(x, coords) yield x, coords.copy() # Three 6-hour steps for i in range(3): x, coords = self.front_hook(x1, coords1) x, coords = self._forward( x, coords, ort6, np.array([np.timedelta64(6, "h")]) ) x, coords = self.rear_hook(x, coords) yield x, coords.copy() x1 = x.clone() coords1 = coords.copy() # Single 3-hour step x, coords = self.front_hook(x, coords) x, coords = self._forward(x, coords, self.ort) x, coords = self.rear_hook(x, coords) yield x, coords.copy() # 24 hour step x, coords = self.front_hook(x0, coord0) x, coords = self._forward( x0, coords, ort24, np.array([np.timedelta64(24, "h")]) ) x, coords = self.rear_hook(x, coords) yield x, coords.copy()