Source code for earth2studio.perturbation.bv

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from collections.abc import Callable

import torch

from earth2studio.perturbation.base import Perturbation
from earth2studio.perturbation.brown import Brown
from earth2studio.utils.type import CoordSystem


[docs] class BredVector: """Bred Vector perturbation method, a classical technique for pertubations in ensemble forecasting. Parameters ---------- model : Callable[[torch.Tensor], torch.Tensor] Dynamical model, typically this is the prognostic AI model. TODO: Update to prognostic looper noise_amplitude : float | Tensor, optional Noise amplitude, by default 0.05. If a tensor, this must be broadcastable with the input data. integration_steps : int, optional Number of integration steps to use in forward call, by default 20 ensemble_perturb : bool, optional Perturb the ensemble in an interacting fashion, by default False seeding_perturbation_method : Perturbation, optional Method to seed the Bred Vector perturbation, by default Brown Noise Note ---- For additional information: - https://journals.ametsoc.org/view/journals/bams/74/12/1520-0477_1993_074_2317_efantg_2_0_co_2.xml - https://en.wikipedia.org/wiki/Bred_vector """ def __init__( self, model: Callable[ [torch.Tensor, CoordSystem], tuple[torch.Tensor, CoordSystem], ], noise_amplitude: float | torch.Tensor = 0.05, integration_steps: int = 20, ensemble_perturb: bool = False, seeding_perturbation_method: Perturbation = Brown(), ): self.model = model self.noise_amplitude = ( noise_amplitude if isinstance(noise_amplitude, torch.Tensor) else torch.Tensor([noise_amplitude]) ) self.ensemble_perturb = ensemble_perturb self.integration_steps = integration_steps self.seeding_perturbation_method = seeding_perturbation_method
[docs] @torch.inference_mode() def __call__( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: """Apply perturbation method Parameters ---------- x : torch.Tensor Input tensor intended to apply perturbation on coords : CoordSystem Ordered dict representing coordinate system that describes the tensor Returns ------- Returns ------- tuple[torch.Tensor, CoordSystem]: Output tensor and respective coordinate system dictionary """ noise_amplitude = self.noise_amplitude.to(x.device) dx, coords = self.seeding_perturbation_method(x, coords) dx -= x xd = torch.clone(x) xd, _ = self.model(xd, coords) # Run forward model for k in range(self.integration_steps): x1 = x + dx x2, _ = self.model(x1, coords) if self.ensemble_perturb: dx1 = x2 - xd dx = dx1 + noise_amplitude * (dx - dx.mean(dim=0)) else: dx = x2 - xd gamma = torch.norm(x) / torch.norm(x + dx) return x + dx * noise_amplitude * gamma, coords