Batch Dimension#

This section of the user guide expands on how batching is handled inside Earth2Studio. As discussed in data movement section, there is a dedicated coordinate axis batch which is commonly used in many of the model implementations. batch represents a dynamic axis that can be of any size, enabling models to better utilize compute resources.

Batch dimensions have the following rules:

  • Must be able to support any nonzero size

  • Must be the leading dimensions of the coordinate system

  • Must be set to np.empty(0) in objects coordinate property

Good Coordinate Definitions#

batch is the leading dimension with a value of np.empty(0).

coords = OrderedDict(
        {
            "batch": np.empty(0),
            "lead_time": np.array([np.timedelta64(0, "h")]),
            "variable": np.array(VARIABLES),
            "lat": np.linspace(90, -90, 720, endpoint=False),
            "lon": np.linspace(0, 360, 1440, endpoint=False),
        }
    )

Other coordinates can have a value of np.empty(0) to denote an additional dynamic axis but imply a required data type. In this case, this model supports a batch but must have a time axis needs to be a Numpy array of type np.datetime64. See data movement section for expected types.

coords = OrderedDict(
        {
            "batch": np.empty(0),
            "time": np.empty(0),
            "variable": np.array(VARIABLES),
            "lat": np.linspace(90, -90, 720, endpoint=False),
            "lon": np.linspace(0, 360, 1440, endpoint=False),
        }
    )

Bad Coordinate Definitions#

Batch dimension is not leading.

coords = OrderedDict(
        {
            "variable": np.array(VARIABLES),
            "batch": np.empty(0),
            "lat": np.linspace(90, -90, 720, endpoint=False),
            "lon": np.linspace(0, 360, 1440, endpoint=False),
        }
    )

Batch dimension is not of size 0.

coords = OrderedDict(
        {
            "batch": np.zeros(1),
            "variable": np.array(VARIABLES),
            "lat": np.linspace(90, -90, 720, endpoint=False),
            "lon": np.linspace(0, 360, 1440, endpoint=False),
        }
    )

Batch Decorator#

While the use of batch dimension is useful for communicating a dimension that can accept variable input sizes, it’s not very convenient to manually manipulate data into a form that matches the batch dimension. To make using batch supporting models easier, Earth2Studio offers a utility decorator earth2studio.models.batch.batch_func which automates transforming extra leading dimensions into a batch one. This utility must be used in an object with coordinate properties.

The batch function does the following steps:

  1. Squeeze leading dims into a batch dimension

  2. Update batched input coordinates with a single batch index dimension

  3. Execute the wrapped function and get outputs

  4. Replace output batch coord with the batched input coordinates

  5. Unsqueeze the leading batch coordinate into original input dimensions

Consider the following example:

from collections import OrderedDict

import numpy as np
import torch

from earth2studio.models.batch import batch_func, batch_coords


class BatchModel:
    input_coords = OrderedDict({"batch": np.zeros(0), "dim1": np.arange(2)})

    @batch_coords()
    def output_coords(
        self,
        input_coords: OrderedDict
        ) -> OrderedDict:
        return OrderedDict({"batch": np.zeros(0), "dim2": np.arange(4)})

    @batch_func()
    def __call__(self, input, coords):
        print("Model Input:", input.size(), coords)
        out = torch.cat([input, input], dim=-1)
        out_c = self.output_coords(coords).copy()
        return out, out_c


input_coords = OrderedDict(
    {"batched_dim0": np.arange(2), "batched_dim1": np.arange(3), "dim1": np.arange(2)}
)
input = torch.randn(2, 3, 2)

model = BatchModel()
print("Input:", input.size(), input_coords)
output, output_coords = model(input, input_coords)
print("Output:", output.size(), output_coords)

The output of the following script will be:

Input: torch.Size([2, 3, 2]) OrderedDict([('batched_dim0', array([0, 1])), ('bacthed_dim1', array([0, 1, 2])), ('dim1', array([0, 1]))])

Model Input: torch.Size([6, 2]) OrderedDict([('batch', array([0, 1, 2, 3, 4, 5])), ('dim1', array([0, 1]))])

Output: torch.Size([2, 3, 4]) OrderedDict([('batched_dim0', array([0, 1])), ('bacthed_dim1', array([0, 1, 2])), ('dim2', array([0, 1, 2, 3]))])

Note that the leading two dimensions were squeezed into a single batch dimension before the execution of the models BatchModel.__call__(). The leading dimensions were then restored back while preserving the updated domain coordinates from the model’s output.

The batch decorator will also unsqueeze a batch axis to an input that is missing only batch from the input coordinate system with no additional dimensions. In this instance a batch size of one is implied. For example, using the model in the example above:

input_coords = OrderedDict(
    {"dim1": np.arange(2)}
)
input = torch.randn(2)

model = BatchModel()
print("Input:", input.size(), input_coords)
output, output_coords = model(input, input_coords)
print("Output:", output.size(), output_coords)

will execute successfully with an output of:

Input: torch.Size([2]) OrderedDict([('dim1', array([0, 1]))])

Model Input: torch.Size([1, 2]) OrderedDict([('batch', array([0])), ('dim1', array([0, 1]))])

Output: torch.Size([4]) OrderedDict([('dim2', array([0, 1, 2, 3]))])

Batch Dimension in IO#

The IO backends require users to pre-define the output coordinate system on which the data will be exported. Typically. a good way to do this is to look at the output coordinate system of the model, which will typically include a batch dimension. But the model won’t actually return a batch dimension, thus it’s a common pattern to replace this batch dimension with whatever leading coordinate the input will have.

For example, refer to the built-in workflows. The setup process for the IO backend is handled in a general manner by first getting the output coordinates of the model, removing empty dimensions such as batch and then prepending known leading dimensions like time.

total_coords = prognostic.output_coords.copy()
for key, value in prognostic.output_coords.items():
    if value.shape == (0,):
        del total_coords[key]
total_coords["time"] = time
total_coords["lead_time"] = np.asarray(
    [prognostic.output_coords["lead_time"] * i for i in range(nsteps + 1)]
).flatten()
total_coords.move_to_end("lead_time", last=False)
total_coords.move_to_end("time", last=False)
var_names = total_coords.pop("variable")
io.add_array(total_coords, var_names)