.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/17_io_performance.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_17_io_performance.py: IO Backend Performance ======================== Leverage different IO backends for storing inference results. This example explores IO backends inside Earth2Studio and how they can be used to write data to different formats / locations. The IO is a core part of any inference pipeline and depending on the desired target, can dramatically impact performance. This example will help navigate users through the use of different IO backend APIs in a simple workflow. In this example you will learn: - Initializing, creating arrays and writing with the Zarr IO backend - Initializing, creating arrays and writing with the NetCDF IO backend - Initializing and writing with the Asynchronous Non-blocking Zarr IO backend - Discussing performance implications and strategies that can be used .. GENERATED FROM PYTHON SOURCE LINES 37-44 .. code-block:: Python # /// script # dependencies = [ # "earth2studio[dlwp] @ git+https://github.com/NVIDIA/earth2studio.git", # "matplotlib", # ] # /// .. GENERATED FROM PYTHON SOURCE LINES 45-50 Set Up ------ To demonstrate different IO, this example will use a simple ensemble workflow that we will manually create ourselves. One could use the built in workflow in Earth2Studio however, this will allow us to better understand the APIs. .. GENERATED FROM PYTHON SOURCE LINES 52-58 We need the following components: - Datasource: Pull data from the GFS data api :py:class:`earth2studio.data.GFS`. - Prognostic Model: Use the built in DLWP model :py:class:`earth2studio.models.px.DLWP`. - Perturbation Method: Use the standard Gaussian method :py:class:`earth2studio.perturbation.Gaussian`. - IO Backends: Use a few IO Backends including :py:class:`earth2studio.io.AsyncZarrBackend`, :py:class:`earth2studio.io.NetCDF4Backend` and :py:class:`earth2studio.io.ZarrBackend`. .. GENERATED FROM PYTHON SOURCE LINES 60-89 .. code-block:: Python import os os.makedirs("outputs", exist_ok=True) from dotenv import load_dotenv load_dotenv() # TODO: make common example prep function import torch from earth2studio.data import GFS, DataSource, fetch_data from earth2studio.io import AsyncZarrBackend, IOBackend, NetCDF4Backend, ZarrBackend from earth2studio.models.px import DLWP, PrognosticModel from earth2studio.perturbation import Gaussian, Perturbation # Get the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the cBottle data source package = DLWP.load_default_package() model = DLWP.load_model(package) model = model.to(device) # Create the ERA5 data source ds = GFS() # Create perturbation method pt = Gaussian() .. GENERATED FROM PYTHON SOURCE LINES 90-97 Creating a Simple Ensemble Workflow ----------------------------------- Start with creating a simple ensemble inference workflow. This is essentially a simpler version of the built in ensemble workflow :py:meth:`earth2studio.run.ensemble`. In this case, this is for an ensemble inference workflow that will predict a 5 day forecast for Christmas 2022. Following standard Earth2Studio practices, the function accepts initialized prognostic, data source, io backend and perturbation method. .. GENERATED FROM PYTHON SOURCE LINES 99-200 .. code-block:: Python import os import time from datetime import datetime, timedelta import numpy as np from tqdm import tqdm from earth2studio.utils.coords import map_coords, split_coords from earth2studio.utils.time import to_time_array times = [datetime(2022, 12, 20)] nsteps = 20 # Assuming 6-hour time steps def christmas_five_day_ensemble( times: list[datetime], nsteps: int, prognostic: PrognosticModel, data: DataSource, io: IOBackend, perturbation: Perturbation, nensemble: int = 8, device: str = "cuda", ) -> None: """Ensemble inference example""" # ========================================== # Fetch Initialization Data prognostic_ic = prognostic.input_coords() times = to_time_array(times) x, coords0 = fetch_data( source=data, time=times, variable=prognostic_ic["variable"], lead_time=prognostic_ic["lead_time"], device=device, ) # ========================================== # ========================================== # Set up IO backend by pre-allocating arrays (not needed for AsyncZarrBackend) total_coords = prognostic.output_coords(prognostic.input_coords()).copy() if "batch" in total_coords: del total_coords["batch"] total_coords["time"] = times total_coords["lead_time"] = np.asarray( [ prognostic.output_coords(prognostic.input_coords())["lead_time"] * i for i in range(nsteps + 1) ] ).flatten() total_coords.move_to_end("lead_time", last=False) total_coords.move_to_end("time", last=False) total_coords = {"ensemble": np.arange(nensemble)} | total_coords variables_to_save = total_coords.pop("variable") io.add_array(total_coords, variables_to_save) # ========================================== # ========================================== # Run inference coords = {"ensemble": np.arange(nensemble)} | coords0.copy() x = x.unsqueeze(0).repeat(nensemble, *([1] * x.ndim)) # Map lat and lon if needed x, coords = map_coords(x, coords, prognostic_ic) # Perturb ensemble x, coords = perturbation(x, coords) # Create prognostic iterator model = prognostic.create_iterator(x, coords) with tqdm( total=nsteps + 1, desc="Running batch inference", position=1, leave=False, ) as pbar: for step, (x, coords) in enumerate(model): # Dump result to IO, split_coords separates variables to different arrays x, coords = map_coords(x, coords, {"variable": np.array(["t2m", "tcwv"])}) io.write(*split_coords(x, coords)) pbar.update(1) if step == nsteps: break # ========================================== def get_folder_size(folder_path: str) -> int: """Get folder size in megabytes""" if os.path.isfile(folder_path): return os.path.getsize(folder_path) / (1024 * 1024) total_size = 0 for dirpath, dirnames, filenames in os.walk(folder_path): for filename in filenames: file_path = os.path.join(dirpath, filename) total_size += os.path.getsize(file_path) return total_size / (1024 * 1024) .. GENERATED FROM PYTHON SOURCE LINES 201-209 Local Storage Zarr IO --------------------- As a base line, lets run the Zarr IO backend saving it to local disk. Local IO storage is typically preferred since we can then access the data after the inference pipeline is finished using standard libraries. Chunking play an important role on performance, both with respect to compression and also when accessing data. Here we will chunk the output data based on time and lead_time .. GENERATED FROM PYTHON SOURCE LINES 211-222 .. code-block:: Python io = ZarrBackend( "outputs/17_io_sync.zarr", chunks={"time": 1, "lead_time": 1}, backend_kwargs={"overwrite": True}, ) start_time = time.time() christmas_five_day_ensemble(times, nsteps, model, ds, io, pt, device=device) zarr_local_clock = time.time() - start_time .. rst-class:: sphx-glr-script-out .. code-block:: none Fetching GFS data: 0%| | 0/7 [00:00 Running batch inference: 0%| | 0/21 [00:00 Running batch inference: 0%| | 0/21 [00:00 Running batch inference: 0%| | 0/21 [00:00 Running batch inference: 0%| | 0/21 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 17_io_performance.py <17_io_performance.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 17_io_performance.zip <17_io_performance.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_