Skip to content

Uniform

DiscreteUniformPrior

Bases: DiscretePriorDistribution

A subclass representing a discrete uniform prior distribution.

Source code in bionemo/moco/distributions/prior/discrete/uniform.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class DiscreteUniformPrior(DiscretePriorDistribution):
    """A subclass representing a discrete uniform prior distribution."""

    def __init__(self, num_classes: int = 10) -> None:
        """Initializes a discrete uniform prior distribution.

        Args:
            num_classes (int): The number of classes in the discrete uniform distribution. Defaults to 10.
        """
        prior_dist = torch.ones((num_classes)) * 1 / num_classes
        super().__init__(num_classes, prior_dist)
        if torch.sum(self.prior_dist).item() - 1.0 > 1e-5:
            raise ValueError("Prior distribution probabilities do not sum up to 1.0")

    def sample(
        self,
        shape: Tuple,
        mask: Optional[Tensor] = None,
        device: Union[str, torch.device] = "cpu",
        rng_generator: Optional[torch.Generator] = None,
    ) -> Tensor:
        """Generates a specified number of samples.

        Args:
            shape (Tuple): The shape of the samples to generate.
            device (str): cpu or gpu.
            mask (Optional[Tensor]): An optional mask to apply to the samples. Defaults to None.
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.

        Returns:
            Float: A tensor of samples.
        """
        samples = torch.randint(0, self.num_classes, shape, device=device, generator=rng_generator)
        if mask is not None:
            samples = samples * mask[(...,) + (None,) * (len(samples.shape) - len(mask.shape))]
        return samples

__init__(num_classes=10)

Initializes a discrete uniform prior distribution.

Parameters:

Name Type Description Default
num_classes int

The number of classes in the discrete uniform distribution. Defaults to 10.

10
Source code in bionemo/moco/distributions/prior/discrete/uniform.py
28
29
30
31
32
33
34
35
36
37
def __init__(self, num_classes: int = 10) -> None:
    """Initializes a discrete uniform prior distribution.

    Args:
        num_classes (int): The number of classes in the discrete uniform distribution. Defaults to 10.
    """
    prior_dist = torch.ones((num_classes)) * 1 / num_classes
    super().__init__(num_classes, prior_dist)
    if torch.sum(self.prior_dist).item() - 1.0 > 1e-5:
        raise ValueError("Prior distribution probabilities do not sum up to 1.0")

sample(shape, mask=None, device='cpu', rng_generator=None)

Generates a specified number of samples.

Parameters:

Name Type Description Default
shape Tuple

The shape of the samples to generate.

required
device str

cpu or gpu.

'cpu'
mask Optional[Tensor]

An optional mask to apply to the samples. Defaults to None.

None
rng_generator Optional[Generator]

An optional :class:torch.Generator for reproducible sampling. Defaults to None.

None

Returns:

Name Type Description
Float Tensor

A tensor of samples.

Source code in bionemo/moco/distributions/prior/discrete/uniform.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def sample(
    self,
    shape: Tuple,
    mask: Optional[Tensor] = None,
    device: Union[str, torch.device] = "cpu",
    rng_generator: Optional[torch.Generator] = None,
) -> Tensor:
    """Generates a specified number of samples.

    Args:
        shape (Tuple): The shape of the samples to generate.
        device (str): cpu or gpu.
        mask (Optional[Tensor]): An optional mask to apply to the samples. Defaults to None.
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.

    Returns:
        Float: A tensor of samples.
    """
    samples = torch.randint(0, self.num_classes, shape, device=device, generator=rng_generator)
    if mask is not None:
        samples = samples * mask[(...,) + (None,) * (len(samples.shape) - len(mask.shape))]
    return samples