Earth2Studio is now OSS!

earth2studio.models.batch.batch_func#

class earth2studio.models.batch.batch_func[source]#

Batch utility decorator which can be added to prognostic and diagnostic models to help enable support for automatic batching of data. This class contains a decorator function which should be added to calls where this functionality is desired.

Note

A model attributes input_coords and output_coords must have “batch” as the coordinate system of the first dimensions. I.e. first key entry needs to be “batch”.

Example

class Model():

    input_coords = OrderedDict([("batch", np.empty(0)), ...])
    output_coords = OrderedDict([("batch", np.empty(0)), ...])

    @batch_func()
    def __call__(
        self,
        x: torch.Tensor,
        coords: CoordSystem,
    ) -> tuple[torch.Tensor, CoordSystem]:
        ...