earth2studio.models.batch
.batch_func#
- class earth2studio.models.batch.batch_func[source]#
Batch utility decorator which can be added to prognostic and diagnostic models to help enable support for automatic batching of data. This class contains a decorator function which should be added to calls where this functionality is desired.
Note
A model attributes input_coords and output_coords must have “batch” as the coordinate system of the first dimensions. I.e. first key entry needs to be “batch”.
Example
class Model(): input_coords = OrderedDict([("batch", np.empty(0)), ...]) output_coords = OrderedDict([("batch", np.empty(0)), ...]) @batch_func() def __call__( self, x: torch.Tensor, coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: ...