Source code for earth2studio.run

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from datetime import datetime
from math import ceil

import numpy as np
import torch
from loguru import logger
from tqdm import tqdm

from earth2studio.data import DataSource, fetch_data
from earth2studio.io import IOBackend
from earth2studio.models.dx import DiagnosticModel
from earth2studio.models.px import PrognosticModel
from earth2studio.perturbation import Perturbation
from earth2studio.utils.coords import CoordSystem, map_coords, split_coords
from earth2studio.utils.time import to_time_array

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)


# sphinx - deterministic start
[docs] def deterministic( time: list[str] | list[datetime] | list[np.datetime64], nsteps: int, prognostic: PrognosticModel, data: DataSource, io: IOBackend, output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, ) -> IOBackend: """Built in deterministic workflow. This workflow creates a determinstic inference pipeline to produce a forecast prediction using a prognostic model. Parameters ---------- time : list[str] | list[datetime] | list[np.datetime64] List of string, datetimes or np.datetime64 nsteps : int Number of forecast steps prognostic : PrognosticModel Prognostic model data : DataSource Data source io : IOBackend IO object output_coords: CoordSystem, optional IO output coordinate system override, by default OrderedDict({}) device : torch.device, optional Device to run inference on, by default None Returns ------- IOBackend Output IO object """ # sphinx - deterministic end logger.info("Running simple workflow!") # Load model onto the device device = ( device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) # sphinx - fetch data start # Fetch data from data source and load onto device prognositc_ic = prognostic.input_coords() time = to_time_array(time) x, coords = fetch_data( source=data, time=time, variable=prognositc_ic["variable"], lead_time=prognositc_ic["lead_time"], device=device, ) logger.success(f"Fetched data from {data.__class__.__name__}") # sphinx - fetch data end # Set up IO backend total_coords = prognostic.output_coords(prognostic.input_coords()).copy() for key, value in prognostic.output_coords( prognostic.input_coords() ).items(): # Scrub batch dims if value.shape == (0,): del total_coords[key] total_coords["time"] = time total_coords["lead_time"] = np.asarray( [ prognostic.output_coords(prognostic.input_coords())["lead_time"] * i for i in range(nsteps + 1) ] ).flatten() total_coords.move_to_end("lead_time", last=False) total_coords.move_to_end("time", last=False) for key, value in total_coords.items(): total_coords[key] = output_coords.get(key, value) var_names = total_coords.pop("variable") io.add_array(total_coords, var_names) # Map lat and lon if needed x, coords = map_coords(x, coords, prognostic.input_coords()) # Create prognostic iterator model = prognostic.create_iterator(x, coords) logger.info("Inference starting!") with tqdm(total=nsteps + 1, desc="Running inference") as pbar: for step, (x, coords) in enumerate(model): # Subselect domain/variables as indicated in output_coords x, coords = map_coords(x, coords, output_coords) io.write(*split_coords(x, coords)) pbar.update(1) if step == nsteps: break logger.success("Inference complete") return io
# sphinx - diagnostic start
[docs] def diagnostic( time: list[str] | list[datetime] | list[np.datetime64], nsteps: int, prognostic: PrognosticModel, diagnostic: DiagnosticModel, data: DataSource, io: IOBackend, output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, ) -> IOBackend: """Built in diagnostic workflow. This workflow creates a determinstic inference pipeline that couples a prognostic model with a diagnostic model. Parameters ---------- time : list[str] | list[datetime] | list[np.datetime64] List of string, datetimes or np.datetime64 nsteps : int Number of forecast steps prognostic : PrognosticModel Prognostic model diagnostic: DiagnosticModel Diagnostic model, must be on same coordinate axis as prognostic data : DataSource Data source io : IOBackend IO object output_coords: CoordSystem, optional IO output coordinate system override, by default OrderedDict({}) device : torch.device, optional Device to run inference on, by default None Returns ------- IOBackend Output IO object """ # sphinx - diagnostic end logger.info("Running diagnostic workflow!") # Load model onto the device device = ( device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) diagnostic = diagnostic.to(device) # Fetch data from data source and load onto device prognositc_ic = prognostic.input_coords() diagnostic_ic = diagnostic.input_coords() time = to_time_array(time) x, coords = fetch_data( source=data, time=time, variable=prognositc_ic["variable"], lead_time=prognositc_ic["lead_time"], device=device, ) logger.success(f"Fetched data from {data.__class__.__name__}") # Set up IO backend total_coords = prognostic.output_coords(prognostic.input_coords()) for key, value in prognostic.output_coords( prognostic.input_coords() ).items(): # Scrub batch dims if key in diagnostic.output_coords(diagnostic_ic): total_coords[key] = diagnostic.output_coords(diagnostic_ic)[key] if value.shape == (0,): del total_coords[key] total_coords["time"] = time total_coords["lead_time"] = np.asarray( [ prognostic.output_coords(prognostic.input_coords())["lead_time"] * i for i in range(nsteps + 1) ] ).flatten() total_coords.move_to_end("lead_time", last=False) total_coords.move_to_end("time", last=False) for key, value in total_coords.items(): total_coords[key] = output_coords.get(key, value) var_names = total_coords.pop("variable") io.add_array(total_coords, var_names) # Map lat and lon if needed x, coords = map_coords(x, coords, prognositc_ic) # Create prognostic iterator model = prognostic.create_iterator(x, coords) logger.info("Inference starting!") with tqdm(total=nsteps + 1, desc="Running inference") as pbar: for step, (x, coords) in enumerate(model): # Run diagnostic x, coords = map_coords(x, coords, diagnostic_ic) x, coords = diagnostic(x, coords) # Subselect domain/variables as indicated in output_coords x, coords = map_coords(x, coords, output_coords) io.write(*split_coords(x, coords)) pbar.update(1) if step == nsteps: break logger.success("Inference complete") return io
# sphinx - ensemble start
[docs] def ensemble( time: list[str] | list[datetime] | list[np.datetime64], nsteps: int, nensemble: int, prognostic: PrognosticModel, data: DataSource, io: IOBackend, perturbation: Perturbation, batch_size: int | None = None, output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, ) -> IOBackend: """Built in ensemble workflow. Parameters ---------- time : list[str] | list[datetime] | list[np.datetime64] List of string, datetimes or np.datetime64 nsteps : int Number of forecast steps nensemble : int Number of ensemble members to run inference for. prognostic : PrognosticModel Prognostic models data : DataSource Data source io : IOBackend IO object perturbation_method : Perturbation Method to perturb the initial condition to create an ensemble. batch_size: int, optional Number of ensemble members to run in a single batch, by default None. output_coords: CoordSystem, optional IO output coordinate system override, by default OrderedDict({}) device : torch.device, optional Device to run inference on, by default None Returns ------- IOBackend Output IO object """ # sphinx - ensemble end logger.info("Running ensemble inference!") # Load model onto the device device = ( device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) # Fetch data from data source and load onto device prognositc_ic = prognostic.input_coords() time = to_time_array(time) x0, coords0 = fetch_data( source=data, time=time, variable=prognositc_ic["variable"], lead_time=prognositc_ic["lead_time"], device="cpu", ) logger.success(f"Fetched data from {data.__class__.__name__}") # Set up IO backend with information from output_coords (if applicable). total_coords = {"ensemble": np.arange(nensemble)} | coords0.copy() total_coords["lead_time"] = np.asarray( [ prognostic.output_coords(prognostic.input_coords())["lead_time"] * i for i in range(nsteps + 1) ] ).flatten() for key, value in total_coords.items(): total_coords[key] = output_coords.get(key, value) variables_to_save = total_coords.pop("variable") io.add_array(total_coords, variables_to_save) # Compute batch sizes if batch_size is None: batch_size = nensemble batch_size = min(nensemble, batch_size) number_of_batches = ceil(nensemble / batch_size) logger.info( f"Starting {nensemble} Member Ensemble Inference with \ {number_of_batches} number of batches." ) batch_id = 0 for batch_id in tqdm( range(0, nensemble, batch_size), total=number_of_batches, desc="Total Ensemble Batches", ): # Get fresh batch data x = x0.to(device) # Expand x, coords for ensemble mini_batch_size = min(batch_size, nensemble - batch_id) coords = { "ensemble": np.arange(batch_id, batch_id + mini_batch_size) } | coords0.copy() # Unsqueeze x for batching ensemble x = x.unsqueeze(0).repeat(mini_batch_size, *([1] * x.ndim)) # Map lat and lon if needed x, coords = map_coords(x, coords, prognositc_ic) # Perturb ensemble x, coords = perturbation(x, coords) # Create prognostic iterator model = prognostic.create_iterator(x, coords) with tqdm( total=nsteps + 1, desc=f"Running batch {batch_id} inference", leave=False ) as pbar: for step, (x, coords) in enumerate(model): # Subselect domain/variables as indicated in output_coords x, coords = map_coords(x, coords, output_coords) io.write(*split_coords(x, coords)) pbar.update(1) if step == nsteps: break batch_id += 1 logger.success("Inference complete") return io