Ot sampler
OTSampler
Sampler for Exact Mini-batch Optimal Transport Plan.
OTSampler implements sampling coordinates according to an OT plan (wrt squared Euclidean cost) with different implementations of the plan calculation. Code is adapted from https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py
Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
|
__init__(method='exact', device='cpu', num_threads=1)
Initialize the OTSampler class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
method
|
str
|
Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd). |
'exact'
|
device
|
Union[str, device]
|
The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu". |
'cpu'
|
num_threads
|
Union[int, str]
|
Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1. |
1
|
Raises:
Type | Description |
---|---|
ValueError
|
If the OT solver is not documented. |
NotImplementedError
|
If the OT solver is not implemented. |
Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|
_calculate_cost_matrix(x0, x1, mask=None)
Compute the cost matrix between a source and a target minibatch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x0
|
Tensor
|
shape (bs, *dim), noise from source minibatch. |
required |
x1
|
Tensor
|
shape (bs, *dim), data from source minibatch. |
required |
mask
|
Optional[Tensor]
|
mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tensor
|
shape (bs, bs), the cost matrix between noise and data in minibatch. |
Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
|
apply_augmentation(x0, x1, mask=None, replace=False, sort='x0')
Sample indices for noise and data in minibatch according to OT plan.
Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target minibatch and draw source and target samples from pi $(x,z) \sim \pi$.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x0
|
Tensor
|
shape (bs, *dim), noise from source minibatch. |
required |
x1
|
Tensor
|
shape (bs, *dim), data from source minibatch. |
required |
mask
|
Optional[Tensor]
|
mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None. |
None
|
replace
|
bool
|
sampling w/ or w/o replacement from the OT plan, default to False. |
False
|
sort
|
str
|
Optional Literal string to sort either x1 or x0 based on the input. |
'x0'
|
Returns:
Name | Type | Description |
---|---|---|
Tuple |
Tuple[Tensor, Tensor, Optional[Tensor]]
|
tuple of 2 tensors or 3 tensors if mask is used, represents the noise (plus mask) and data samples following OT plan pi. |
Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
|
get_ot_matrix(x0, x1, mask=None)
Compute the OT matrix between a source and a target minibatch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x0
|
Tensor
|
shape (bs, *dim), noise from source minibatch. |
required |
x1
|
Tensor
|
shape (bs, *dim), data from source minibatch. |
required |
mask
|
Optional[Tensor]
|
mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
p |
Tensor
|
shape (bs, bs), the OT matrix between noise and data in minibatch. |
Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
|
sample_map(pi, batch_size, replace=False)
Draw source and target samples from pi $(x,z) \sim \pi$.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pi
|
Tensor
|
shape (bs, bs), the OT matrix between noise and data in minibatch. |
required |
batch_size
|
int
|
The batch size of the minibatch. |
required |
replace
|
bool
|
sampling w/ or w/o replacement from the OT plan, default to False. |
False
|
Returns:
Name | Type | Description |
---|---|---|
Tuple |
Tuple[Tensor, Tensor]
|
tuple of 2 tensors, represents the indices of noise and data samples from pi. |
Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
|
to_device(device)
Moves all internal tensors to the specified device and updates the self.device
attribute.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device
|
str
|
The device to move the tensors to (e.g. "cpu", "cuda:0"). |
required |
Note
This method is used to transfer the internal state of the OTSampler to a different device.
It updates the self.device
attribute to reflect the new device and moves all internal tensors to the specified device.
Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|