Skip to content

Base interpolant

Interpolant

Bases: ABC

An abstract base class representing an Interpolant.

This class serves as a foundation for creating interpolants that can be used in various applications, providing a basic structure and interface for interpolation-related operations.

Source code in bionemo/moco/interpolants/base_interpolant.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class Interpolant(ABC):
    """An abstract base class representing an Interpolant.

    This class serves as a foundation for creating interpolants that can be used
    in various applications, providing a basic structure and interface for
    interpolation-related operations.
    """

    def __init__(
        self,
        time_distribution: TimeDistribution,
        prior_distribution: PriorDistribution,
        device: Union[str, torch.device] = "cpu",
        rng_generator: Optional[torch.Generator] = None,
    ):
        """Initializes the Interpolant class.

        Args:
            time_distribution (TimeDistribution): The distribution of time steps.
            prior_distribution (PriorDistribution): The prior distribution of the variable.
            device (Union[str, torch.device], optional): The device on which to operate. Defaults to "cpu".
            rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
        """
        self.time_distribution = time_distribution
        self.prior_distribution = prior_distribution
        self.device = device
        self.rng_generator = rng_generator

    @abstractmethod
    def interpolate(self, *args, **kwargs) -> Tensor:
        """Get x(t) with given time t from noise and data.

        Interpolate between x0 and x1 at the given time t.
        """
        pass

    @abstractmethod
    def step(self, *args, **kwargs) -> Tensor:
        """Do one step integration."""
        pass

    def general_step(self, method_name: str, kwargs: dict):
        """Calls a step method of the class by its name, passing the provided keyword arguments.

        Args:
            method_name (str): The name of the step method to call.
            kwargs (dict): Keyword arguments to pass to the step method.

        Returns:
            The result of the step method call.

        Raises:
            ValueError: If the provided method name does not start with 'step'.
            Exception: If the step method call fails. The error message includes a list of available step methods.

        Note:
            This method allows for dynamic invocation of step methods, providing flexibility in the class's usage.
        """
        if not method_name.startswith("step"):
            raise ValueError(f"Method name '{method_name}' does not start with 'step'")

        try:
            # Get the step method by its name
            func = getattr(self, method_name)
            # Call the step method with the provided keyword arguments
            return func(**kwargs)
        except Exception as e:
            # Get a list of available step methods
            available_methods = "\n".join([f"  - {attr}" for attr in dir(self) if attr.startswith("step")])
            # Create a detailed error message
            error_message = f"Error calling method '{method_name}': {e}\nAvailable step methods:\n{available_methods}"
            # Re-raise the exception with the detailed error message
            raise type(e)(error_message)

    def sample_prior(self, *args, **kwargs) -> Tensor:
        """Sample from prior distribution.

        This method generates a sample from the prior distribution specified by the
        `prior_distribution` attribute.

        Returns:
            Tensor: The generated sample from the prior distribution.
        """
        # Ensure the device is specified, default to self.device if not provided
        if "device" not in kwargs:
            kwargs["device"] = self.device
        kwargs["rng_generator"] = self.rng_generator
        # Sample from the prior distribution
        return self.prior_distribution.sample(*args, **kwargs)

    def sample_time(self, *args, **kwargs) -> Tensor:
        """Sample from time distribution."""
        # Ensure the device is specified, default to self.device if not provided
        if "device" not in kwargs:
            kwargs["device"] = self.device
        kwargs["rng_generator"] = self.rng_generator
        # Sample from the time distribution
        return self.time_distribution.sample(*args, **kwargs)

    def to_device(self, device: str):
        """Moves all internal tensors to the specified device and updates the `self.device` attribute.

        Args:
            device (str): The device to move the tensors to (e.g. "cpu", "cuda:0").

        Note:
            This method is used to transfer the internal state of the DDPM interpolant to a different device.
            It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device.
        """
        self.device = device
        for attr_name in dir(self):
            if attr_name.startswith("_") and isinstance(getattr(self, attr_name), torch.Tensor):
                setattr(self, attr_name, getattr(self, attr_name).to(device))
        return self

    def clean_mask_center(self, data: Tensor, mask: Optional[Tensor] = None, center: Bool = False) -> Tensor:
        """Returns a clean tensor that has been masked and/or centered based on the function arguments.

        Args:
            data: The input data with shape (..., nodes, features).
            mask: An optional mask to apply to the data with shape (..., nodes). If provided, it is used to calculate the CoM. Defaults to None.
            center: A boolean indicating whether to center the data around the calculated CoM. Defaults to False.

        Returns:
            The data with shape (..., nodes, features) either centered around the CoM if `center` is True or unchanged if `center` is False.
        """
        if mask is not None:
            data = data * mask.unsqueeze(-1)
        if not center:
            return data
        if mask is None:
            num_nodes = torch.tensor(data.shape[1], device=data.device)
        else:
            num_nodes = torch.clamp(mask.sum(dim=-1), min=1)  # clamp used to prevent divide by 0
        com = data.sum(dim=-2) / num_nodes.unsqueeze(-1)
        return data - com.unsqueeze(-2)

__init__(time_distribution, prior_distribution, device='cpu', rng_generator=None)

Initializes the Interpolant class.

Parameters:

Name Type Description Default
time_distribution TimeDistribution

The distribution of time steps.

required
prior_distribution PriorDistribution

The prior distribution of the variable.

required
device Union[str, device]

The device on which to operate. Defaults to "cpu".

'cpu'
rng_generator Optional[Generator]

An optional :class:torch.Generator for reproducible sampling. Defaults to None.

None
Source code in bionemo/moco/interpolants/base_interpolant.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def __init__(
    self,
    time_distribution: TimeDistribution,
    prior_distribution: PriorDistribution,
    device: Union[str, torch.device] = "cpu",
    rng_generator: Optional[torch.Generator] = None,
):
    """Initializes the Interpolant class.

    Args:
        time_distribution (TimeDistribution): The distribution of time steps.
        prior_distribution (PriorDistribution): The prior distribution of the variable.
        device (Union[str, torch.device], optional): The device on which to operate. Defaults to "cpu".
        rng_generator: An optional :class:`torch.Generator` for reproducible sampling. Defaults to None.
    """
    self.time_distribution = time_distribution
    self.prior_distribution = prior_distribution
    self.device = device
    self.rng_generator = rng_generator

clean_mask_center(data, mask=None, center=False)

Returns a clean tensor that has been masked and/or centered based on the function arguments.

Parameters:

Name Type Description Default
data Tensor

The input data with shape (..., nodes, features).

required
mask Optional[Tensor]

An optional mask to apply to the data with shape (..., nodes). If provided, it is used to calculate the CoM. Defaults to None.

None
center Bool

A boolean indicating whether to center the data around the calculated CoM. Defaults to False.

False

Returns:

Type Description
Tensor

The data with shape (..., nodes, features) either centered around the CoM if center is True or unchanged if center is False.

Source code in bionemo/moco/interpolants/base_interpolant.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def clean_mask_center(self, data: Tensor, mask: Optional[Tensor] = None, center: Bool = False) -> Tensor:
    """Returns a clean tensor that has been masked and/or centered based on the function arguments.

    Args:
        data: The input data with shape (..., nodes, features).
        mask: An optional mask to apply to the data with shape (..., nodes). If provided, it is used to calculate the CoM. Defaults to None.
        center: A boolean indicating whether to center the data around the calculated CoM. Defaults to False.

    Returns:
        The data with shape (..., nodes, features) either centered around the CoM if `center` is True or unchanged if `center` is False.
    """
    if mask is not None:
        data = data * mask.unsqueeze(-1)
    if not center:
        return data
    if mask is None:
        num_nodes = torch.tensor(data.shape[1], device=data.device)
    else:
        num_nodes = torch.clamp(mask.sum(dim=-1), min=1)  # clamp used to prevent divide by 0
    com = data.sum(dim=-2) / num_nodes.unsqueeze(-1)
    return data - com.unsqueeze(-2)

general_step(method_name, kwargs)

Calls a step method of the class by its name, passing the provided keyword arguments.

Parameters:

Name Type Description Default
method_name str

The name of the step method to call.

required
kwargs dict

Keyword arguments to pass to the step method.

required

Returns:

Type Description

The result of the step method call.

Raises:

Type Description
ValueError

If the provided method name does not start with 'step'.

Exception

If the step method call fails. The error message includes a list of available step methods.

Note

This method allows for dynamic invocation of step methods, providing flexibility in the class's usage.

Source code in bionemo/moco/interpolants/base_interpolant.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def general_step(self, method_name: str, kwargs: dict):
    """Calls a step method of the class by its name, passing the provided keyword arguments.

    Args:
        method_name (str): The name of the step method to call.
        kwargs (dict): Keyword arguments to pass to the step method.

    Returns:
        The result of the step method call.

    Raises:
        ValueError: If the provided method name does not start with 'step'.
        Exception: If the step method call fails. The error message includes a list of available step methods.

    Note:
        This method allows for dynamic invocation of step methods, providing flexibility in the class's usage.
    """
    if not method_name.startswith("step"):
        raise ValueError(f"Method name '{method_name}' does not start with 'step'")

    try:
        # Get the step method by its name
        func = getattr(self, method_name)
        # Call the step method with the provided keyword arguments
        return func(**kwargs)
    except Exception as e:
        # Get a list of available step methods
        available_methods = "\n".join([f"  - {attr}" for attr in dir(self) if attr.startswith("step")])
        # Create a detailed error message
        error_message = f"Error calling method '{method_name}': {e}\nAvailable step methods:\n{available_methods}"
        # Re-raise the exception with the detailed error message
        raise type(e)(error_message)

interpolate(*args, **kwargs) abstractmethod

Get x(t) with given time t from noise and data.

Interpolate between x0 and x1 at the given time t.

Source code in bionemo/moco/interpolants/base_interpolant.py
134
135
136
137
138
139
140
@abstractmethod
def interpolate(self, *args, **kwargs) -> Tensor:
    """Get x(t) with given time t from noise and data.

    Interpolate between x0 and x1 at the given time t.
    """
    pass

sample_prior(*args, **kwargs)

Sample from prior distribution.

This method generates a sample from the prior distribution specified by the prior_distribution attribute.

Returns:

Name Type Description
Tensor Tensor

The generated sample from the prior distribution.

Source code in bionemo/moco/interpolants/base_interpolant.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def sample_prior(self, *args, **kwargs) -> Tensor:
    """Sample from prior distribution.

    This method generates a sample from the prior distribution specified by the
    `prior_distribution` attribute.

    Returns:
        Tensor: The generated sample from the prior distribution.
    """
    # Ensure the device is specified, default to self.device if not provided
    if "device" not in kwargs:
        kwargs["device"] = self.device
    kwargs["rng_generator"] = self.rng_generator
    # Sample from the prior distribution
    return self.prior_distribution.sample(*args, **kwargs)

sample_time(*args, **kwargs)

Sample from time distribution.

Source code in bionemo/moco/interpolants/base_interpolant.py
196
197
198
199
200
201
202
203
def sample_time(self, *args, **kwargs) -> Tensor:
    """Sample from time distribution."""
    # Ensure the device is specified, default to self.device if not provided
    if "device" not in kwargs:
        kwargs["device"] = self.device
    kwargs["rng_generator"] = self.rng_generator
    # Sample from the time distribution
    return self.time_distribution.sample(*args, **kwargs)

step(*args, **kwargs) abstractmethod

Do one step integration.

Source code in bionemo/moco/interpolants/base_interpolant.py
142
143
144
145
@abstractmethod
def step(self, *args, **kwargs) -> Tensor:
    """Do one step integration."""
    pass

to_device(device)

Moves all internal tensors to the specified device and updates the self.device attribute.

Parameters:

Name Type Description Default
device str

The device to move the tensors to (e.g. "cpu", "cuda:0").

required
Note

This method is used to transfer the internal state of the DDPM interpolant to a different device. It updates the self.device attribute to reflect the new device and moves all internal tensors to the specified device.

Source code in bionemo/moco/interpolants/base_interpolant.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def to_device(self, device: str):
    """Moves all internal tensors to the specified device and updates the `self.device` attribute.

    Args:
        device (str): The device to move the tensors to (e.g. "cpu", "cuda:0").

    Note:
        This method is used to transfer the internal state of the DDPM interpolant to a different device.
        It updates the `self.device` attribute to reflect the new device and moves all internal tensors to the specified device.
    """
    self.device = device
    for attr_name in dir(self):
        if attr_name.startswith("_") and isinstance(getattr(self, attr_name), torch.Tensor):
            setattr(self, attr_name, getattr(self, attr_name).to(device))
    return self

PredictionType

Bases: Enum

An enumeration representing the type of prediction a Denoising Diffusion Probabilistic Model (DDPM) can be used for.

DDPMs are versatile models that can be utilized for various prediction tasks, including:

  • Data: Predicting the original data distribution from a noisy input.
  • Noise: Predicting the noise that was added to the original data to obtain the input.
  • Velocity: Predicting the velocity or rate of change of the data, particularly useful for modeling temporal dynamics.

These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting.

Source code in bionemo/moco/interpolants/base_interpolant.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class PredictionType(Enum):
    """An enumeration representing the type of prediction a Denoising Diffusion Probabilistic Model (DDPM) can be used for.

    DDPMs are versatile models that can be utilized for various prediction tasks, including:

    - **Data**: Predicting the original data distribution from a noisy input.
    - **Noise**: Predicting the noise that was added to the original data to obtain the input.
    - **Velocity**: Predicting the velocity or rate of change of the data, particularly useful for modeling temporal dynamics.

    These prediction types can be used to train neural networks for specific tasks, such as denoising, image synthesis, or time-series forecasting.
    """

    DATA = "data"
    NOISE = "noise"
    VELOCITY = "velocity"

pad_like(source, target)

Pads the dimensions of the source tensor to match the dimensions of the target tensor.

Parameters:

Name Type Description Default
source Tensor

The tensor to be padded.

required
target Tensor

The tensor that the source tensor should match in dimensions.

required

Returns:

Name Type Description
Tensor Tensor

The padded source tensor.

Raises:

Type Description
ValueError

If the source tensor has more dimensions than the target tensor.

Example

source = torch.tensor([1, 2, 3]) # shape: (3,) target = torch.tensor([[1, 2], [4, 5], [7, 8]]) # shape: (3, 2) padded_source = pad_like(source, target) # shape: (3, 1)

Source code in bionemo/moco/interpolants/base_interpolant.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def pad_like(source: Tensor, target: Tensor) -> Tensor:
    """Pads the dimensions of the source tensor to match the dimensions of the target tensor.

    Args:
        source (Tensor): The tensor to be padded.
        target (Tensor): The tensor that the source tensor should match in dimensions.

    Returns:
        Tensor: The padded source tensor.

    Raises:
        ValueError: If the source tensor has more dimensions than the target tensor.

    Example:
        >>> source = torch.tensor([1, 2, 3])  # shape: (3,)
        >>> target = torch.tensor([[1, 2], [4, 5], [7, 8]])  # shape: (3, 2)
        >>> padded_source = pad_like(source, target)  # shape: (3, 1)
    """
    if source.ndim == target.ndim:
        return source
    elif source.ndim > target.ndim:
        raise ValueError(f"Cannot pad {source.shape} to {target.shape}")
    return source.view(list(source.shape) + [1] * (target.ndim - source.ndim))

string_to_enum(value, enum_type)

Converts a string to an enum value of the specified type. If the input is already an enum instance, it is returned as-is.

Parameters:

Name Type Description Default
value Union[str, E]

The string to convert or an existing enum instance.

required
enum_type Type[E]

The enum type to convert to.

required

Returns:

Name Type Description
E AnyEnum

The corresponding enum value.

Raises:

Type Description
ValueError

If the string does not correspond to any enum member.

Source code in bionemo/moco/interpolants/base_interpolant.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def string_to_enum(value: Union[str, AnyEnum], enum_type: Type[AnyEnum]) -> AnyEnum:
    """Converts a string to an enum value of the specified type. If the input is already an enum instance, it is returned as-is.

    Args:
        value (Union[str, E]): The string to convert or an existing enum instance.
        enum_type (Type[E]): The enum type to convert to.

    Returns:
        E: The corresponding enum value.

    Raises:
        ValueError: If the string does not correspond to any enum member.
    """
    if isinstance(value, enum_type):
        # If the value is already an enum, return it
        return value

    try:
        # Match the value to the Enum, case-insensitively
        return enum_type(value)
    except ValueError:
        # Raise a helpful error if the value is invalid
        valid_values = [e.value for e in enum_type]
        raise ValueError(f"Invalid value '{value}'. Expected one of {valid_values}.")