Skip to content

Ot sampler

OTSampler

Sampler for Exact Mini-batch Optimal Transport Plan.

OTSampler implements sampling coordinates according to an OT plan (wrt squared Euclidean cost) with different implementations of the plan calculation. Code is adapted from https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class OTSampler:
    """Sampler for Exact Mini-batch Optimal Transport Plan.

    OTSampler implements sampling coordinates according to an OT plan (wrt squared Euclidean cost)
    with different implementations of the plan calculation. Code is adapted from https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py

    """

    def __init__(
        self,
        method: str = "exact",
        device: Union[str, torch.device] = "cpu",
        num_threads: int = 1,
    ) -> None:
        """Initialize the OTSampler class.

        Args:
            method (str): Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd).
            device (Union[str, torch.device], optional): The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
            num_threads (Union[int, str], optional): Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.

        Raises:
            ValueError: If the OT solver is not documented.
            NotImplementedError: If the OT solver is not implemented.
        """
        # ot_fn should take (a, b, M) as arguments where a, b are marginals and
        # M is a cost matrix
        if method == "exact":
            self.ot_fn: Callable[..., torch.Tensor] = partial(pot.emd, numThreads=num_threads)  # type: ignore
        elif method in {"sinkhorn", "unbalanced", "partial"}:
            raise NotImplementedError("OT solver other than 'exact' is not implemented.")
        else:
            raise ValueError(f"Unknown method: {method}")
        self.device = device

    def to_device(self, device: str):
        """Moves all internal tensors to the specified device and updates the `self.device` attribute.

        Args:
            device (str): The device to move the tensors to (e.g. "cpu", "cuda:0").

        Note:
            This method is used to transfer the internal state of the OTSampler to a different device.
            It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device.
        """
        self.device = device
        for attr_name in dir(self):
            if attr_name.startswith("_") and isinstance(getattr(self, attr_name), torch.Tensor):
                setattr(self, attr_name, getattr(self, attr_name).to(device))
        return self

    def sample_map(self, pi: Tensor, batch_size: int, replace: Bool = False) -> Tuple[Tensor, Tensor]:
        r"""Draw source and target samples from pi $(x,z) \sim \pi$.

        Args:
            pi (Tensor): shape (bs, bs), the OT matrix between noise and data in minibatch.
            batch_size (int): The batch size of the minibatch.
            replace (bool): sampling w/ or w/o replacement from the OT plan, default to False.

        Returns:
            Tuple: tuple of 2 tensors, represents the indices of noise and data samples from pi.
        """
        if pi.shape[0] != batch_size or pi.shape[1] != batch_size:
            raise ValueError("Shape mismatch: pi.shape = {}, batch_size = {}".format(pi.shape, batch_size))
        p = pi.flatten()
        p = p / p.sum()
        choices = torch.multinomial(p, batch_size, replacement=replace)
        return torch.div(choices, pi.shape[1], rounding_mode="floor"), choices % pi.shape[1]

    def _calculate_cost_matrix(self, x0: Tensor, x1: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        """Compute the cost matrix between a source and a target minibatch.

        Args:
            x0 (Tensor): shape (bs, *dim), noise from source minibatch.
            x1 (Tensor): shape (bs, *dim), data from source minibatch.
            mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

        Returns:
            Tensor: shape (bs, bs), the cost matrix between noise and data in minibatch.
        """
        if mask is None:
            # Flatten the input tensors
            x0, x1 = x0.reshape(x0.shape[0], -1), x1.reshape(x1.shape[0], -1)

            # Compute the cost matrix. For exact OT, we use squared Euclidean distance.
            M = torch.cdist(x0, x1) ** 2
        else:
            # Initialize the cost matrix
            M = torch.zeros((x0.shape[0], x1.shape[0]))
            # For each x0 sample, apply its mask to all x1 samples and calculate the cost
            for i in range(x0.shape[0]):
                x0i_mask = mask[i].unsqueeze(-1)
                masked_x1 = x1 * x0i_mask
                masked_x0 = x0[i] * x0i_mask
                cost = torch.cdist(masked_x0.reshape(1, -1), masked_x1.reshape(x1.shape[0], -1)) ** 2
                M[i] = cost
        return M

    def get_ot_matrix(self, x0: Tensor, x1: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        """Compute the OT matrix between a source and a target minibatch.

        Args:
            x0 (Tensor): shape (bs, *dim), noise from source minibatch.
            x1 (Tensor): shape (bs, *dim), data from source minibatch.
            mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

        Returns:
            p (Tensor): shape (bs, bs), the OT matrix between noise and data in minibatch.

        """
        # Compute the cost matrix
        M = self._calculate_cost_matrix(x0, x1, mask)
        # Set uniform weights for all samples in a minibatch
        a, b = pot.unif(x0.shape[0], type_as=M), pot.unif(x1.shape[0], type_as=M)

        p = self.ot_fn(a, b, M)
        # Handle exceptions
        if not torch.all(torch.isfinite(p)):
            raise ValueError("OT plan map is not finite, cost mean, max: {}, {}".format(M.mean(), M.max()))
        if torch.abs(p.sum()) < 1e-8:
            warnings.warn("Numerical errors in OT matrix, reverting to uniform plan.")
            p = torch.ones_like(p) / p.numel()

        return p

    def apply_augmentation(
        self,
        x0: Tensor,
        x1: Tensor,
        mask: Optional[Tensor] = None,
        replace: Bool = False,
        sort: Optional[Literal["noise", "x0", "data", "x1"]] = "x0",
    ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
        r"""Sample indices for noise and data in minibatch according to OT plan.

        Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target
        minibatch and draw source and target samples from pi $(x,z) \sim \pi$.

        Args:
            x0 (Tensor): shape (bs, *dim), noise from source minibatch.
            x1 (Tensor): shape (bs, *dim), data from source minibatch.
            mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
            replace (bool): sampling w/ or w/o replacement from the OT plan, default to False.
            sort (str): Optional Literal string to sort either x1 or x0 based on the input.

        Returns:
            Tuple: tuple of 2 tensors or 3 tensors if mask is used, represents the noise (plus mask) and data samples following OT plan pi.
        """
        if replace and sort is not None:
            raise ValueError("Cannot sample with replacement and sort")
        # Calculate the optimal transport
        pi = self.get_ot_matrix(x0, x1, mask)

        # Sample (x0, x1) mapping indices from the OT matrix
        i, j = self.sample_map(pi, x0.shape[0], replace=replace)
        if not replace and (sort == "noise" or sort == "x0"):
            sort_idx = torch.argsort(i)
            i = i[sort_idx]
            j = j[sort_idx]

            if not (i == torch.arange(x0.shape[0], device=i.device)).all():
                raise ValueError("x0_idx should be a tensor from 0 to size - 1 when sort is 'noise' or 'x0")
            noise = x0
            data = x1[j]
        elif not replace and (sort == "data" or sort == "x1"):
            sort_idx = torch.argsort(j)
            i = i[sort_idx]
            j = j[sort_idx]

            if not (j == torch.arange(x1.shape[0], device=j.device)).all():
                raise ValueError("x1_idx should be a tensor from 0 to size - 1 when sort is 'noise' or 'x0")
            noise = x0[i]
            data = x1
        else:
            noise = x0[i]
            data = x1[j]

        # Output the permuted samples in the minibatch
        if mask is not None:
            if mask.device != x0.device:
                mask = mask.to(x0.device)
            mask = mask[i]
        return noise, data, mask

__init__(method='exact', device='cpu', num_threads=1)

Initialize the OTSampler class.

Parameters:

Name Type Description Default
method str

Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd).

'exact'
device Union[str, device]

The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".

'cpu'
num_threads Union[int, str]

Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.

1

Raises:

Type Description
ValueError

If the OT solver is not documented.

NotImplementedError

If the OT solver is not implemented.

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    method: str = "exact",
    device: Union[str, torch.device] = "cpu",
    num_threads: int = 1,
) -> None:
    """Initialize the OTSampler class.

    Args:
        method (str): Choose which optimal transport solver you would like to use. Currently only support exact OT solvers (pot.emd).
        device (Union[str, torch.device], optional): The device on which to run the interpolant, either "cpu" or a CUDA device (e.g. "cuda:0"). Defaults to "cpu".
        num_threads (Union[int, str], optional): Number of threads to use for OT solver. If "max", uses the maximum number of threads. Default is 1.

    Raises:
        ValueError: If the OT solver is not documented.
        NotImplementedError: If the OT solver is not implemented.
    """
    # ot_fn should take (a, b, M) as arguments where a, b are marginals and
    # M is a cost matrix
    if method == "exact":
        self.ot_fn: Callable[..., torch.Tensor] = partial(pot.emd, numThreads=num_threads)  # type: ignore
    elif method in {"sinkhorn", "unbalanced", "partial"}:
        raise NotImplementedError("OT solver other than 'exact' is not implemented.")
    else:
        raise ValueError(f"Unknown method: {method}")
    self.device = device

_calculate_cost_matrix(x0, x1, mask=None)

Compute the cost matrix between a source and a target minibatch.

Parameters:

Name Type Description Default
x0 Tensor

shape (bs, *dim), noise from source minibatch.

required
x1 Tensor

shape (bs, *dim), data from source minibatch.

required
mask Optional[Tensor]

mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

None

Returns:

Name Type Description
Tensor Tensor

shape (bs, bs), the cost matrix between noise and data in minibatch.

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def _calculate_cost_matrix(self, x0: Tensor, x1: Tensor, mask: Optional[Tensor] = None) -> Tensor:
    """Compute the cost matrix between a source and a target minibatch.

    Args:
        x0 (Tensor): shape (bs, *dim), noise from source minibatch.
        x1 (Tensor): shape (bs, *dim), data from source minibatch.
        mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

    Returns:
        Tensor: shape (bs, bs), the cost matrix between noise and data in minibatch.
    """
    if mask is None:
        # Flatten the input tensors
        x0, x1 = x0.reshape(x0.shape[0], -1), x1.reshape(x1.shape[0], -1)

        # Compute the cost matrix. For exact OT, we use squared Euclidean distance.
        M = torch.cdist(x0, x1) ** 2
    else:
        # Initialize the cost matrix
        M = torch.zeros((x0.shape[0], x1.shape[0]))
        # For each x0 sample, apply its mask to all x1 samples and calculate the cost
        for i in range(x0.shape[0]):
            x0i_mask = mask[i].unsqueeze(-1)
            masked_x1 = x1 * x0i_mask
            masked_x0 = x0[i] * x0i_mask
            cost = torch.cdist(masked_x0.reshape(1, -1), masked_x1.reshape(x1.shape[0], -1)) ** 2
            M[i] = cost
    return M

apply_augmentation(x0, x1, mask=None, replace=False, sort='x0')

Sample indices for noise and data in minibatch according to OT plan.

Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target minibatch and draw source and target samples from pi $(x,z) \sim \pi$.

Parameters:

Name Type Description Default
x0 Tensor

shape (bs, *dim), noise from source minibatch.

required
x1 Tensor

shape (bs, *dim), data from source minibatch.

required
mask Optional[Tensor]

mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

None
replace bool

sampling w/ or w/o replacement from the OT plan, default to False.

False
sort str

Optional Literal string to sort either x1 or x0 based on the input.

'x0'

Returns:

Name Type Description
Tuple Tuple[Tensor, Tensor, Optional[Tensor]]

tuple of 2 tensors or 3 tensors if mask is used, represents the noise (plus mask) and data samples following OT plan pi.

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def apply_augmentation(
    self,
    x0: Tensor,
    x1: Tensor,
    mask: Optional[Tensor] = None,
    replace: Bool = False,
    sort: Optional[Literal["noise", "x0", "data", "x1"]] = "x0",
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
    r"""Sample indices for noise and data in minibatch according to OT plan.

    Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target
    minibatch and draw source and target samples from pi $(x,z) \sim \pi$.

    Args:
        x0 (Tensor): shape (bs, *dim), noise from source minibatch.
        x1 (Tensor): shape (bs, *dim), data from source minibatch.
        mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
        replace (bool): sampling w/ or w/o replacement from the OT plan, default to False.
        sort (str): Optional Literal string to sort either x1 or x0 based on the input.

    Returns:
        Tuple: tuple of 2 tensors or 3 tensors if mask is used, represents the noise (plus mask) and data samples following OT plan pi.
    """
    if replace and sort is not None:
        raise ValueError("Cannot sample with replacement and sort")
    # Calculate the optimal transport
    pi = self.get_ot_matrix(x0, x1, mask)

    # Sample (x0, x1) mapping indices from the OT matrix
    i, j = self.sample_map(pi, x0.shape[0], replace=replace)
    if not replace and (sort == "noise" or sort == "x0"):
        sort_idx = torch.argsort(i)
        i = i[sort_idx]
        j = j[sort_idx]

        if not (i == torch.arange(x0.shape[0], device=i.device)).all():
            raise ValueError("x0_idx should be a tensor from 0 to size - 1 when sort is 'noise' or 'x0")
        noise = x0
        data = x1[j]
    elif not replace and (sort == "data" or sort == "x1"):
        sort_idx = torch.argsort(j)
        i = i[sort_idx]
        j = j[sort_idx]

        if not (j == torch.arange(x1.shape[0], device=j.device)).all():
            raise ValueError("x1_idx should be a tensor from 0 to size - 1 when sort is 'noise' or 'x0")
        noise = x0[i]
        data = x1
    else:
        noise = x0[i]
        data = x1[j]

    # Output the permuted samples in the minibatch
    if mask is not None:
        if mask.device != x0.device:
            mask = mask.to(x0.device)
        mask = mask[i]
    return noise, data, mask

get_ot_matrix(x0, x1, mask=None)

Compute the OT matrix between a source and a target minibatch.

Parameters:

Name Type Description Default
x0 Tensor

shape (bs, *dim), noise from source minibatch.

required
x1 Tensor

shape (bs, *dim), data from source minibatch.

required
mask Optional[Tensor]

mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

None

Returns:

Name Type Description
p Tensor

shape (bs, bs), the OT matrix between noise and data in minibatch.

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def get_ot_matrix(self, x0: Tensor, x1: Tensor, mask: Optional[Tensor] = None) -> Tensor:
    """Compute the OT matrix between a source and a target minibatch.

    Args:
        x0 (Tensor): shape (bs, *dim), noise from source minibatch.
        x1 (Tensor): shape (bs, *dim), data from source minibatch.
        mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

    Returns:
        p (Tensor): shape (bs, bs), the OT matrix between noise and data in minibatch.

    """
    # Compute the cost matrix
    M = self._calculate_cost_matrix(x0, x1, mask)
    # Set uniform weights for all samples in a minibatch
    a, b = pot.unif(x0.shape[0], type_as=M), pot.unif(x1.shape[0], type_as=M)

    p = self.ot_fn(a, b, M)
    # Handle exceptions
    if not torch.all(torch.isfinite(p)):
        raise ValueError("OT plan map is not finite, cost mean, max: {}, {}".format(M.mean(), M.max()))
    if torch.abs(p.sum()) < 1e-8:
        warnings.warn("Numerical errors in OT matrix, reverting to uniform plan.")
        p = torch.ones_like(p) / p.numel()

    return p

sample_map(pi, batch_size, replace=False)

Draw source and target samples from pi $(x,z) \sim \pi$.

Parameters:

Name Type Description Default
pi Tensor

shape (bs, bs), the OT matrix between noise and data in minibatch.

required
batch_size int

The batch size of the minibatch.

required
replace bool

sampling w/ or w/o replacement from the OT plan, default to False.

False

Returns:

Name Type Description
Tuple Tuple[Tensor, Tensor]

tuple of 2 tensors, represents the indices of noise and data samples from pi.

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def sample_map(self, pi: Tensor, batch_size: int, replace: Bool = False) -> Tuple[Tensor, Tensor]:
    r"""Draw source and target samples from pi $(x,z) \sim \pi$.

    Args:
        pi (Tensor): shape (bs, bs), the OT matrix between noise and data in minibatch.
        batch_size (int): The batch size of the minibatch.
        replace (bool): sampling w/ or w/o replacement from the OT plan, default to False.

    Returns:
        Tuple: tuple of 2 tensors, represents the indices of noise and data samples from pi.
    """
    if pi.shape[0] != batch_size or pi.shape[1] != batch_size:
        raise ValueError("Shape mismatch: pi.shape = {}, batch_size = {}".format(pi.shape, batch_size))
    p = pi.flatten()
    p = p / p.sum()
    choices = torch.multinomial(p, batch_size, replacement=replace)
    return torch.div(choices, pi.shape[1], rounding_mode="floor"), choices % pi.shape[1]

to_device(device)

Moves all internal tensors to the specified device and updates the self.device attribute.

Parameters:

Name Type Description Default
device str

The device to move the tensors to (e.g. "cpu", "cuda:0").

required
Note

This method is used to transfer the internal state of the OTSampler to a different device. It updates the self.device attribute to reflect the new device and moves all internal tensors to the specified device.

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/ot_sampler.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def to_device(self, device: str):
    """Moves all internal tensors to the specified device and updates the `self.device` attribute.

    Args:
        device (str): The device to move the tensors to (e.g. "cpu", "cuda:0").

    Note:
        This method is used to transfer the internal state of the OTSampler to a different device.
        It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device.
    """
    self.device = device
    for attr_name in dir(self):
        if attr_name.startswith("_") and isinstance(getattr(self, attr_name), torch.Tensor):
            setattr(self, attr_name, getattr(self, attr_name).to(device))
    return self