Skip to content

Harmonic

LinearHarmonicPrior

Bases: PriorDistribution

A subclass representing a Linear Harmonic prior distribution from Jit et al. https://arxiv.org/abs/2304.02198.

Source code in bionemo/moco/distributions/prior/continuous/harmonic.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class LinearHarmonicPrior(PriorDistribution):
    """A subclass representing a Linear Harmonic prior distribution from Jit et al. https://arxiv.org/abs/2304.02198."""

    def __init__(
        self,
        distance: Float = 3.8,
        length: Optional[int] = None,
        center: Bool = False,
        rng_generator: Optional[torch.Generator] = None,
        device: Union[str, torch.device] = "cpu",
    ) -> None:
        """Linear Harmonic prior distribution.

        Args:
            distance (Float): RMS distance between adjacent points in the line graph.
            length (Optional[int]): The number of points in a batch.
            center (bool): Whether to center the samples around the mean. Defaults to False.
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
            device (Optional[str]): Device to place the schedule on (default is "cpu").
        """
        self.distance = distance
        self.length = length
        self.center = center
        self.rng_generator = rng_generator
        self.device = device
        if length:
            self._calculate_terms(length, device)

    def _calculate_terms(self, N, device):
        a = 3 / (self.distance * self.distance)
        J = torch.zeros(N, N)
        for i, j in zip(torch.arange(N - 1), torch.arange(1, N)):
            J[i, i] += a
            J[j, j] += a
            J[i, j] = J[j, i] = -a
        D, P = torch.linalg.eigh(J)
        D_inv = 1 / D
        D_inv[0] = 0
        self.P, self.D_inv = P.to(device), D_inv.to(device)

    def sample(
        self,
        shape: Tuple,
        mask: Optional[Tensor] = None,
        device: Union[str, torch.device] = "cpu",
        rng_generator: Optional[torch.Generator] = None,
    ) -> Tensor:
        """Generates a specified number of samples from the Harmonic prior distribution.

        Args:
            shape (Tuple): The shape of the samples to generate.
            device (str): cpu or gpu.
            mask (Optional[Tensor]): An optional mask to apply to the samples. Defaults to None.
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.

        Returns:
            Float: A tensor of samples.
        """
        if len(shape) != 3:
            raise ValueError("Input shape can only work for B x L x D")
        if rng_generator is None:
            rng_generator = self.rng_generator

        samples = torch.randn(*shape, device=device, generator=rng_generator)
        N = shape[1]

        if N != self.length:
            self._calculate_terms(N, device)

        std = torch.sqrt(self.D_inv).unsqueeze(-1)
        samples = self.P @ (std * samples)

        if self.center:
            samples = remove_center_of_mass(samples, mask)

        if mask is not None:
            samples = samples * mask.unsqueeze(-1)
        return samples

__init__(distance=3.8, length=None, center=False, rng_generator=None, device='cpu')

Linear Harmonic prior distribution.

Parameters:

Name Type Description Default
distance Float

RMS distance between adjacent points in the line graph.

3.8
length Optional[int]

The number of points in a batch.

None
center bool

Whether to center the samples around the mean. Defaults to False.

False
rng_generator Optional[Generator]

An optional :class:torch.Generator for reproducible sampling. Defaults to None.

None
device Optional[str]

Device to place the schedule on (default is "cpu").

'cpu'
Source code in bionemo/moco/distributions/prior/continuous/harmonic.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    distance: Float = 3.8,
    length: Optional[int] = None,
    center: Bool = False,
    rng_generator: Optional[torch.Generator] = None,
    device: Union[str, torch.device] = "cpu",
) -> None:
    """Linear Harmonic prior distribution.

    Args:
        distance (Float): RMS distance between adjacent points in the line graph.
        length (Optional[int]): The number of points in a batch.
        center (bool): Whether to center the samples around the mean. Defaults to False.
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
        device (Optional[str]): Device to place the schedule on (default is "cpu").
    """
    self.distance = distance
    self.length = length
    self.center = center
    self.rng_generator = rng_generator
    self.device = device
    if length:
        self._calculate_terms(length, device)

sample(shape, mask=None, device='cpu', rng_generator=None)

Generates a specified number of samples from the Harmonic prior distribution.

Parameters:

Name Type Description Default
shape Tuple

The shape of the samples to generate.

required
device str

cpu or gpu.

'cpu'
mask Optional[Tensor]

An optional mask to apply to the samples. Defaults to None.

None
rng_generator Optional[Generator]

An optional :class:torch.Generator for reproducible sampling. Defaults to None.

None

Returns:

Name Type Description
Float Tensor

A tensor of samples.

Source code in bionemo/moco/distributions/prior/continuous/harmonic.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def sample(
    self,
    shape: Tuple,
    mask: Optional[Tensor] = None,
    device: Union[str, torch.device] = "cpu",
    rng_generator: Optional[torch.Generator] = None,
) -> Tensor:
    """Generates a specified number of samples from the Harmonic prior distribution.

    Args:
        shape (Tuple): The shape of the samples to generate.
        device (str): cpu or gpu.
        mask (Optional[Tensor]): An optional mask to apply to the samples. Defaults to None.
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.

    Returns:
        Float: A tensor of samples.
    """
    if len(shape) != 3:
        raise ValueError("Input shape can only work for B x L x D")
    if rng_generator is None:
        rng_generator = self.rng_generator

    samples = torch.randn(*shape, device=device, generator=rng_generator)
    N = shape[1]

    if N != self.length:
        self._calculate_terms(N, device)

    std = torch.sqrt(self.D_inv).unsqueeze(-1)
    samples = self.P @ (std * samples)

    if self.center:
        samples = remove_center_of_mass(samples, mask)

    if mask is not None:
        samples = samples * mask.unsqueeze(-1)
    return samples