earth2studio.models.batch.batch_func#

class earth2studio.models.batch.batch_func[source]#

Batch utility decorator which can be added to prognostic and diagnostic models to help enable support for automatic batching of data. This class contains a decorator function which should be added to calls where this functionality is desired.

Note

A model attributes input_coords and output_coords must have “batch” as the coordinate system of the first dimensions. I.e. first key entry needs to be “batch”.

Note

When decorating a method of a model, such as __call__, the method is required to have a signature of (self, *args: Any, **kwargs: Any) -> tuple[torch.Tensor, CoordSystem]. All positional arguments must be a sequence of (x, CoordSystem) pairs. All kwargs are passed through to the wrapped function without modification.

Example

class Model():

    input_coords = OrderedDict([("batch", np.empty(0)), ...])
    output_coords = OrderedDict([("batch", np.empty(0)), ...])

    @batch_func()
    def __call__(
        self,
        x: torch.Tensor,
        coords: CoordSystem,
    ) -> tuple[torch.Tensor, CoordSystem]:
        ...