Building Generative Models for Continuous Data via Continuous Interpolants¶
import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.datasets import make_moons
Task Setup¶
To demonstrate how Conditional Flow Matching works we use sklearn to sample from and create custom 2D distriubtions.
To start we define our "dataloader" so to speak. This is the '''sample_moons''' function.
Next we define a custom PriorDistribution to enable the conversion of 8 equidistance gaussians to the moon distribution above.
def sample_moons(n, normalize = False):
x1, _ = make_moons(n_samples=n, noise=0.08)
x1 = torch.Tensor(x1)
x1 = x1 * 3 - 1
if normalize:
x1 = (x1 - x1.mean(0))/x1.std(0) * 2
return x1
x1 = sample_moons(1000)
plt.scatter(x1[:, 0], x1[:, 1])
<matplotlib.collections.PathCollection at 0x7eb0d639ca90>
Model Creation¶
Here we define a simple 4 layer MLP and define our optimizer
dim = 2
hidden_size = 64
batch_size = 256
model = torch.nn.Sequential(
torch.nn.Linear(dim + 1, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, dim),
)
optimizer = torch.optim.Adam(model.parameters())
Continuous Flow Matching Interpolant¶
Here we import our desired interpolant objects.
The continuous flow matcher and the desired time distribution.
from bionemo.moco.interpolants import ContinuousFlowMatcher
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.distributions.prior import GaussianPrior
uniform_time = UniformTimeDistribution()
simple_prior = GaussianPrior()
sigma = 0.1
cfm = ContinuousFlowMatcher(time_distribution=uniform_time,
prior_distribution=simple_prior,
sigma=sigma,
prediction_type="velocity")
# Place both the model and the interpolant on the same device
DEVICE = "cuda"
model = model.to(DEVICE)
cfm = cfm.to_device(DEVICE)
Training Loop¶
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = cfm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = cfm.sample_time(batch_size)
xt = cfm.interpolate(x1, t, x0)
ut = cfm.calculate_target(x1, x0)
vt = model(torch.cat([xt, t[:, None]], dim=-1))
loss = cfm.loss(vt, ut, target_type="velocity").mean()
loss.backward()
optimizer.step()
if (k + 1) % 5000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
5000: loss 2.752 10000: loss 2.838 15000: loss 2.709 20000: loss 3.096
Setting Up Generation¶
Now we need to import the desired inference time schedule. This is what gives us the time values to iterate through to iteratively generate from our model.
Here we show the output time schedule as well as the discretization between time points. We note that different inference time schedules may have different shapes resulting in non uniform dt
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
inference_sched = LinearInferenceSchedule(nsteps = 100)
schedule = inference_sched.generate_schedule().to(DEVICE)
dts = inference_sched.discretize().to(DEVICE)
schedule, dts
(tensor([0.0000, 0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800, 0.0900, 0.1000, 0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700, 0.1800, 0.1900, 0.2000, 0.2100, 0.2200, 0.2300, 0.2400, 0.2500, 0.2600, 0.2700, 0.2800, 0.2900, 0.3000, 0.3100, 0.3200, 0.3300, 0.3400, 0.3500, 0.3600, 0.3700, 0.3800, 0.3900, 0.4000, 0.4100, 0.4200, 0.4300, 0.4400, 0.4500, 0.4600, 0.4700, 0.4800, 0.4900, 0.5000, 0.5100, 0.5200, 0.5300, 0.5400, 0.5500, 0.5600, 0.5700, 0.5800, 0.5900, 0.6000, 0.6100, 0.6200, 0.6300, 0.6400, 0.6500, 0.6600, 0.6700, 0.6800, 0.6900, 0.7000, 0.7100, 0.7200, 0.7300, 0.7400, 0.7500, 0.7600, 0.7700, 0.7800, 0.7900, 0.8000, 0.8100, 0.8200, 0.8300, 0.8400, 0.8500, 0.8600, 0.8700, 0.8800, 0.8900, 0.9000, 0.9100, 0.9200, 0.9300, 0.9400, 0.9500, 0.9600, 0.9700, 0.9800, 0.9900], device='cuda:0'), tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100], device='cuda:0'))
Sample from the trained model¶
inf_size = 1024
sample = cfm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for dt, t in zip(dts, schedule):
full_t = inference_sched.pad_time(inf_size, t, DEVICE)
vt = model(torch.cat([sample, full_t[:, None]], dim=-1)) # calculate the vector field based on the definition of the model
sample = cfm.step(vt, sample, dt, full_t)
trajectory.append(sample) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
Sample from underlying score model¶
low temperature sampling is a heuristic, unclear what effects it has on the final distribution. Intuitively, it cuts tails and focuses more on the mode, in practice who knows exactly what's the final effect.¶
gt_mode is a hyperparameter that must be experimentally chosen¶
inf_size = 1024
sample = cfm.sample_prior((inf_size, 2)).to(DEVICE)
trajectory_stoch = [sample]
vts = []
for dt, t in zip(dts, schedule):
time = inference_sched.pad_time(inf_size, t, DEVICE) #torch.full((inf_size,), t).to(DEVICE)
vt = model(torch.cat([sample, time[:, None]], dim=-1))
sample = cfm.step_score_stochastic(vt, sample, dt, time, noise_temperature=1.0, gt_mode = "tan")
trajectory_stoch.append(sample)
vts.append(vt)
traj = torch.stack(trajectory_stoch).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(0)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
#for i in range(0, traj.shape[0]-1):
# plt.plot(traj[i, :n, 0], traj[i, :n, 1], c="olive", alpha=0.2) #, s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(1)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.title("Stochastic score sampling Temperature = 1.0")
plt.show()
What happens if you just sample from a random model?¶
fmodel = torch.nn.Sequential(
torch.nn.Linear(dim + 1, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, dim),
).to(DEVICE)
inf_size = 1024
sample = cfm.sample_prior((inf_size, 2)).to(DEVICE)
trajectory2 = [sample]
for dt, t in zip(dts, schedule):
time = inference_sched.pad_time(inf_size, t, DEVICE) #torch.full((inf_size,), t).to(DEVICE)
vt = fmodel(torch.cat([sample, time[:, None]], dim=-1))
sample = cfm.step(vt, sample, dt, time)
trajectory2.append(sample)
n = 2000
traj = torch.stack(trajectory2).cpu().detach().numpy()
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(0)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(1)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
import math
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
def __init__(
self, dim_in: int, dim_out: int, dim_hids: List[int],
):
super().__init__()
self.layers = nn.ModuleList([
TimeLinear(dim_in, dim_hids[0]),
*[TimeLinear(dim_hids[i-1], dim_hids[i]) for i in range(1, len(dim_hids))],
TimeLinear(dim_hids[-1], dim_out)
])
def forward(self, x: torch.Tensor, t: torch.Tensor):
for i, layer in enumerate(self.layers):
x = layer(x, t)
if i < len(self.layers) - 1:
x = F.relu(x)
return x
class TimeLinear(nn.Module):
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.time_embedding = TimeEmbedding(dim_out)
self.fc = nn.Linear(dim_in, dim_out)
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.fc(x)
alpha = self.time_embedding(t).view(-1, self.dim_out)
return alpha * x
class TimeEmbedding(nn.Module):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t: torch.Tensor):
if t.ndim == 0:
t = t.unsqueeze(-1)
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.interpolants import DDPM
from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule, DiscreteLinearNoiseSchedule
from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule
from bionemo.moco.distributions.prior import GaussianPrior
DEVICE = "cuda:0"
uniform_time = UniformTimeDistribution(discrete_time=True, nsteps = 1000)
simple_prior = GaussianPrior()
ddpm = DDPM(time_distribution=uniform_time,
prior_distribution=simple_prior,
prediction_type = "noise",
noise_schedule = DiscreteLinearNoiseSchedule(nsteps = 1000),
device=DEVICE)
Train the Model¶
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
DEVICE = "cuda"
model = model.to(DEVICE)
ddpm = ddpm.to_device(DEVICE)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = ddpm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = ddpm.sample_time(batch_size)
xt = ddpm.interpolate(x1, t, x0)
eps = model(xt, t)
loss = ddpm.loss(eps, x0, t).mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 0.320 2000: loss 0.372 3000: loss 0.330 4000: loss 0.409 5000: loss 0.338 6000: loss 0.378 7000: loss 0.355 8000: loss 0.394 9000: loss 0.359 10000: loss 0.338 11000: loss 0.257 12000: loss 0.293 13000: loss 0.333 14000: loss 0.329 15000: loss 0.322 16000: loss 0.302 17000: loss 0.282 18000: loss 0.331 19000: loss 0.289 20000: loss 0.322
Let's vizualize what the interpolation looks like during training for different times¶
x0 = ddpm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
for t in range(0, 900, 100):
tt = ddpm.sample_time(batch_size)*0 + t
out = ddpm.interpolate(x1, tt, x0)
plt.scatter(out[:, 0].cpu().detach(), out[:, 1].cpu().detach())
plt.title(f"Time = {t}")
plt.show()
Create the inference time schedule and sample from the model¶
inf_size = 1024
schedule = DiscreteLinearInferenceSchedule(nsteps = 1000, direction = "diffusion").generate_schedule(device= DEVICE)
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
vt = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step_noise(vt, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
/home/dreidenbach/mambaforge/envs/moco_bionemo/lib/python3.10/site-packages/IPython/core/pylabtools.py:170: UserWarning: Creating legend with loc="best" can be slow with large amounts of data. fig.canvas.print_figure(bytes_io, **kw)
inf_size = 1024
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
eps_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step(eps_hat, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
notice that his yields very similar results to using the underlying score function in the stochastic score based CFM example¶
Notice that there is no difference whether or not we convert the predicted noise to data inside thte .step() function¶
Let's try other cool sampling functions¶
inf_size = 1024
schedule = DiscreteLinearInferenceSchedule(nsteps = 1000, direction = "diffusion").generate_schedule(device= DEVICE)
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
eps_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step_ddim(eps_hat, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
What happens when you sample from an untrained model with DDPM¶
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens).to(DEVICE)
inf_size = 1024
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE)
trajectory2 = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
vt = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step_noise(vt, full_t, sample)
trajectory2.append(sample) #
n = 2000
traj = torch.stack(trajectory2).cpu().detach().numpy()
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(0)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(1)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
Now let's switch the parameterization of DDPM from noise to data¶
Here instead of training the model to learn the noise we want to learn the raw data. Both options are valid and the choice of which depends on the underlying modeling task.
from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
from bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM
from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule, DiscreteLinearNoiseSchedule
from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule
from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
DEVICE = "cuda:0"
uniform_time = UniformTimeDistribution(discrete_time=True, nsteps = 1000)
simple_prior = GaussianPrior()
ddpm = DDPM(time_distribution=uniform_time,
prior_distribution=simple_prior,
prediction_type = "data",
noise_schedule = DiscreteLinearNoiseSchedule(nsteps = 1000),
device=DEVICE)
Let us first train the model with a weight such that it is theoretically equivalent to the simple noise matching loss. See Equation 9 from https://arxiv.org/pdf/2202.00512¶
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
DEVICE = "cuda"
model = model.to(DEVICE)
ddpm = ddpm.to_device(DEVICE)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = ddpm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = ddpm.sample_time(batch_size)
xt = ddpm.interpolate(x1, t, x0)
x_hat = model(xt, t)
loss = ddpm.loss(x_hat, x1, t, weight_type="data_to_noise").mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 0.504 2000: loss 1.002 3000: loss 0.446 4000: loss 1.014 5000: loss 0.375 6000: loss 1.849 7000: loss 0.489 8000: loss 1.577 9000: loss 0.314 10000: loss 0.468 11000: loss 0.332 12000: loss 1.729 13000: loss 0.374 14000: loss 0.779 15000: loss 0.536 16000: loss 6.597 17000: loss 1.269 18000: loss 0.501 19000: loss 0.546 20000: loss 0.490
inf_size = 1024
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step(x_hat, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
Now let us train with no loss weighting to optimize a true data matching loss for comparison¶
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
DEVICE = "cuda"
model = model.to(DEVICE)
ddpm = ddpm.to_device(DEVICE)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = ddpm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = ddpm.sample_time(batch_size)
xt = ddpm.interpolate(x1, t, x0)
x_hat = model(xt, t)
loss = ddpm.loss(x_hat, x1, t, weight_type="ones").mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 2.651 2000: loss 2.659 3000: loss 2.603 4000: loss 2.507 5000: loss 2.650 6000: loss 2.792 7000: loss 2.670 8000: loss 2.550 9000: loss 2.685 10000: loss 2.410 11000: loss 2.290 12000: loss 2.755 13000: loss 2.521 14000: loss 2.505 15000: loss 2.196 16000: loss 2.702 17000: loss 2.933 18000: loss 2.350 19000: loss 2.397 20000: loss 2.382
inf_size = 1024
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step(x_hat, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
Now let's try a continuous time analog interpolant to DDPM called VDM¶
This interpolant was used in Chroma and is described in great detail here https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf¶
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.interpolants import VDM
from bionemo.moco.schedules.noise.continuous_snr_transforms import CosineSNRTransform, LinearSNRTransform, LinearLogInterpolatedSNRTransform
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
DEVICE = "cuda:0"
uniform_time = UniformTimeDistribution(discrete_time=False)
simple_prior = GaussianPrior()
vdm = VDM(time_distribution=uniform_time,
prior_distribution=simple_prior,
prediction_type = "data",
noise_schedule = LinearLogInterpolatedSNRTransform(),
device=DEVICE)
schedule = LinearInferenceSchedule(nsteps = 1000, direction="diffusion")
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
DEVICE = "cuda"
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = vdm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = vdm.sample_time(batch_size)
xt = vdm.interpolate(x1, t, x0)
x_hat = model(xt, t)
loss = vdm.loss(x_hat, x1, t, weight_type="ones").mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 1.251 2000: loss 1.152 3000: loss 1.156 4000: loss 0.908 5000: loss 1.174 6000: loss 1.355 7000: loss 1.008 8000: loss 1.567 9000: loss 1.092 10000: loss 1.290 11000: loss 1.149 12000: loss 1.350 13000: loss 1.480 14000: loss 1.061 15000: loss 1.223 16000: loss 1.180 17000: loss 1.127 18000: loss 1.351 19000: loss 1.059 20000: loss 1.074
# DEVICE="cuda:1"
# model = model.to(DEVICE)
# vdm = vdm.to_device(DEVICE)
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = vdm.step(x_hat, full_t, sample, dt)
trajectory.append(sample) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = vdm.step_ddim(x_hat, full_t, sample, dt)
trajectory.append(sample) # save the trajectory for plotting purposes
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
What is interesting here is that the deterministic sampling of DDIM best recovers the Flow Matching ODE samples¶
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
# sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt)
sample = vdm.step_ode(x_hat, full_t, sample, dt)
trajectory.append(sample) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
# sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt)
sample = vdm.step_ode(x_hat, full_t, sample, dt, temperature = 1.5)
trajectory.append(sample) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
# sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt)
sample = vdm.step_ode(x_hat, full_t, sample, dt, temperature = 0.5)
trajectory.append(sample) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt)
# sample = vdm.step_ode(x_hat, full_t, sample, dt)
trajectory.append(sample) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()