Skip to content

Kabsch augmentation

KabschAugmentation

Point-wise Kabsch alignment.

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/kabsch_augmentation.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class KabschAugmentation:
    """Point-wise Kabsch alignment."""

    def __init__(self):
        """Initialize the KabschAugmentation instance.

        Notes:
            - This implementation assumes no required initialization arguments.
            - You can add instance variables (e.g., `self.variable_name`) as needed.
        """
        pass  # No operations are performed when initializing with no args

    def kabsch_align(self, target: Tensor, noise: Tensor):
        """Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.

        Args:
            target (Tensor): shape (N, *dim), data from source minibatch.
            noise (Tensor): shape (N, *dim), noise from source minibatch.

        Returns:
            R (Tensor): shape (*dim, *dim), the rotation matrix.
            Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise
        """
        dimension = target.shape[-1]
        noise_translation = noise.mean(dim=0)
        noise_centered = noise - noise_translation
        target_centered = target - target.mean(dim=0)

        # Compute the covariance matrix
        covariance_matix = target_centered.T @ noise_centered

        # Compute the SVD of the covariance matrix
        U, S, Vt = torch.linalg.svd(covariance_matix)
        d = torch.sign(torch.linalg.det(Vt.T @ U.T)).item()
        d_mat = torch.tensor([1] * (dimension - 1) + [d], device=Vt.device, dtype=Vt.dtype)
        R = Vt.T @ torch.diag(d_mat) @ U.T

        target_aligned = target_centered @ R.T + noise_translation

        return R, target_aligned

    def batch_kabsch_align(self, target: Tensor, noise: Tensor):
        """Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.

        Args:
            target (Tensor): shape (B, N, *dim), data from source minibatch.
            noise (Tensor): shape (B, N, *dim), noise from source minibatch.

        Returns:
            R (Tensor): shape (*dim, *dim), the rotation matrix.
            Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise
        """
        # Corrected Batched Kabsch Alignment
        batch_size, _, dimension = target.shape

        # Center the target and noise tensors along the middle dimension (N) for each batch item
        noise_translation = noise.mean(dim=1, keepdim=True)
        noise_centered = noise - noise_translation
        target_centered = target - target.mean(dim=1, keepdim=True)

        # Compute the covariance matrix for each batch item
        covariance_matrix = torch.matmul(target_centered.transpose(1, 2), noise_centered)

        # Compute the SVD of the covariance matrix for each batch item
        U, S, Vt = torch.linalg.svd(covariance_matrix)

        # Adjust for proper rotation (determinant=1) for each batch item
        d = torch.sign(torch.linalg.det(Vt @ U.transpose(-1, -2)))  # Keep as tensor for batch operations
        d_mat = torch.diag_embed(
            torch.cat(
                [torch.ones(batch_size, dimension - 1, device=Vt.device, dtype=Vt.dtype), d.unsqueeze(-1)], dim=-1
            )
        )

        R_batch = torch.matmul(torch.matmul(Vt.transpose(-1, -2), d_mat), U.transpose(-1, -2))

        target_aligned = target_centered @ R_batch.transpose(-1, -2) + noise_translation
        return R_batch, target_aligned

    def apply_augmentation(
        self,
        x0: Tensor,
        x1: Tensor,
        mask: Optional[Tensor] = None,
        align_noise_to_data=True,
    ) -> Tuple[Tensor, Tensor]:
        r"""Sample indices for noise and data in minibatch according to OT plan.

        Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target
        minibatch and draw source and target samples from pi $(x,z) \sim \pi$.

        Args:
            x0 (Tensor): shape (bs, *dim), noise from source minibatch.
            x1 (Tensor): shape (bs, *dim), data from source minibatch.
            mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
            replace (bool): sampling w/ or w/o replacement from the OT plan, default to False.
            align_noise_to_data (bool): Direction of alignment default is True meaning it augments Noise to reduce error to Data.

        Returns:
            Tuple: tuple of 2 tensors, represents the noise and data samples following OT plan pi.
        """
        if x1.ndim > 2:
            align_func = self.batch_kabsch_align
        else:
            align_func = self.kabsch_align
        if mask is not None:
            mask = pad_like(mask, x1)
            x1 = x1 * mask
            x0 = x0 * mask
        if align_noise_to_data:
            # Compute the rotation matrix R that aligns x0 to x1
            R, aligned_x0 = align_func(x0, x1)
            noise = aligned_x0
            data = x1
        else:
            # Compute the rotation matrix R that aligns x1 to x0
            R, aligned_x1 = align_func(x1, x0)
            noise = x0
            data = aligned_x1
        if mask is not None:
            noise = noise * mask
            data = data * mask
        # Output the permuted samples in the minibatch
        return noise, data

__init__()

Initialize the KabschAugmentation instance.

Notes
  • This implementation assumes no required initialization arguments.
  • You can add instance variables (e.g., self.variable_name) as needed.
Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/kabsch_augmentation.py
28
29
30
31
32
33
34
35
def __init__(self):
    """Initialize the KabschAugmentation instance.

    Notes:
        - This implementation assumes no required initialization arguments.
        - You can add instance variables (e.g., `self.variable_name`) as needed.
    """
    pass  # No operations are performed when initializing with no args

apply_augmentation(x0, x1, mask=None, align_noise_to_data=True)

Sample indices for noise and data in minibatch according to OT plan.

Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target minibatch and draw source and target samples from pi $(x,z) \sim \pi$.

Parameters:

Name Type Description Default
x0 Tensor

shape (bs, *dim), noise from source minibatch.

required
x1 Tensor

shape (bs, *dim), data from source minibatch.

required
mask Optional[Tensor]

mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.

None
replace bool

sampling w/ or w/o replacement from the OT plan, default to False.

required
align_noise_to_data bool

Direction of alignment default is True meaning it augments Noise to reduce error to Data.

True

Returns:

Name Type Description
Tuple Tuple[Tensor, Tensor]

tuple of 2 tensors, represents the noise and data samples following OT plan pi.

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/kabsch_augmentation.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def apply_augmentation(
    self,
    x0: Tensor,
    x1: Tensor,
    mask: Optional[Tensor] = None,
    align_noise_to_data=True,
) -> Tuple[Tensor, Tensor]:
    r"""Sample indices for noise and data in minibatch according to OT plan.

    Compute the OT plan $\pi$ (wrt squared Euclidean cost after Kabsch alignment) between a source and a target
    minibatch and draw source and target samples from pi $(x,z) \sim \pi$.

    Args:
        x0 (Tensor): shape (bs, *dim), noise from source minibatch.
        x1 (Tensor): shape (bs, *dim), data from source minibatch.
        mask (Optional[Tensor], optional): mask to apply to the output, shape (batchsize, nodes), if not provided no mask is applied. Defaults to None.
        replace (bool): sampling w/ or w/o replacement from the OT plan, default to False.
        align_noise_to_data (bool): Direction of alignment default is True meaning it augments Noise to reduce error to Data.

    Returns:
        Tuple: tuple of 2 tensors, represents the noise and data samples following OT plan pi.
    """
    if x1.ndim > 2:
        align_func = self.batch_kabsch_align
    else:
        align_func = self.kabsch_align
    if mask is not None:
        mask = pad_like(mask, x1)
        x1 = x1 * mask
        x0 = x0 * mask
    if align_noise_to_data:
        # Compute the rotation matrix R that aligns x0 to x1
        R, aligned_x0 = align_func(x0, x1)
        noise = aligned_x0
        data = x1
    else:
        # Compute the rotation matrix R that aligns x1 to x0
        R, aligned_x1 = align_func(x1, x0)
        noise = x0
        data = aligned_x1
    if mask is not None:
        noise = noise * mask
        data = data * mask
    # Output the permuted samples in the minibatch
    return noise, data

batch_kabsch_align(target, noise)

Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.

Parameters:

Name Type Description Default
target Tensor

shape (B, N, *dim), data from source minibatch.

required
noise Tensor

shape (B, N, *dim), noise from source minibatch.

required

Returns:

Name Type Description
R Tensor

shape (dim, dim), the rotation matrix.

Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/kabsch_augmentation.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def batch_kabsch_align(self, target: Tensor, noise: Tensor):
    """Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.

    Args:
        target (Tensor): shape (B, N, *dim), data from source minibatch.
        noise (Tensor): shape (B, N, *dim), noise from source minibatch.

    Returns:
        R (Tensor): shape (*dim, *dim), the rotation matrix.
        Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise
    """
    # Corrected Batched Kabsch Alignment
    batch_size, _, dimension = target.shape

    # Center the target and noise tensors along the middle dimension (N) for each batch item
    noise_translation = noise.mean(dim=1, keepdim=True)
    noise_centered = noise - noise_translation
    target_centered = target - target.mean(dim=1, keepdim=True)

    # Compute the covariance matrix for each batch item
    covariance_matrix = torch.matmul(target_centered.transpose(1, 2), noise_centered)

    # Compute the SVD of the covariance matrix for each batch item
    U, S, Vt = torch.linalg.svd(covariance_matrix)

    # Adjust for proper rotation (determinant=1) for each batch item
    d = torch.sign(torch.linalg.det(Vt @ U.transpose(-1, -2)))  # Keep as tensor for batch operations
    d_mat = torch.diag_embed(
        torch.cat(
            [torch.ones(batch_size, dimension - 1, device=Vt.device, dtype=Vt.dtype), d.unsqueeze(-1)], dim=-1
        )
    )

    R_batch = torch.matmul(torch.matmul(Vt.transpose(-1, -2), d_mat), U.transpose(-1, -2))

    target_aligned = target_centered @ R_batch.transpose(-1, -2) + noise_translation
    return R_batch, target_aligned

kabsch_align(target, noise)

Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.

Parameters:

Name Type Description Default
target Tensor

shape (N, *dim), data from source minibatch.

required
noise Tensor

shape (N, *dim), noise from source minibatch.

required

Returns:

Name Type Description
R Tensor

shape (dim, dim), the rotation matrix.

Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise

Source code in bionemo/moco/interpolants/continuous_time/continuous/data_augmentation/kabsch_augmentation.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def kabsch_align(self, target: Tensor, noise: Tensor):
    """Find the Rotation matrix (R) such that RMSD is minimized between target @ R.T and noise.

    Args:
        target (Tensor): shape (N, *dim), data from source minibatch.
        noise (Tensor): shape (N, *dim), noise from source minibatch.

    Returns:
        R (Tensor): shape (*dim, *dim), the rotation matrix.
        Aliged Target (Tensor): target tensor rotated and shifted to reduced RMSD with noise
    """
    dimension = target.shape[-1]
    noise_translation = noise.mean(dim=0)
    noise_centered = noise - noise_translation
    target_centered = target - target.mean(dim=0)

    # Compute the covariance matrix
    covariance_matix = target_centered.T @ noise_centered

    # Compute the SVD of the covariance matrix
    U, S, Vt = torch.linalg.svd(covariance_matix)
    d = torch.sign(torch.linalg.det(Vt.T @ U.T)).item()
    d_mat = torch.tensor([1] * (dimension - 1) + [d], device=Vt.device, dtype=Vt.dtype)
    R = Vt.T @ torch.diag(d_mat) @ U.T

    target_aligned = target_centered @ R.T + noise_translation

    return R, target_aligned