Earth2Studio is now OSS!

Data Movement#

Keeping Data Transparent

“Show me your code and conceal your data structures, and I shall continue to be mystified. Show me your data structures, and I won’t usually need your code; it’ll be obvious.” - Fred Brooks

Earth2Studio aims to keep data simple and interpretable between components. Given that this package interacts with geo-physical data, the common data-structure inside workflows is the pairing of:

  1. A PyTorch tensor (torch.Tensor) on the inference device holding the array data of interest.

  2. An OrderedDict of numpy arrays (CoordSystem) that represents the geo-phyiscal coordinate system the tensor represents.

For example, perturbation methods operate by using a data tensor and coordinate system to generate a noise tensor:

class Perturbation(Protocol):
    """Perturbation interface."""

    @torch.inference_mode()
    def __call__(
        self,
        x: torch.Tensor,
        coords: CoordSystem,
    ) -> tuple[torch.Tensor, CoordSystem]:
        """Apply perturbation method to input tensor

        Parameters
        ----------
        x : torch.Tensor
            Input tensor intended to apply perturbation on
        coords : CoordSystem
            Ordered dict representing coordinate system that describes the tensor

        Returns
        -------
        tuple[torch.Tensor, CoordSystem]:
            Output tensor and respective coordinate system dictionary
        """
        pass

In later sections, users will find that most components have APIs that either generate or interact with these two data structures. The combonation of both the data tensor and respective coordinate system provides complete information one needs to interpret any stage of a workflow.

Note

Data is always moved between components using unnormalized physical units.

Coordinate Systems#

As previously discussed, coordinate dictionaries are a critical part of Earth-2 Inference Studio’s data movement. We wanted the coordinate object to be a fairly primitive data object that allows users to interact with the data outside the project and keep things transparent in workflows. Inside the package these are typed as CoordSystem which is defined as the following:

CoordSystem = NewType("CoordSystem", OrderedDict[str, np.ndarray])

The dictionary is ordered since the keys correspond the the dimensions of the associated data tensor. Let’s consider a simple example of a 2D lat-lon grid:

x = torch.randn(181, 360)

coords = OrderedDict({
    "lat": np.linspace(-90, 90, 181),
    "lon": np.linspace(0, 360, 360, endpoint=False)
})

Much of Earth2Studio typically operates on a lat-lon grid but it’s not required to.

Standard Coordinate Names#

Earth2Studio has a dimension naming standard for its built in feature set. We encourage users to follow similar naming schemes for compatability between Earth-2 Inference Studio when possible and the packages we interface with.

Key

Description

Type

batch

Dimension representing the batch dimension of the data. Used to denote a “free” dimension, consult batching docs for more details.

np.empty(0)

time

Time dimension, represented via numpy arrays of datetime objects.

np.ndarray[np.datetime64[ns]] (TimeArray)

lead_time

Lead time is used to denote a dimension that indexes over forecast steps.

np.ndarray[np.timedelta64[ns]] (LeadTimeArray)

variable

Dimension representing physical variable (atmospheric, surface, etc). Earth-2 Inference Studio has its own naming convention. See Lexicon docs for more more details.

np.ndarray[str] (VariableArray)

lat

Lattitude coordinate array

np.ndarray[float]

lon

Longitude coordinate array

np.ndarray[float]

Note

np.empty(0) is used to denote variable axis in the coordinate system. Namely a dimension in the tensor that can be of any size, greater than 0. Typically this is the batch dimension.

Coordinate Utilities#

The downside of using a dictionary to store coordinates is that manipulating the data tensor and then updating the coordinate array is a manual process. To help make this process less tedious, Earth2Studio has several utility functions that make interacting with coordinates easier. The bulk of these can be found in the Earth2Studio Utilities.

Warning

🚧 Under construction, todo: add some example here! 🚧

Model Coordinates#

Models are components where coordinate systems are essential both, for defining what data is needed and also what data is produced. Both Prognostic Models and Diagnostic Models require two functions that serve coordinate information:

  • input_coords() : A function that returns the expected input coordinate system of the model. A new dictionary should be returned every time.

  • output_coords() : A function that returns the expected output coordinate system of the model given an input coordinate system. This function should also validate the input coordinate dictionary.

For an example, consider the FourCastNet model’s implementations:

    def input_coords(self) -> CoordSystem:
        """Input coordinate system of the prognostic model

        Returns
        -------
        CoordSystem
            Coordinate system dictionary
        """
        return OrderedDict(
            {
                "batch": np.empty(0),
                "lead_time": np.array([np.timedelta64(0, "h")]),
                "variable": np.array(VARIABLES),
                "lat": np.linspace(90, -90, 720, endpoint=False),
                "lon": np.linspace(0, 360, 1440, endpoint=False),
            }
        )

    @batch_coords()
    def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
        """Output coordinate system of the prognostic model

        Parameters
        ----------
        input_coords : CoordSystem
            Input coordinate system to transform into output_coords

        Returns
        -------
        CoordSystem
            Coordinate system dictionary
        """

        output_coords = OrderedDict(
            {
                "batch": np.empty(0),
                "lead_time": np.array([np.timedelta64(6, "h")]),
                "variable": np.array(VARIABLES),
                "lat": np.linspace(90, -90, 720, endpoint=False),
                "lon": np.linspace(0, 360, 1440, endpoint=False),
            }
        )

        test_coords = input_coords.copy()
        test_coords["lead_time"] = (
            test_coords["lead_time"] - input_coords["lead_time"][-1]
        )
        target_input_coords = self.input_coords()
        for i, key in enumerate(target_input_coords):
            if key != "batch":
                handshake_dim(test_coords, key, i)
                handshake_coords(test_coords, target_input_coords, key)

        output_coords = output_coords.copy()
        output_coords["batch"] = input_coords["batch"]
        output_coords["lead_time"] = (
            output_coords["lead_time"] + input_coords["lead_time"]
        )

        return output_coords

The input_coords() function provides the coordinate spec of the input tensor expected by the model. In the case of FourCastNet, it’s expecting a input tensor of shape [...,1,26,720,1440] which corresponds to [...,lead,variables,lat,lon].

The output_coords() function provides validation and transformation of coordinates by the model. I.e. one should be able to deduce the models output data without executing the forward pass of the model. This function requires an input coordinate system that represent the input data it received. The first step is validating the input_coords using the coordinate handhshake utils functions. Next the output coordinate systems are built. This function is a place to store the complexity of a model’s forward process for other Earth2Studio components.

We encourage users to explore the coordinate transforms of models to learn more about how they operate:

from earth2studio.models.px import FCN
model = FCN(None, None, None)
input_coords = model.input_coords()
output_coords = model.output_coords(input_coords)
print("Input  lead:", input_coords['lead_time'])
print("Output lead:", output_coords['lead_time'])
output_coords = model.output_coords(output_coords)
print("Output lead 2:", output_coords['lead_time'])

The output of the following script will be:

Input  lead: [0]
Output lead: [6]
Output lead 2: [12]

Note

The batch dimension was not discussed intentially. Think about it like a free dimension. More information can be found in the Batch Dimension section.

Inference on the GPU#

It is beneficial to leverage the GPU for as many processes as possible. Earth2Studio aims to get data from the data source and immediately convert it into the tensor, coord data struction on the device. From there, the data is kept on the GPU until the very last moment when writes are needed to in-memory or to file.

earth2studio-data

In the figure above, that the data is first pulled from the data source as an Xarray data array which is then then converted to a tensor. The data remain on the device, denoted by the GPU boundary, until it needs to be written by the IO component.

Data Sources and Xarray

This may raise the question: Why do datasources not output directly to tensor and coordinate dictionaries? This is an opinionated decision due to the fact that these data sources need to store data on the CPU regardless and can be extremely useful outside of the context of this package. Thus they return Xarray data arrays which are is nothing more than a fancy data array with a coordinate system attached to it!