Skip to content

Utils

safe_index(tensor, index, device)

Safely indexes a tensor using a given index and returns the result on a specified device.

Note can implement forcing with return tensor[index.to(tensor.device)].to(device) but has costly migration.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to be indexed.

required
index Tensor

The index to use for indexing the tensor.

required
device device

The device on which the result should be returned.

required

Returns:

Name Type Description
Tensor

The indexed tensor on the specified device.

Raises:

Type Description
ValueError

If tensor, index, and device are not all on the same device.

Source code in bionemo/moco/interpolants/discrete_time/utils.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def safe_index(tensor: Tensor, index: Tensor, device: torch.device):
    """Safely indexes a tensor using a given index and returns the result on a specified device.

    Note can implement forcing with  return tensor[index.to(tensor.device)].to(device) but has costly migration.

    Args:
        tensor (Tensor): The tensor to be indexed.
        index (Tensor): The index to use for indexing the tensor.
        device (torch.device): The device on which the result should be returned.

    Returns:
        Tensor: The indexed tensor on the specified device.

    Raises:
        ValueError: If tensor, index, and device are not all on the same device.
    """
    if not (tensor.device == index.device == device):
        raise ValueError(
            f"Tensor, index, and device must all be on the same device. "
            f"Got tensor.device={tensor.device}, index.device={index.device}, and device={device}."
        )

    return tensor[index]