Source code for earth2studio.models.px.aifsens

# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import zipfile
from collections import OrderedDict
from collections.abc import Generator, Iterator

import numpy as np
import torch

from earth2studio.data import IFS
from earth2studio.data.utils import fetch_data
from earth2studio.models.auto import AutoModelMixin, Package
from earth2studio.models.batch import batch_coords, batch_func
from earth2studio.models.px.base import PrognosticModel
from earth2studio.models.px.utils import PrognosticMixin
from earth2studio.utils import handshake_coords, handshake_dim
from earth2studio.utils.imports import (
    OptionalDependencyFailure,
    check_optional_dependencies,
)
from earth2studio.utils.type import CoordSystem

try:
    import anemoi.models  # noqa: F401
    import earthkit.regrid  # noqa: F401
    import ecmwf.opendata  # noqa: F401
    import flash_attn  # noqa: F401
except ImportError:
    OptionalDependencyFailure("aifsens")

VARIABLES = [
    "q50",
    "q100",
    "q150",
    "q200",
    "q250",
    "q300",
    "q400",
    "q500",
    "q600",
    "q700",
    "q850",
    "q925",
    "q1000",
    "t50",
    "t100",
    "t150",
    "t200",
    "t250",
    "t300",
    "t400",
    "t500",
    "t600",
    "t700",
    "t850",
    "t925",
    "t1000",
    "u50",
    "u100",
    "u150",
    "u200",
    "u250",
    "u300",
    "u400",
    "u500",
    "u600",
    "u700",
    "u850",
    "u925",
    "u1000",
    "v50",
    "v100",
    "v150",
    "v200",
    "v250",
    "v300",
    "v400",
    "v500",
    "v600",
    "v700",
    "v850",
    "v925",
    "v1000",
    "w50",
    "w100",
    "w150",
    "w200",
    "w250",
    "w300",
    "w400",
    "w500",
    "w600",
    "w700",
    "w850",
    "w925",
    "w1000",
    "z50",
    "z100",
    "z150",
    "z200",
    "z250",
    "z300",
    "z400",
    "z500",
    "z600",
    "z700",
    "z850",
    "z925",
    "z1000",
    "u10m",
    "v10m",
    "d2m",
    "t2m",
    "lsm",
    "msl",
    "sdor",
    "skt",
    "slor",
    "sp",
    "tcw",
    "z",
    "cp06",
    "tp06",
    "cos_latitude",
    "cos_longitude",
    "sin_latitude",
    "sin_longitude",
    "cos_julian_day",
    "cos_local_time",
    "sin_julian_day",
    "sin_local_time",
    "insolation",  # cosine zenith angle
    "stl1",
    "stl2",
    "ssrd06",
    "strd06",
    "sf",
    "tcc",
    "mcc",
    "hcc",
    "lcc",
    "u100m",
    "v100m",
    "ro",
]  # from config.json >> dataset.variables


[docs] @check_optional_dependencies() class AIFSENS(torch.nn.Module, AutoModelMixin, PrognosticMixin): """Artificial Intelligence Forecasting System Ensemble (AIFS ENS v1.0), a probabilistic, ensemble-based forecast model from the European Centre for Medium-Range Weather Forecasts (ECMWF). AIFS ENS uses a GNN encoder/decoder with a sliding-window transformer processor, trained on ERA5 reanalysis and operational NWP analyses, and is run four times daily with a 6-hour time step. The model is trained with a CRPS objective over a small ensemble to provide calibrated probabilistic output. Note ---- This model uses the checkpoints provided by ECMWF. For additional information see the following resources: - https://arxiv.org/abs/2406.01465 - https://huggingface.co/ecmwf/aifs-ens-1.0 - https://github.com/ecmwf/anemoi-core Parameters ---------- model : torch.nn.Module Core PyTorch module with the pretrained AIFSENS weights loaded. latitudes : torch.Tensor Latitude values for the native model grid, registered as a buffer for interpolation. longitudes : torch.Tensor Longitude values for the native model grid, registered as a buffer for interpolation. interpolation_matrix : torch.Tensor CSR sparse matrix mapping ERA5 lat/lon inputs onto the native model grid. inverse_interpolation_matrix : torch.Tensor CSR sparse matrix mapping outputs from the native model grid back to ERA5 lat/lon. invariants : torch.Tensor Tensor of shape [4, 721, 1440] containing the invariant fields "lsm", "sdor", "slor" and "z" Warning ------- We encourage users to familiarize themselves with the license restrictions of this model's checkpoints. """ VARIABLE_INVARIANTS = ["lsm", "sdor", "slor", "z"] VARIABLE_FORCINGS = [ "cos_latitude", "cos_longitude", "sin_latitude", "sin_longitude", "cos_julian_day", "cos_local_time", "sin_julian_day", "sin_local_time", "insolation", ] def __init__( self, model: torch.nn.Module, latitudes: torch.Tensor, longitudes: torch.Tensor, interpolation_matrix: torch.Tensor, inverse_interpolation_matrix: torch.Tensor, invariants: torch.Tensor, ) -> None: super().__init__() self.model = model self.register_buffer("invariants", invariants) self.register_buffer("latitudes", latitudes) self.register_buffer("longitudes", longitudes) self.register_buffer("interpolation_matrix", interpolation_matrix) self.register_buffer( "inverse_interpolation_matrix", inverse_interpolation_matrix ) self.invariant_ids = self.model.data_indices.data.output.forcing[ : len(self.VARIABLE_INVARIANTS) ] self.forcing_ids = self.model.data_indices.data.output.forcing[ len(self.VARIABLE_INVARIANTS) : ] @property def input_variables(self) -> list[str]: indices = torch.cat( [ self.model.data_indices.data.input.prognostic, self.model.data_indices.data.input.forcing, ] ) # Sort the concatenated tensor indices = indices.sort().values # Keep only elements NOT forcings or invariants mask = ~torch.isin(indices, self.model.data_indices.data.input.forcing) selected = [VARIABLES[i] for i in indices[mask].tolist()] return selected @property def output_variables(self) -> list[str]: # Output constants + prognostic and diagnostic - generated forcings indices = torch.cat( [ self.model.data_indices.data.output.forcing, self.model.data_indices.data.output.full, ] ) # Sort the concatenated tensor indices = torch.unique(indices.sort().values) # Keep only elements NOT in to_remove mask = ~torch.isin(indices, self.model.data_indices.data.output.forcing) selected = [VARIABLES[i] for i in indices[mask].tolist()] return selected def input_coords(self) -> CoordSystem: """Input coordinate system of the prognostic model Returns ------- CoordSystem Coordinate system dictionary """ return OrderedDict( { "batch": np.empty(0), "time": np.empty(0), "lead_time": np.array( [np.timedelta64(-6, "h"), np.timedelta64(0, "h")] ), "variable": np.array(self.input_variables), "lat": np.linspace(90.0, -90.0, 721), "lon": np.linspace(0, 360, 1440, endpoint=False), } ) @batch_coords() def output_coords(self, input_coords: CoordSystem) -> CoordSystem: """Output coordinate system of the prognostic model Parameters ---------- input_coords : CoordSystem Input coordinate system to transform into output_coords by default None, will use self.input_coords. Returns ------- CoordSystem Coordinate system dictionary """ output_coords = OrderedDict( { "batch": np.empty(0), "time": np.empty(0), "lead_time": np.array([np.timedelta64(6, "h")]), "variable": np.array(self.output_variables), "lat": np.linspace(90.0, -90.0, 721), "lon": np.linspace(0, 360, 1440, endpoint=False), } ) if input_coords is None: return output_coords test_coords = input_coords.copy() test_coords["lead_time"] = ( test_coords["lead_time"] - input_coords["lead_time"][-1] ) target_input_coords = self.input_coords() for i, key in enumerate(target_input_coords): if key not in ["batch", "time"]: handshake_dim(test_coords, key, i) handshake_coords(test_coords, target_input_coords, key) output_coords["batch"] = input_coords["batch"] output_coords["time"] = input_coords["time"] output_coords["lead_time"] = ( input_coords["lead_time"][-1] + output_coords["lead_time"] ) return output_coords
[docs] @classmethod def load_default_package(cls) -> Package: """Load prognostic package""" package = Package( "hf://ecmwf/aifs-ens-1.0", cache_options={ "cache_storage": Package.default_cache("aifs-ens-1.0"), "same_names": True, }, ) return package
[docs] @classmethod @check_optional_dependencies() def load_model(cls, package: Package) -> PrognosticModel: """Load prognostic from package""" # Load model model_path = package.resolve("aifs-ens-crps-1.0.ckpt") model = torch.load( model_path, weights_only=False, map_location=torch.ones(1).device ) model.eval() # Define the path to the metadata file metadata_path = "inference-anemoi-by_epoch-epoch_001-step_000040_tp_fix_0.05/anemoi-metadata/ai-models.json" # Extract metadata and supporting arrays from the zip file with zipfile.ZipFile(model_path, "r") as zipf: # NOTE: this is totally baffling # Load metadata metadata = json.load(zipf.open(metadata_path)) # Load supporting arrays supporting_arrays = {} for key, entry in metadata.get("supporting_arrays_paths", {}).items(): supporting_arrays[key] = np.frombuffer( zipf.read(entry["path"]), dtype=entry["dtype"], ).reshape(entry["shape"]) # Load interpolation matrix # TODO: Maybe change this to allow for multiple packages? interpolation_package = Package( "https://get.ecmwf.int/repository/earthkit/regrid/db/1/mir_16_linear", cache_options={ "cache_storage": Package.default_cache( "aifs-ens-1.0/interpolation_matrix" ), "same_names": True, }, ) interpolation_matrix_path = interpolation_package.resolve( "9533e90f8433424400ab53c7fafc87ba1a04453093311c0b5bd0b35fedc1fb83.npz" ) interpolation_matrix = np.load(interpolation_matrix_path) torch_interpolation_matrix = torch.sparse_csr_tensor( crow_indices=torch.from_numpy(interpolation_matrix["indptr"]), col_indices=torch.from_numpy(interpolation_matrix["indices"]), values=torch.from_numpy(interpolation_matrix["data"]), size=(interpolation_matrix["shape"][0], interpolation_matrix["shape"][1]), dtype=torch.float64, ) inverse_interpolation_package = Package( "https://get.ecmwf.int/repository/earthkit/regrid/db/1/mir_16_linear/", cache_options={ "cache_storage": Package.default_cache( "aifs-ens-1.0/inverse_interpolation_matrix" ), "same_names": True, }, ) inverse_interpolation_matrix_path = inverse_interpolation_package.resolve( "7f0be51c7c1f522592c7639e0d3f95bcbff8a044292aa281c1e73b842736d9bf.npz" ) inverse_interpolation_matrix = np.load(inverse_interpolation_matrix_path) torch_inverse_interpolation_matrix = torch.sparse_csr_tensor( crow_indices=torch.from_numpy(inverse_interpolation_matrix["indptr"]), col_indices=torch.from_numpy(inverse_interpolation_matrix["indices"]), values=torch.from_numpy(inverse_interpolation_matrix["data"]), size=( inverse_interpolation_matrix["shape"][0], inverse_interpolation_matrix["shape"][1], ), dtype=torch.float64, ) # Fetch invariants from IFS, note that there are deviations between these # invariant fields depending on where and what time the data is fetched. # For this model, we will use ECMWF's own invarints in the IFS data store. ifs = IFS(cache=True, verbose=False) invariants, _ = fetch_data( source=ifs, time=np.array([np.datetime64("2026-01-01T00:00:00")]), variable=["lsm", "sdor", "slor", "z"], ) invariants = invariants.squeeze() # Can also fetch from NCAR ERA5 backup but these have some differences # invariant_package = Package( # "https://nsf-ncar-era5.s3.amazonaws.com/e5.oper.invariant/197901/", # cache_options={ # "cache_storage": Package.default_cache( # "aifs-ens-1.0" # ), # "same_names": True, # }, # ) # invariant_arrays = [] # for key, value in {"lsm": 172, "sdor": 160, "slor": 163, "z": 129}.items(): # ds = xr.load_dataset(invariant_package.resolve(f"e5.oper.invariant.128_{value:03d}_{key}.ll025sc.1979010100_1979010100.nc")) # invariant_arrays.append(ds[key.upper()].values) # invariants = torch.Tensor(invariant_arrays).squeeze() return cls( model, latitudes=torch.Tensor(supporting_arrays["latitudes"]).reshape(1, 1, -1, 1), longitudes=torch.Tensor(supporting_arrays["longitudes"]).reshape( 1, 1, -1, 1 ), interpolation_matrix=torch_interpolation_matrix, inverse_interpolation_matrix=torch_inverse_interpolation_matrix, invariants=invariants, )
def get_cos_sin_julian_day( self, time_array: np.datetime64, longitudes: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Get cosine and sine of Julian day""" days = ( time_array.astype("datetime64[D]") - time_array.astype("datetime64[Y]") ).astype(np.float32) hours = ( time_array.astype("datetime64[h]") - time_array.astype("datetime64[D]") ).astype(np.float32) julian_days = days + (hours / 24.0) normalized = 2 * np.pi * (julian_days / 365.25) cos_julian_day = torch.full_like( longitudes, np.cos(normalized), dtype=torch.float32 ) sin_julian_day = torch.full_like( longitudes, np.sin(normalized), dtype=torch.float32 ) return cos_julian_day, sin_julian_day def get_cos_sin_local_time( self, time_array: np.datetime64, longitudes: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Get cosine and sine of local time""" hours = ( time_array.astype("datetime64[h]") - time_array.astype("datetime64[D]") ).astype(np.float32) normalized_time = 2 * np.pi * (hours / 24.0) normalized_longitudes = 2 * np.pi * (longitudes / 360.0) tau = normalized_time + normalized_longitudes cos_local_time = torch.cos(tau) sin_local_time = torch.sin(tau) return cos_local_time, sin_local_time def get_cosine_zenith_fields( self, date: np.datetime64, latitudes: torch.Tensor, longitudes: torch.Tensor, ) -> torch.Tensor: """Get cosine zenith fields for input time array""" # Get Julian day days = (date.astype("datetime64[D]") - date.astype("datetime64[Y]")).astype( np.float32 ) hours = (date.astype("datetime64[h]") - date.astype("datetime64[D]")).astype( np.float32 ) seconds = (date.astype("datetime64[s]") - date.astype("datetime64[h]")).astype( np.float32 ) julian_day = days + seconds / 86400.0 # Convert angle to tensor angle = torch.tensor( julian_day / 365.25 * torch.pi * 2, device=latitudes.device ) # declination in [degrees] declination = ( 0.396372 - 22.91327 * torch.cos(angle) + 4.025430 * torch.sin(angle) - 0.387205 * torch.cos(2 * angle) + 0.051967 * torch.sin(2 * angle) - 0.154527 * torch.cos(3 * angle) + 0.084798 * torch.sin(3 * angle) ) # time correction in [h.degrees] time_correction = ( 0.004297 + 0.107029 * torch.cos(angle) - 1.837877 * torch.sin(angle) - 0.837378 * torch.cos(2 * angle) - 2.340475 * torch.sin(2 * angle) ) # Convert to radians declination = torch.deg2rad(declination) latitudes = torch.deg2rad(latitudes) # Calculate sine and cosine of declination and latitude sindec_sinlat = torch.sin(declination) * torch.sin(latitudes) cosdec_coslat = torch.cos(declination) * torch.cos(latitudes) # Solar hour angle solar_angle = torch.deg2rad((hours - 12) * 15 + longitudes + time_correction) zenith_angle = sindec_sinlat + cosdec_coslat * torch.cos(solar_angle) # Clip negative values return torch.clamp(zenith_angle, min=0.0) def _add_invariants( self, x: torch.Tensor, coords: CoordSystem, ) -> torch.Tensor: """add ['lsm', 'sdor', 'slor', 'z'] to input tensor""" shape = list(x.shape) shape[-3] += self.invariants.shape[-3] _x = torch.zeros(shape, device=x.device) all_ids = torch.arange(shape[-3]) variable_ids = all_ids[~torch.isin(all_ids, self.invariant_ids)] _x[..., variable_ids, :, :] = x _x[..., self.invariant_ids, :, :] = self.invariants return _x def _prepare_input( self, x: torch.Tensor, coords: CoordSystem, ) -> torch.Tensor: """Prepare input tensor and coordinates for the AIFS ENS model.""" # add invariants x = self._add_invariants(x, coords) # Interpolate the input tensor to the model grid shape = x.shape x = x.flatten(start_dim=4) x = x.flatten(end_dim=3) x = torch.swapaxes(x, 0, -1) x = x.to(dtype=torch.float64) x = self.interpolation_matrix @ x x = x.to(dtype=torch.float32) x = torch.swapaxes(x, 0, -1) x = x.reshape([shape[0] * shape[1], shape[2], shape[3], -1]) x = torch.swapaxes(x, 2, 3) # Get cos, sin of latitude and longitude # (cos_latitude, sin_latitude, cos_longitude, sin_longitude) cos_latitude = torch.cos(torch.deg2rad(self.latitudes)) sin_latitude = torch.sin(torch.deg2rad(self.latitudes)) cos_longitude = torch.cos(torch.deg2rad(self.longitudes)) sin_longitude = torch.sin(torch.deg2rad(self.longitudes)) cos_latitude = torch.cat([cos_latitude, cos_latitude], dim=1) cos_longitude = torch.cat([cos_longitude, cos_longitude], dim=1) sin_latitude = torch.cat([sin_latitude, sin_latitude], dim=1) sin_longitude = torch.cat([sin_longitude, sin_longitude], dim=1) # Get cos, sin of Julian day cos_julian_day_0, sin_julian_day_0 = self.get_cos_sin_julian_day( coords["time"][0] - np.timedelta64(6, "h"), self.longitudes ) cos_julian_day_1, sin_julian_day_1 = self.get_cos_sin_julian_day( coords["time"][0], self.longitudes ) cos_julian_day = torch.cat([cos_julian_day_0, cos_julian_day_1], dim=1) sin_julian_day = torch.cat([sin_julian_day_0, sin_julian_day_1], dim=1) # Get cos, sin local time cos_local_time_0, sin_local_time_0 = self.get_cos_sin_local_time( coords["time"][0] - np.timedelta64(6, "h"), self.longitudes ) cos_local_time_1, sin_local_time_1 = self.get_cos_sin_local_time( coords["time"][0], self.longitudes ) cos_local_time = torch.cat([cos_local_time_0, cos_local_time_1], dim=1) sin_local_time = torch.cat([sin_local_time_0, sin_local_time_1], dim=1) # Get cosine zenith angle # Add insolation / cosine zenith angle cos_zenith_angle_0 = self.get_cosine_zenith_fields( coords["time"][0] - np.timedelta64(6, "h"), self.latitudes, self.longitudes ) cos_zenith_angle_1 = self.get_cosine_zenith_fields( coords["time"][0], self.latitudes, self.longitudes ) cos_zenith_angle = torch.cat([cos_zenith_angle_0, cos_zenith_angle_1], dim=1) # Combine inputs x = torch.cat( [ x[:, :, :, :90], cos_latitude.repeat(shape[0] * shape[1], 1, 1, 1), cos_longitude.repeat(shape[0] * shape[1], 1, 1, 1), sin_latitude.repeat(shape[0] * shape[1], 1, 1, 1), sin_longitude.repeat(shape[0] * shape[1], 1, 1, 1), cos_julian_day.repeat(shape[0] * shape[1], 1, 1, 1), cos_local_time.repeat(shape[0] * shape[1], 1, 1, 1), sin_julian_day.repeat(shape[0] * shape[1], 1, 1, 1), sin_local_time.repeat(shape[0] * shape[1], 1, 1, 1), cos_zenith_angle.repeat(shape[0] * shape[1], 1, 1, 1), x[:, :, :, 90:], ], dim=3, ) return x def _update_input( self, x: torch.Tensor, coords: CoordSystem, ) -> torch.Tensor: """Update time based inputs.""" time0 = coords["time"][0] + coords["lead_time"][0] time1 = coords["time"][0] + coords["lead_time"][1] # Select only inputs # From AnemoiModelInterface.DataIndices # https://anemoi.readthedocs.io/projects/models/en/latest/modules/data_indices.html#usage-information x = x[..., self.model.data_indices.data.input.full] # Get cos, sin of Julian day cos_julian_day_0, sin_julian_day_0 = self.get_cos_sin_julian_day( time0, self.longitudes ) cos_julian_day_1, sin_julian_day_1 = self.get_cos_sin_julian_day( time1, self.longitudes ) cos_julian_day = torch.cat([cos_julian_day_0, cos_julian_day_1], dim=1) sin_julian_day = torch.cat([sin_julian_day_0, sin_julian_day_1], dim=1) # Get cos, sin local time cos_local_time_0, sin_local_time_0 = self.get_cos_sin_local_time( time0, self.longitudes ) cos_local_time_1, sin_local_time_1 = self.get_cos_sin_local_time( time1, self.longitudes ) cos_local_time = torch.cat([cos_local_time_0, cos_local_time_1], dim=1) sin_local_time = torch.cat([sin_local_time_0, sin_local_time_1], dim=1) # Get cosine zenith angle # Add insolation / cosine zenith angle cos_zenith_angle_0 = self.get_cosine_zenith_fields( time0, self.latitudes, self.longitudes ) cos_zenith_angle_1 = self.get_cosine_zenith_fields( time1, self.latitudes, self.longitudes ) cos_zenith_angle = torch.cat([cos_zenith_angle_0, cos_zenith_angle_1], dim=1) # Add terms to x x[:, :, :, 94:95] = cos_julian_day x[:, :, :, 95:96] = cos_local_time x[:, :, :, 96:97] = sin_julian_day x[:, :, :, 97:98] = sin_local_time x[:, :, :, 98:99] = cos_zenith_angle return x def _prepare_output( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: """Prepare input tensor and coordinates for the AIFS ENS model.""" # Remove generated forcings all_indices = torch.arange(x.size(-1)) keep = torch.isin( all_indices, self.model.data_indices.data.output.forcing, invert=True ) x = x[..., keep] shape = x.shape # Interpolate the model grid to the lat lon grid x = x[:, 1:2] x = x.flatten(end_dim=1) x = torch.swapaxes(x, 0, 1) x = x.flatten(start_dim=1) x = x.to(dtype=torch.float64) x = self.inverse_interpolation_matrix @ x x = x.to(dtype=torch.float32) x = torch.reshape(x, [x.shape[0], shape[0], shape[-1]]) x = torch.swapaxes(x, 0, 1) x = torch.swapaxes(x, 1, 2) x = torch.reshape( x, [ coords["batch"].shape[0], coords["time"].shape[0], coords["lead_time"].shape[0], coords["variable"].shape[0], coords["lat"].shape[0], coords["lon"].shape[0], ], ) return x @torch.inference_mode() def _forward( self, x: torch.Tensor, coords: CoordSystem, step: int = 1, ) -> tuple[torch.Tensor, CoordSystem]: output_coords = self.output_coords(coords) with torch.autocast(device_type=str(x.device), dtype=torch.float16): y = self.model.predict_step(x, fcstep=step) out = torch.empty( (x.shape[0], x.shape[1], x.shape[2], len(VARIABLES)), device=x.device, ) out[:, 0, :, self.model.data_indices.data.input.full] = x[:, 1] out[:, 1, :, self.model.data_indices.data.output.full] = y[:, 0] out[:, 1, :, self.model.data_indices.data.input.forcing] = x[ :, 1, :, self.model.data_indices.model.input.forcing ] return out, output_coords
[docs] @batch_func() def __call__( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: """Runs prognostic model 1 step. Parameters ---------- x : torch.Tensor Input tensor coords : CoordSystem Input coordinate system Returns ------- tuple[torch.Tensor, CoordSystem] Output tensor and coordinate system 6 hours in the future """ _ = self.output_coords(coords) # NOTE: Quick fix for exception handling x = self._prepare_input(x, coords) x, coords = self._forward(x, coords) x = self._prepare_output(x, coords) return x, coords
def _fill_input(self, x: torch.Tensor, coords: CoordSystem) -> torch.Tensor: """ Fill the model input tensor by selecting prognostic + forcing variables, while removing generated forcings (indices 92–100). """ # add invariants to prognostics x = self._add_invariants(x, coords) batch, time, lead, _, height, width = x.shape # Prepare empty output tensor with VARIABLE dimension out = torch.empty( (batch, time, lead, len(VARIABLES), height, width), device=x.device, ) # Collect relevant indices from model (prognostic + forcing) indices = ( torch.cat( [ self.model.data_indices.data.input.prognostic, self.model.data_indices.data.input.forcing, ] ) .sort() .values ) # Keep only valid indices (exclude generated forcings) valid_mask = ~torch.isin(indices, self.forcing_ids) # Fill tensor: copy input slices into selected variable slots out[:, :, 0, indices[valid_mask]] = x[0, 0, 0, ...] out[:, :, 1, indices[valid_mask]] = x[0, 0, 1, ...] # Drop generated forcing / invariants from output all_indices = torch.arange(len(VARIABLES)) keep = torch.isin( all_indices, self.model.data_indices.data.output.forcing, invert=True ) out = out[:, :, :, keep, ...] # Update coordinates with remaining variable names selected_variables = [VARIABLES[i] for i in all_indices[keep].tolist()] out_coords = coords.copy() out_coords["variable"] = np.array(selected_variables) return out, out_coords @batch_func() def _default_generator( self, x: torch.Tensor, coords: CoordSystem ) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]: coords = coords.copy() self.output_coords(coords) first_out, coords_out = self._fill_input(x, coords) coords_out["lead_time"] = coords["lead_time"][1:] yield first_out[:, :, 1:], coords_out # Prepare input tensor x = self._prepare_input(x, coords) step = 1 while True: # Front hook x, coords = self.front_hook(x, coords) # Forward is identity operator y, coords_out = self._forward(x, coords, step=step) # Prepare output tensor output_tensor = self._prepare_output(y, coords_out) # Rear hook output_tensor, coords_out = self.rear_hook(output_tensor, coords_out) # Yield output tensor yield output_tensor, coords_out.copy() # Update coordinates coords["lead_time"] = ( coords["lead_time"] + self.output_coords(self.input_coords())["lead_time"] ) # Prepare input tensor x = self._update_input(y, coords) step += 1
[docs] def create_iterator( self, x: torch.Tensor, coords: CoordSystem ) -> Iterator[tuple[torch.Tensor, CoordSystem]]: """Creates a iterator which can be used to perform time-integration of the prognostic model. Will return the initial condition first (0th step). Parameters ---------- x : torch.Tensor Input tensor coords : CoordSystem Input coordinate system Yields ------ Iterator[tuple[torch.Tensor, CoordSystem]] Iterator that generates time-steps of the prognostic model container the output data tensor and coordinate system dictionary. """ yield from self._default_generator(x, coords)