# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal
import numpy as np
import torch
from earth2studio.utils.type import CoordSystem
[docs]
def handshake_dim(
input_coords: CoordSystem,
required_dim: str,
required_index: int | None = None,
) -> None:
"""Simple check to see if coordinate system has a dimension in a particular index
Parameters
----------
input_coords : CoordSystem
Input coordinate system to validate
required_dim : str
Required dimension (name of coordinate)
required_index : int, optional
Required index of dimension if needed, by default None
Raises
------
KeyError
If required dimension is not found in the input coordinate system
ValueError
If the required index is outside the dimensionality of the input coordinate system
ValueError
If dimension is not in the required index
Returns
-------
None
"""
if required_dim not in input_coords:
raise KeyError(
f"Required dimension {required_dim} not found in input coordinates"
)
input_dims = list(input_coords.keys())
if required_index is None:
return
try:
input_dims[required_index]
except IndexError:
raise ValueError(
f"Required index {required_index} outside dimensionality of input coordinate system of {len(input_dims)}"
)
if input_dims[required_index] != required_dim:
raise ValueError(
f"Required dimension {required_dim} not found in the required index {required_index} in dim list {input_dims}"
)
[docs]
def handshake_coords(
input_coords: CoordSystem,
target_coords: CoordSystem,
required_dim: str,
) -> None:
"""Simple check to see if the required dimensions have the same coordinate system
Parameters
----------
input_coords : CoordSystem
Input coordinate system to validate
target_coords : CoordSystem
Target coordinate system
required_dim : str
Required dimension (name of coordinate)
Raises
------
KeyError
If required dim is not present in coordinate systems
ValueError
If coordinates of required dimensions don't match
Returns
-------
None
"""
if required_dim not in input_coords:
raise KeyError(
f"Required dimension {required_dim} not found in input coordinates"
)
if required_dim not in target_coords:
raise KeyError(
f"Required dimension {required_dim} not found in target coordinates"
)
if not np.all(input_coords[required_dim] == target_coords[required_dim]):
raise ValueError(
f"Coordinate systems for required dim {required_dim} are not the same"
)
[docs]
def handshake_size(
input_coords: CoordSystem,
required_dim: str,
required_size: int,
) -> None:
"""Simple check to see if a coordinate system of a given dimension is a required
size
Parameters
----------
input_coords : CoordSystem
Input coordinate system to validate
required_dim : str
Required dimension (name of coordinate)
required_size : int
Required coordinate system size
Raises
------
KeyError
If required dim is not present in input coordinate system
ValueError
If required dimension is not of required size
Returns
-------
None
Note
----
Presently assumes coordinate system of given dimension is 1D
"""
if required_dim not in input_coords:
raise KeyError(
f"Required dimension {required_dim} not found in input coordinates"
)
if input_coords[required_dim].shape[0] != required_size:
raise ValueError(
f"Coordinate size for required dim {required_dim} is not of size {required_size}"
)
[docs]
def map_coords(
x: torch.Tensor,
input_coords: CoordSystem,
output_coords: CoordSystem,
method: Literal["nearest"] = "nearest",
ignore_batch: bool = True,
) -> tuple[torch.Tensor, CoordSystem]:
"""A basic interpolation util to map between coordinate systems with common
dimensions. Namely, `output_coords` should consist of keys are present in
`input_coords`. Note that `output_coords` do not need have all the dimensions of the
`input_coords`.
Parameters
----------
x : torch.Tensor
Input data to map
input_coords : CoordSystem
Respective input coordinate system
output_coords : CoordSystem
Target output coordinates to map.
method : Literal["nearest"], optional
Method to use for mapping numeric coordinates, by default "nearest"
ignore_batch: bool, optional
Ignore batch dimension in output coordinate if present, by default True
Returns
-------
tuple[torch.Tensor, CoordSystem]
Mapped data and coordinate system.
Raises
------
KeyError:
If output coordinate has a dimension not in the input coordinate
ValueError
If value in non-numeric output coordinate is not in input coordinate
"""
mapped_coords = input_coords.copy()
for key, value in output_coords.items():
if key in [
"batch",
"time",
"lead_time",
]: # TODO: Need better solution, time is numeric
continue
if key not in input_coords:
raise KeyError(f"Output coordinate dim {key} not found in input coords")
outc = value
inc = mapped_coords[key]
dim = list(input_coords).index(key)
if not np.issubdtype(value.dtype, np.number):
if not np.all(np.isin(outc, inc)):
raise ValueError(f"Error! Some elements of {outc} are not in {inc}.")
# Not numerical just sub select
# sort inputs and outputs before np.in1d
indx_inc = inc.argsort()
indx_outc = outc.argsort()
indx_rev_outc = indx_outc.argsort()
indx = np.where(
np.in1d(inc[indx_inc], outc[indx_outc], assume_unique=True)
)[0]
# undo sorting
indx = indx_inc[indx][indx_rev_outc]
if len(indx) != len(value):
raise ValueError(
f"Output coord dim {key} contains values not present in input"
)
mapped_coords[key] = outc
x = torch.index_select(
x, dim, torch.tensor(indx, dtype=torch.int32, device=x.device)
)
else:
# Method = nearest
c1 = np.repeat(inc[:, np.newaxis], outc.shape[0], axis=1)
c2 = np.repeat(outc[np.newaxis, :], inc.shape[0], axis=0)
c = np.abs(c1 - c2)
idx = np.argmin(c, axis=0)
x = torch.index_select(
x, dim, torch.tensor(idx, dtype=torch.int32, device=x.device)
)
mapped_coords[key] = outc
# TODO: Linear
# c = np.pad(array, pad_width=1, mode='edge')
# idx2 = numpy.where(c[idx+2] < c[idx] , idx+1, idx-1)
# a = torch.Tensor(input_coords[key][idx2] - input_coords[key][idx], device=x.device)
# y0 = torch.index_select(x, i, torch.IntTensor(idx, device=x.device))
# y1 = torch.index_select(x, i, torch.IntTensor(idx2, device=x.device))
# x0 = torch.Tensor(value - input_coords[key][idx], device=x.device)
# x1 = torch.Tensor(input_coords[key][idx2] - value, device=x.device)
# x = torch.where(a == 0, y0, (x1*y0 + x0*y1)/a)
return x, mapped_coords
[docs]
def split_coords(
x: torch.Tensor, coords: CoordSystem, dim: str = "variable"
) -> tuple[list[torch.Tensor], CoordSystem, np.ndarray]:
"""
A utility function to split a dimension from a (x,coords) pair and convert it into
a list of tensors, a CoordSystem, and the dimension that extract from coords.
Parameters
----------
x : torch.Tensor
Input tensor
coords : CoordSystem
Coordinates referring to the dimensions of x
dim : str
Name of the dimension in coords to split along
Returns
-------
list[torch.Tensor]
List of tensors extracted by splitting the extracted dimension from coords.
CoordSystem
The updated coord system with the extracted dimension removed.
np.ndarray
The values of the dimension extracted from the coordinate system.
"""
if dim not in coords:
raise ValueError(f"dim {dim} is not in coords: {list(coords)}.")
reduced_coords = coords.copy()
dim_index = list(reduced_coords).index(dim)
values = reduced_coords.pop(dim)
xs = [xi.squeeze(dim_index) for xi in x.split(1, dim=dim_index)]
return xs, reduced_coords, values