Utils
safe_index(tensor, index, device)
Safely indexes a tensor using a given index and returns the result on a specified device.
Note can implement forcing with return tensor[index.to(tensor.device)].to(device) but has costly migration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensor
|
Tensor
|
The tensor to be indexed. |
required |
index
|
Tensor
|
The index to use for indexing the tensor. |
required |
device
|
device
|
The device on which the result should be returned. |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
The indexed tensor on the specified device. |
Raises:
Type | Description |
---|---|
ValueError
|
If tensor, index, and device are not all on the same device. |
Source code in bionemo/moco/interpolants/discrete_time/utils.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
|