Skip to content

Utils

remove_center_of_mass(data, mask=None)

Calculates the center of mass (CoM) of the given data.

Parameters:

Name Type Description Default
data Tensor

The input data with shape (..., nodes, features).

required
mask Optional[Tensor]

An optional binary mask to apply to the data with shape (..., nodes) to mask out interaction from CoM calculation. Defaults to None.

None

Returns: The CoM of the data with shape (..., 1, features).

Source code in bionemo/moco/distributions/prior/continuous/utils.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def remove_center_of_mass(data: Tensor, mask: Optional[Tensor] = None) -> Tensor:
    """Calculates the center of mass (CoM) of the given data.

    Args:
        data: The input data with shape (..., nodes, features).
        mask: An optional binary mask to apply to the data with shape (..., nodes) to mask out interaction from CoM calculation. Defaults to None.

    Returns:
    The CoM of the data with shape (..., 1, features).
    """
    if mask is None:
        com = data.mean(dim=-2, keepdim=True)
    else:
        masked_data = data * mask.unsqueeze(-1)
        num_nodes = mask.sum(dim=-1, keepdim=True).unsqueeze(-1)
        com = masked_data.sum(dim=-2, keepdim=True) / num_nodes
    return data - com