Skip to content

Logit normal

LogitNormalTimeDistribution

Bases: TimeDistribution

A class representing a logit normal time distribution.

Source code in bionemo/moco/distributions/time/logit_normal.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class LogitNormalTimeDistribution(TimeDistribution):
    """A class representing a logit normal time distribution."""

    def __init__(
        self,
        p1: Float = 0.0,
        p2: Float = 1.0,
        min_t: Float = 0.0,
        max_t: Float = 1.0,
        discrete_time: Bool = False,
        nsteps: Optional[int] = None,
        rng_generator: Optional[torch.Generator] = None,
    ):
        """Initializes a BetaTimeDistribution object.

        Args:
            p1 (Float): The first shape parameter of the logit normal distribution i.e. the mean.
            p2 (Float): The second shape parameter of the logit normal distribution i.e. the std.
            min_t (Float): The minimum time value.
            max_t (Float): The maximum time value.
            discrete_time (Bool): Whether the time is discrete.
            nsteps (Optional[int]): Number of nsteps for discretization.
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
        """
        super().__init__(discrete_time, nsteps, min_t, max_t, rng_generator)
        self.p1 = p1
        self.p2 = p2

    def sample(
        self, n_samples: int, device: Union[str, torch.device] = "cpu", rng_generator: Optional[torch.Generator] = None
    ):
        """Generates a specified number of samples from the uniform time distribution.

        Args:
            n_samples (int): The number of samples to generate.
            device (str): cpu or gpu.
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.

        Returns:
            A tensor of samples.
        """
        if rng_generator is None:
            rng_generator = self.rng_generator
        time_step = torch.randn(n_samples, device=device, generator=rng_generator) * self.p2 + self.p1
        time_step = torch.nn.functional.sigmoid(time_step)
        if self.min_t and self.max_t and (self.min_t > 0 or self.max_t < 1):
            time_step = time_step * (self.max_t - self.min_t) + self.min_t
        if self.discrete_time:
            if self.nsteps is None:
                raise ValueError("nsteps cannot be None for discrete time sampling")
            time_step = float_time_to_index(time_step, self.nsteps)
        return time_step

__init__(p1=0.0, p2=1.0, min_t=0.0, max_t=1.0, discrete_time=False, nsteps=None, rng_generator=None)

Initializes a BetaTimeDistribution object.

Parameters:

Name Type Description Default
p1 Float

The first shape parameter of the logit normal distribution i.e. the mean.

0.0
p2 Float

The second shape parameter of the logit normal distribution i.e. the std.

1.0
min_t Float

The minimum time value.

0.0
max_t Float

The maximum time value.

1.0
discrete_time Bool

Whether the time is discrete.

False
nsteps Optional[int]

Number of nsteps for discretization.

None
rng_generator Optional[Generator]

An optional :class:torch.Generator for reproducible sampling. Defaults to None.

None
Source code in bionemo/moco/distributions/time/logit_normal.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    p1: Float = 0.0,
    p2: Float = 1.0,
    min_t: Float = 0.0,
    max_t: Float = 1.0,
    discrete_time: Bool = False,
    nsteps: Optional[int] = None,
    rng_generator: Optional[torch.Generator] = None,
):
    """Initializes a BetaTimeDistribution object.

    Args:
        p1 (Float): The first shape parameter of the logit normal distribution i.e. the mean.
        p2 (Float): The second shape parameter of the logit normal distribution i.e. the std.
        min_t (Float): The minimum time value.
        max_t (Float): The maximum time value.
        discrete_time (Bool): Whether the time is discrete.
        nsteps (Optional[int]): Number of nsteps for discretization.
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
    """
    super().__init__(discrete_time, nsteps, min_t, max_t, rng_generator)
    self.p1 = p1
    self.p2 = p2

sample(n_samples, device='cpu', rng_generator=None)

Generates a specified number of samples from the uniform time distribution.

Parameters:

Name Type Description Default
n_samples int

The number of samples to generate.

required
device str

cpu or gpu.

'cpu'
rng_generator Optional[Generator]

An optional :class:torch.Generator for reproducible sampling. Defaults to None.

None

Returns:

Type Description

A tensor of samples.

Source code in bionemo/moco/distributions/time/logit_normal.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def sample(
    self, n_samples: int, device: Union[str, torch.device] = "cpu", rng_generator: Optional[torch.Generator] = None
):
    """Generates a specified number of samples from the uniform time distribution.

    Args:
        n_samples (int): The number of samples to generate.
        device (str): cpu or gpu.
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.

    Returns:
        A tensor of samples.
    """
    if rng_generator is None:
        rng_generator = self.rng_generator
    time_step = torch.randn(n_samples, device=device, generator=rng_generator) * self.p2 + self.p1
    time_step = torch.nn.functional.sigmoid(time_step)
    if self.min_t and self.max_t and (self.min_t > 0 or self.max_t < 1):
        time_step = time_step * (self.max_t - self.min_t) + self.min_t
    if self.discrete_time:
        if self.nsteps is None:
            raise ValueError("nsteps cannot be None for discrete time sampling")
        time_step = float_time_to_index(time_step, self.nsteps)
    return time_step