Output Handling#
While input data handling is primarily managed by the data sources in
earth2studio.data
, output handling is managed by the IO backends available
in earth2studio.io
.
These backends are designed to balance the ability for users to customize the arrays and
metadata within the exposed backend while also making it easy to design reusable
workflows.
The key extension of the typical (x, coords)
data structure movement throughout
the rest of the earth2studio
code and output store compatibility is the notion of
an array_name
. Names distinguish between different arrays within the backend and
are currently a requirement for storing Datasets
in xarray
, zarr
, and netcdf
.
This means that the user must supply a name when adding an array to a store or when
writing an array. A frequent pattern is to extract one dimension of an array,
such as "variable"
to act as individual arrays in the backend, see the examples below.
IO Backend Interface#
The full requirements for a standard IO backend are defined explicitly in the
earth2studio/io/base.py
.
@runtime_checkable
class IOBackend(Protocol):
"""Interface for a generic IO backend."""
def __init__(
self,
) -> None:
pass
def add_array(
self, coords: CoordSystem, array_name: str | list[str], **kwargs: dict[str, Any]
) -> None:
"""
Add an array with `array_name` to the existing IO backend object.
Parameters
----------
coords : OrderedDict
Ordered dictionary of representing the dimensions and coordinate data
of x.
array_name : str
Name of the arrays that will be initialized with coordinates as dimensions.
kwargs : dict[str, Any], optional
Optional keyword arguments that will be passed to the IO backend constructor.
"""
pass
def write(
self,
x: torch.Tensor | list[torch.Tensor],
coords: CoordSystem,
array_name: str | list[str],
) -> None:
"""
Write data to the current backend using the passed array_name.
Parameters
----------
x : torch.Tensor | list[torch.Tensor]
Tensor(s) to be written to zarr store.
coords : OrderedDict
Coordinates of the passed data.
array_name : str | list[str]
Name(s) of the array(s) that will be written to.
"""
pass
Note
IO Backends do not need to inherit this protocol; this is simply used to define the required APIs. Some built-in IO backends also may offer additional functionality that is not universally supported (and hence not required).
There are two important methods that must be supported: add_array
, which
adds an array to the underlying store and any attached coordinates, and write
,
which explicitly stores the provided data in the backend.
The write
command may induce synchronization if the input tensor resides on the GPU
and the store.
Most stores make a conversion from PyTorch to numpy in this process.
The earth2studio.io.kv
backend has the option for storing data on the GPU, which
can be done asynchronously.
Most data stores offer several additional utilities such as __contains__
,
__getitem__
, __len__
, and __iter__
. For examples, see the implementation in
earth2studio.io.ZarrBackend
:
def __init__(
self,
file_name: str = None,
chunks: dict[str, int] = {},
backend_kwargs: dict[str, Any] = {"overwrite": False},
) -> None:
if file_name is None:
self.store = zarr.storage.MemoryStore()
else:
self.store = zarr.storage.DirectoryStore(file_name)
self.root = zarr.group(self.store, **backend_kwargs)
# Read data from file, if available
self.coords: CoordSystem = OrderedDict({})
self.chunks = chunks.copy()
for array in self.root:
dims = self.root[array].attrs["_ARRAY_DIMENSIONS"]
for dim in dims:
if dim not in self.coords:
self.coords[dim] = self.root[dim]
for array in self.root:
if array not in self.coords:
dims = self.root[array].attrs["_ARRAY_DIMENSIONS"]
for c, d in zip(self.root[array].chunks, dims):
self.chunks[d] = c
def __contains__(self, item: str) -> bool:
"""Checks if item in Zarr Group.
Parameters
----------
item : str
"""
return self.root.__contains__(item)
def __getitem__(self, item: str) -> zarr.core.Array:
"""Gets item in Zarr Group.
Parameters
----------
item : str
"""
return self.root.__getitem__(item)
def __len__(
self,
) -> int:
"""Gets length of Zarr Group."""
return self.root.__len__()
def __iter__(
self,
) -> Iterator:
"""Return an iterator over Zarr Group member names."""
return self.root.__iter__()
Because of datetime
compatibility, we recommend using the ZarrBackend
as a default.
Initializing a Store#
A common data pattern seen throughout our example workflows is to initialize the
variables and dimensions of a backend using a complete CoordSystem
. For example:
# Build a complete CoordSystem
total_coords = OrderedDict(
dict(
'ensemble': ...,
'time': ...,
'lead_time': ...,
'variable': ...,
'lat': ...,
'lon': ...
)
)
# Give an informative array name
array_name = 'fields'
# Initialize all dimensions in total_coords and the array 'fields'
io.add_array(total_coords, 'fields')
It can be tedious to define each coordinate and dimension, luckily if we have a prognostic or diagnostic model, most of this information is already available. Here is a robust example of such a use-case:
# Set up IO backend
# assume we have `prognostic model`, `time` and `array_name`
# Copy prognostic model output coordinates
total_coords = OrderedDict(
{
k: v for k, v in prognostic.output_coords(prognostic.input_coords).items() if
(k != "batch") and (v.shape != 0)
}
)
total_coords["time"] = time
total_coords["lead_time"] = np.asarray(
[total_coords["lead_time"] * i for i in range(nsteps + 1)]
).flatten()
total_coords.move_to_end("lead_time", last=False)
total_coords.move_to_end("time", last=False)
io.add_array(total_coords, array_name)
Prognostic models, diagnostic models, statistics, and metrics are required to have a
output_coords
method which maps from an input coordinate to a corresponding output
coordinate. This method is meant to simulate the result of __call__
without having
to actually compute the forward call of the method. See the API documentation for more
details.
Another common IO use-case is to extract a particular dimension (usually variable
) as
the array names.
# A modification of the previous example:
var_names = total_coords.pop("variable")
io.add_array(total_coords, var_names)
Writing to the store#
Once the data arrays have been initialized in the backend, writing to those arrays is a single line of code.
x, coords = model(x, coords)
io.write(x, coords, array_name)
If, as above, the user is extracting a dimension of the tensor to use as array names
then they can make use of earth2studio.utils.coords.split_coords
:
io.write(*split_coords(x, coords, dim = "variable"))