Building Generative Models for Discrete Data via Discrete Interpolants¶
In [1]:
Copied!
import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
torch.cuda.manual_seed(42)
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from tqdm import tqdm
import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
torch.cuda.manual_seed(42)
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from tqdm import tqdm
Tutorial¶
This notebook walks through how to use 3 discrete data interpolants: (1) Discrete Flow Matching (2) Discrete Denoising Diffusion Probabilistic Models, and (3) Masked Diffusion Language Modeling
Task¶
here our object contains 10 binary elements with the goal distribution being a uniform distribution over the 10 elements.
We initalize our interpolants with a binary uniform prior so on average each sample with have a value of 5 out of 10
Define the Model Architecture¶
In [2]:
Copied!
# training
B = 32 # batch size
D = 10 # dimension
S = 2 # state space
class Model(nn.Module):
def __init__(self, D, S):
super().__init__()
self.embedding = nn.Embedding(S+1, 16)
self.net = nn.Sequential(
nn.Linear(17 * D, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, S*D),
)
def forward(self, x, t):
B, D = x.shape
x_emb = self.embedding(x) # (B, D, 16)
net_input = torch.cat([x_emb, t[:, None, None].repeat(1, D, 1)], dim=-1).reshape(B, -1) # (B, D * 17)
return self.net(net_input).reshape(B, D, S) # (B, D, S)
# training
B = 32 # batch size
D = 10 # dimension
S = 2 # state space
class Model(nn.Module):
def __init__(self, D, S):
super().__init__()
self.embedding = nn.Embedding(S+1, 16)
self.net = nn.Sequential(
nn.Linear(17 * D, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, S*D),
)
def forward(self, x, t):
B, D = x.shape
x_emb = self.embedding(x) # (B, D, 16)
net_input = torch.cat([x_emb, t[:, None, None].repeat(1, D, 1)], dim=-1).reshape(B, -1) # (B, D * 17)
return self.net(net_input).reshape(B, D, S) # (B, D, S)
Define the Discret Flow Matching Interpolant¶
In [3]:
Copied!
from bionemo.moco.distributions.prior import DiscreteUniformPrior
from bionemo.moco.interpolants import DiscreteFlowMatcher
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
B = 32 # batch size
D = 10 # dimension
S = 2 # state space
DEVICE = "cuda:0"
prior = DiscreteUniformPrior(num_classes=S)
time_distribution = UniformTimeDistribution()
dfm = DiscreteFlowMatcher(time_distribution=time_distribution,
prior_distribution=prior,
device=DEVICE)
schedule = LinearInferenceSchedule(nsteps = 1000)
from bionemo.moco.distributions.prior import DiscreteUniformPrior
from bionemo.moco.interpolants import DiscreteFlowMatcher
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
B = 32 # batch size
D = 10 # dimension
S = 2 # state space
DEVICE = "cuda:0"
prior = DiscreteUniformPrior(num_classes=S)
time_distribution = UniformTimeDistribution()
dfm = DiscreteFlowMatcher(time_distribution=time_distribution,
prior_distribution=prior,
device=DEVICE)
schedule = LinearInferenceSchedule(nsteps = 1000)
In [4]:
Copied!
model = Model(D, S)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model = Model(D, S)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
Train DFM¶
In [5]:
Copied!
model = model.to(DEVICE)
losses = []
for _ in tqdm(range(50000)):
num_ones = torch.randint(0, D+1, (B,))
x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long().to(DEVICE)
# x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
optimizer.zero_grad()
x0 = dfm.sample_prior(x1.shape) # B x D
t = dfm.sample_time(B)
xt = dfm.interpolate(x1, t, x0)
logits = model(xt, t) # (B, D, S)
loss = dfm.loss(logits, x1, t).mean()
loss.backward()
optimizer.step()
losses.append(loss.item())
model = model.to(DEVICE)
losses = []
for _ in tqdm(range(50000)):
num_ones = torch.randint(0, D+1, (B,))
x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long().to(DEVICE)
# x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
optimizer.zero_grad()
x0 = dfm.sample_prior(x1.shape) # B x D
t = dfm.sample_time(B)
xt = dfm.interpolate(x1, t, x0)
logits = model(xt, t) # (B, D, S)
loss = dfm.loss(logits, x1, t).mean()
loss.backward()
optimizer.step()
losses.append(loss.item())
100%|██████████| 50000/50000 [00:54<00:00, 923.41it/s]
In [6]:
Copied!
plt.plot(losses, label='Training Loss', linestyle='-', color='blue', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)
plt.show()
plt.plot(losses, label='Training Loss', linestyle='-', color='blue', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)
plt.show()
Sample from DFM¶
In [7]:
Copied!
num_samples = 1000
xt = dfm.sample_prior((num_samples, D))
print(xt.shape)
ts = schedule.generate_schedule(device=DEVICE)
dts = schedule.discretize(device=DEVICE)
num_samples = 1000
xt = dfm.sample_prior((num_samples, D))
print(xt.shape)
ts = schedule.generate_schedule(device=DEVICE)
dts = schedule.discretize(device=DEVICE)
torch.Size([1000, 10])
In [8]:
Copied!
ts
ts
Out[8]:
tensor([0.0000, 0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0060, 0.0070, 0.0080, 0.0090, 0.0100, 0.0110, 0.0120, 0.0130, 0.0140, 0.0150, 0.0160, 0.0170, 0.0180, 0.0190, 0.0200, 0.0210, 0.0220, 0.0230, 0.0240, 0.0250, 0.0260, 0.0270, 0.0280, 0.0290, 0.0300, 0.0310, 0.0320, 0.0330, 0.0340, 0.0350, 0.0360, 0.0370, 0.0380, 0.0390, 0.0400, 0.0410, 0.0420, 0.0430, 0.0440, 0.0450, 0.0460, 0.0470, 0.0480, 0.0490, 0.0500, 0.0510, 0.0520, 0.0530, 0.0540, 0.0550, 0.0560, 0.0570, 0.0580, 0.0590, 0.0600, 0.0610, 0.0620, 0.0630, 0.0640, 0.0650, 0.0660, 0.0670, 0.0680, 0.0690, 0.0700, 0.0710, 0.0720, 0.0730, 0.0740, 0.0750, 0.0760, 0.0770, 0.0780, 0.0790, 0.0800, 0.0810, 0.0820, 0.0830, 0.0840, 0.0850, 0.0860, 0.0870, 0.0880, 0.0890, 0.0900, 0.0910, 0.0920, 0.0930, 0.0940, 0.0950, 0.0960, 0.0970, 0.0980, 0.0990, 0.1000, 0.1010, 0.1020, 0.1030, 0.1040, 0.1050, 0.1060, 0.1070, 0.1080, 0.1090, 0.1100, 0.1110, 0.1120, 0.1130, 0.1140, 0.1150, 0.1160, 0.1170, 0.1180, 0.1190, 0.1200, 0.1210, 0.1220, 0.1230, 0.1240, 0.1250, 0.1260, 0.1270, 0.1280, 0.1290, 0.1300, 0.1310, 0.1320, 0.1330, 0.1340, 0.1350, 0.1360, 0.1370, 0.1380, 0.1390, 0.1400, 0.1410, 0.1420, 0.1430, 0.1440, 0.1450, 0.1460, 0.1470, 0.1480, 0.1490, 0.1500, 0.1510, 0.1520, 0.1530, 0.1540, 0.1550, 0.1560, 0.1570, 0.1580, 0.1590, 0.1600, 0.1610, 0.1620, 0.1630, 0.1640, 0.1650, 0.1660, 0.1670, 0.1680, 0.1690, 0.1700, 0.1710, 0.1720, 0.1730, 0.1740, 0.1750, 0.1760, 0.1770, 0.1780, 0.1790, 0.1800, 0.1810, 0.1820, 0.1830, 0.1840, 0.1850, 0.1860, 0.1870, 0.1880, 0.1890, 0.1900, 0.1910, 0.1920, 0.1930, 0.1940, 0.1950, 0.1960, 0.1970, 0.1980, 0.1990, 0.2000, 0.2010, 0.2020, 0.2030, 0.2040, 0.2050, 0.2060, 0.2070, 0.2080, 0.2090, 0.2100, 0.2110, 0.2120, 0.2130, 0.2140, 0.2150, 0.2160, 0.2170, 0.2180, 0.2190, 0.2200, 0.2210, 0.2220, 0.2230, 0.2240, 0.2250, 0.2260, 0.2270, 0.2280, 0.2290, 0.2300, 0.2310, 0.2320, 0.2330, 0.2340, 0.2350, 0.2360, 0.2370, 0.2380, 0.2390, 0.2400, 0.2410, 0.2420, 0.2430, 0.2440, 0.2450, 0.2460, 0.2470, 0.2480, 0.2490, 0.2500, 0.2510, 0.2520, 0.2530, 0.2540, 0.2550, 0.2560, 0.2570, 0.2580, 0.2590, 0.2600, 0.2610, 0.2620, 0.2630, 0.2640, 0.2650, 0.2660, 0.2670, 0.2680, 0.2690, 0.2700, 0.2710, 0.2720, 0.2730, 0.2740, 0.2750, 0.2760, 0.2770, 0.2780, 0.2790, 0.2800, 0.2810, 0.2820, 0.2830, 0.2840, 0.2850, 0.2860, 0.2870, 0.2880, 0.2890, 0.2900, 0.2910, 0.2920, 0.2930, 0.2940, 0.2950, 0.2960, 0.2970, 0.2980, 0.2990, 0.3000, 0.3010, 0.3020, 0.3030, 0.3040, 0.3050, 0.3060, 0.3070, 0.3080, 0.3090, 0.3100, 0.3110, 0.3120, 0.3130, 0.3140, 0.3150, 0.3160, 0.3170, 0.3180, 0.3190, 0.3200, 0.3210, 0.3220, 0.3230, 0.3240, 0.3250, 0.3260, 0.3270, 0.3280, 0.3290, 0.3300, 0.3310, 0.3320, 0.3330, 0.3340, 0.3350, 0.3360, 0.3370, 0.3380, 0.3390, 0.3400, 0.3410, 0.3420, 0.3430, 0.3440, 0.3450, 0.3460, 0.3470, 0.3480, 0.3490, 0.3500, 0.3510, 0.3520, 0.3530, 0.3540, 0.3550, 0.3560, 0.3570, 0.3580, 0.3590, 0.3600, 0.3610, 0.3620, 0.3630, 0.3640, 0.3650, 0.3660, 0.3670, 0.3680, 0.3690, 0.3700, 0.3710, 0.3720, 0.3730, 0.3740, 0.3750, 0.3760, 0.3770, 0.3780, 0.3790, 0.3800, 0.3810, 0.3820, 0.3830, 0.3840, 0.3850, 0.3860, 0.3870, 0.3880, 0.3890, 0.3900, 0.3910, 0.3920, 0.3930, 0.3940, 0.3950, 0.3960, 0.3970, 0.3980, 0.3990, 0.4000, 0.4010, 0.4020, 0.4030, 0.4040, 0.4050, 0.4060, 0.4070, 0.4080, 0.4090, 0.4100, 0.4110, 0.4120, 0.4130, 0.4140, 0.4150, 0.4160, 0.4170, 0.4180, 0.4190, 0.4200, 0.4210, 0.4220, 0.4230, 0.4240, 0.4250, 0.4260, 0.4270, 0.4280, 0.4290, 0.4300, 0.4310, 0.4320, 0.4330, 0.4340, 0.4350, 0.4360, 0.4370, 0.4380, 0.4390, 0.4400, 0.4410, 0.4420, 0.4430, 0.4440, 0.4450, 0.4460, 0.4470, 0.4480, 0.4490, 0.4500, 0.4510, 0.4520, 0.4530, 0.4540, 0.4550, 0.4560, 0.4570, 0.4580, 0.4590, 0.4600, 0.4610, 0.4620, 0.4630, 0.4640, 0.4650, 0.4660, 0.4670, 0.4680, 0.4690, 0.4700, 0.4710, 0.4720, 0.4730, 0.4740, 0.4750, 0.4760, 0.4770, 0.4780, 0.4790, 0.4800, 0.4810, 0.4820, 0.4830, 0.4840, 0.4850, 0.4860, 0.4870, 0.4880, 0.4890, 0.4900, 0.4910, 0.4920, 0.4930, 0.4940, 0.4950, 0.4960, 0.4970, 0.4980, 0.4990, 0.5000, 0.5010, 0.5020, 0.5030, 0.5040, 0.5050, 0.5060, 0.5070, 0.5080, 0.5090, 0.5100, 0.5110, 0.5120, 0.5130, 0.5140, 0.5150, 0.5160, 0.5170, 0.5180, 0.5190, 0.5200, 0.5210, 0.5220, 0.5230, 0.5240, 0.5250, 0.5260, 0.5270, 0.5280, 0.5290, 0.5300, 0.5310, 0.5320, 0.5330, 0.5340, 0.5350, 0.5360, 0.5370, 0.5380, 0.5390, 0.5400, 0.5410, 0.5420, 0.5430, 0.5440, 0.5450, 0.5460, 0.5470, 0.5480, 0.5490, 0.5500, 0.5510, 0.5520, 0.5530, 0.5540, 0.5550, 0.5560, 0.5570, 0.5580, 0.5590, 0.5600, 0.5610, 0.5620, 0.5630, 0.5640, 0.5650, 0.5660, 0.5670, 0.5680, 0.5690, 0.5700, 0.5710, 0.5720, 0.5730, 0.5740, 0.5750, 0.5760, 0.5770, 0.5780, 0.5790, 0.5800, 0.5810, 0.5820, 0.5830, 0.5840, 0.5850, 0.5860, 0.5870, 0.5880, 0.5890, 0.5900, 0.5910, 0.5920, 0.5930, 0.5940, 0.5950, 0.5960, 0.5970, 0.5980, 0.5990, 0.6000, 0.6010, 0.6020, 0.6030, 0.6040, 0.6050, 0.6060, 0.6070, 0.6080, 0.6090, 0.6100, 0.6110, 0.6120, 0.6130, 0.6140, 0.6150, 0.6160, 0.6170, 0.6180, 0.6190, 0.6200, 0.6210, 0.6220, 0.6230, 0.6240, 0.6250, 0.6260, 0.6270, 0.6280, 0.6290, 0.6300, 0.6310, 0.6320, 0.6330, 0.6340, 0.6350, 0.6360, 0.6370, 0.6380, 0.6390, 0.6400, 0.6410, 0.6420, 0.6430, 0.6440, 0.6450, 0.6460, 0.6470, 0.6480, 0.6490, 0.6500, 0.6510, 0.6520, 0.6530, 0.6540, 0.6550, 0.6560, 0.6570, 0.6580, 0.6590, 0.6600, 0.6610, 0.6620, 0.6630, 0.6640, 0.6650, 0.6660, 0.6670, 0.6680, 0.6690, 0.6700, 0.6710, 0.6720, 0.6730, 0.6740, 0.6750, 0.6760, 0.6770, 0.6780, 0.6790, 0.6800, 0.6810, 0.6820, 0.6830, 0.6840, 0.6850, 0.6860, 0.6870, 0.6880, 0.6890, 0.6900, 0.6910, 0.6920, 0.6930, 0.6940, 0.6950, 0.6960, 0.6970, 0.6980, 0.6990, 0.7000, 0.7010, 0.7020, 0.7030, 0.7040, 0.7050, 0.7060, 0.7070, 0.7080, 0.7090, 0.7100, 0.7110, 0.7120, 0.7130, 0.7140, 0.7150, 0.7160, 0.7170, 0.7180, 0.7190, 0.7200, 0.7210, 0.7220, 0.7230, 0.7240, 0.7250, 0.7260, 0.7270, 0.7280, 0.7290, 0.7300, 0.7310, 0.7320, 0.7330, 0.7340, 0.7350, 0.7360, 0.7370, 0.7380, 0.7390, 0.7400, 0.7410, 0.7420, 0.7430, 0.7440, 0.7450, 0.7460, 0.7470, 0.7480, 0.7490, 0.7500, 0.7510, 0.7520, 0.7530, 0.7540, 0.7550, 0.7560, 0.7570, 0.7580, 0.7590, 0.7600, 0.7610, 0.7620, 0.7630, 0.7640, 0.7650, 0.7660, 0.7670, 0.7680, 0.7690, 0.7700, 0.7710, 0.7720, 0.7730, 0.7740, 0.7750, 0.7760, 0.7770, 0.7780, 0.7790, 0.7800, 0.7810, 0.7820, 0.7830, 0.7840, 0.7850, 0.7860, 0.7870, 0.7880, 0.7890, 0.7900, 0.7910, 0.7920, 0.7930, 0.7940, 0.7950, 0.7960, 0.7970, 0.7980, 0.7990, 0.8000, 0.8010, 0.8020, 0.8030, 0.8040, 0.8050, 0.8060, 0.8070, 0.8080, 0.8090, 0.8100, 0.8110, 0.8120, 0.8130, 0.8140, 0.8150, 0.8160, 0.8170, 0.8180, 0.8190, 0.8200, 0.8210, 0.8220, 0.8230, 0.8240, 0.8250, 0.8260, 0.8270, 0.8280, 0.8290, 0.8300, 0.8310, 0.8320, 0.8330, 0.8340, 0.8350, 0.8360, 0.8370, 0.8380, 0.8390, 0.8400, 0.8410, 0.8420, 0.8430, 0.8440, 0.8450, 0.8460, 0.8470, 0.8480, 0.8490, 0.8500, 0.8510, 0.8520, 0.8530, 0.8540, 0.8550, 0.8560, 0.8570, 0.8580, 0.8590, 0.8600, 0.8610, 0.8620, 0.8630, 0.8640, 0.8650, 0.8660, 0.8670, 0.8680, 0.8690, 0.8700, 0.8710, 0.8720, 0.8730, 0.8740, 0.8750, 0.8760, 0.8770, 0.8780, 0.8790, 0.8800, 0.8810, 0.8820, 0.8830, 0.8840, 0.8850, 0.8860, 0.8870, 0.8880, 0.8890, 0.8900, 0.8910, 0.8920, 0.8930, 0.8940, 0.8950, 0.8960, 0.8970, 0.8980, 0.8990, 0.9000, 0.9010, 0.9020, 0.9030, 0.9040, 0.9050, 0.9060, 0.9070, 0.9080, 0.9090, 0.9100, 0.9110, 0.9120, 0.9130, 0.9140, 0.9150, 0.9160, 0.9170, 0.9180, 0.9190, 0.9200, 0.9210, 0.9220, 0.9230, 0.9240, 0.9250, 0.9260, 0.9270, 0.9280, 0.9290, 0.9300, 0.9310, 0.9320, 0.9330, 0.9340, 0.9350, 0.9360, 0.9370, 0.9380, 0.9390, 0.9400, 0.9410, 0.9420, 0.9430, 0.9440, 0.9450, 0.9460, 0.9470, 0.9480, 0.9490, 0.9500, 0.9510, 0.9520, 0.9530, 0.9540, 0.9550, 0.9560, 0.9570, 0.9580, 0.9590, 0.9600, 0.9610, 0.9620, 0.9630, 0.9640, 0.9650, 0.9660, 0.9670, 0.9680, 0.9690, 0.9700, 0.9710, 0.9720, 0.9730, 0.9740, 0.9750, 0.9760, 0.9770, 0.9780, 0.9790, 0.9800, 0.9810, 0.9820, 0.9830, 0.9840, 0.9850, 0.9860, 0.9870, 0.9880, 0.9890, 0.9900, 0.9910, 0.9920, 0.9930, 0.9940, 0.9950, 0.9960, 0.9970, 0.9980, 0.9990], device='cuda:0')
In [9]:
Copied!
LinearInferenceSchedule(nsteps = 100, min_t=0, inclusive_end=False).generate_schedule()
LinearInferenceSchedule(nsteps = 100, min_t=0, inclusive_end=False).generate_schedule()
Out[9]:
tensor([0.0000, 0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800, 0.0900, 0.1000, 0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700, 0.1800, 0.1900, 0.2000, 0.2100, 0.2200, 0.2300, 0.2400, 0.2500, 0.2600, 0.2700, 0.2800, 0.2900, 0.3000, 0.3100, 0.3200, 0.3300, 0.3400, 0.3500, 0.3600, 0.3700, 0.3800, 0.3900, 0.4000, 0.4100, 0.4200, 0.4300, 0.4400, 0.4500, 0.4600, 0.4700, 0.4800, 0.4900, 0.5000, 0.5100, 0.5200, 0.5300, 0.5400, 0.5500, 0.5600, 0.5700, 0.5800, 0.5900, 0.6000, 0.6100, 0.6200, 0.6300, 0.6400, 0.6500, 0.6600, 0.6700, 0.6800, 0.6900, 0.7000, 0.7100, 0.7200, 0.7300, 0.7400, 0.7500, 0.7600, 0.7700, 0.7800, 0.7900, 0.8000, 0.8100, 0.8200, 0.8300, 0.8400, 0.8500, 0.8600, 0.8700, 0.8800, 0.8900, 0.9000, 0.9100, 0.9200, 0.9300, 0.9400, 0.9500, 0.9600, 0.9700, 0.9800, 0.9900])
In [10]:
Copied!
for dt, t in zip(dts, ts):
t = schedule.pad_time(num_samples, t, DEVICE)
logits = model(xt, t)
xt = dfm.step(logits, t, xt, dt, stochasticity=0)
for dt, t in zip(dts, ts):
t = schedule.pad_time(num_samples, t, DEVICE)
logits = model(xt, t)
xt = dfm.step(logits, t, xt, dt, stochasticity=0)
Generated DFM Samples¶
In [11]:
Copied!
counts = xt.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
counts = xt.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
Ground Truth Distribution¶
In [12]:
Copied!
num_ones = torch.randint(0, D+1, (1000,))
x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long()
counts = x1.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
num_ones = torch.randint(0, D+1, (1000,))
x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long()
counts = x1.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
Discrete Uniform Prior Distribution¶
In [13]:
Copied!
x0 = dfm.sample_prior((10000, D))
counts = x0.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
x0 = dfm.sample_prior((10000, D))
counts = x0.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
We see that with DFM we are able to approximate the ground truth distribution.Now let's try a different interpolant¶
D3PM Interpolant¶
In [30]:
Copied!
from bionemo.moco.distributions.prior import DiscreteUniformPrior
from bionemo.moco.interpolants import D3PM
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule
from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule
B = 32 # batch size
D = 10 # dimension
S = 2 # state space
DEVICE = "cuda:0"
prior = DiscreteUniformPrior(num_classes=S)
time_distribution = UniformTimeDistribution(discrete_time = True, nsteps = 1000)
noise_schedule = DiscreteCosineNoiseSchedule(nsteps = 1000)
d3pm = D3PM(time_distribution=time_distribution,
prior_distribution=prior,
noise_schedule = noise_schedule,
device=DEVICE)
schedule = DiscreteLinearInferenceSchedule(nsteps = 1000, direction="diffusion", device=DEVICE)
from bionemo.moco.distributions.prior import DiscreteUniformPrior
from bionemo.moco.interpolants import D3PM
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule
from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule
B = 32 # batch size
D = 10 # dimension
S = 2 # state space
DEVICE = "cuda:0"
prior = DiscreteUniformPrior(num_classes=S)
time_distribution = UniformTimeDistribution(discrete_time = True, nsteps = 1000)
noise_schedule = DiscreteCosineNoiseSchedule(nsteps = 1000)
d3pm = D3PM(time_distribution=time_distribution,
prior_distribution=prior,
noise_schedule = noise_schedule,
device=DEVICE)
schedule = DiscreteLinearInferenceSchedule(nsteps = 1000, direction="diffusion", device=DEVICE)
In [31]:
Copied!
model = Model(D, S)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
d3pm.terminal_distribution
model = Model(D, S)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
d3pm.terminal_distribution
Out[31]:
tensor([0.5000, 0.5000])
Train D3PM¶
In [32]:
Copied!
model = model.to(DEVICE)
losses = []
for _ in tqdm(range(50000)):
num_ones = torch.randint(0, D+1, (B,))
x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long().to(DEVICE)
# x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
optimizer.zero_grad()
# x0 = dfm.sample_prior(x1.shape) # B x D
t = d3pm.sample_time(B)
xt = d3pm.interpolate(x1, t)
logits = model(xt, t) # (B, D, S)
loss = d3pm.loss(logits, x1, xt, t).mean()
loss.backward()
optimizer.step()
losses.append(loss.item())
model = model.to(DEVICE)
losses = []
for _ in tqdm(range(50000)):
num_ones = torch.randint(0, D+1, (B,))
x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long().to(DEVICE)
# x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
optimizer.zero_grad()
# x0 = dfm.sample_prior(x1.shape) # B x D
t = d3pm.sample_time(B)
xt = d3pm.interpolate(x1, t)
logits = model(xt, t) # (B, D, S)
loss = d3pm.loss(logits, x1, xt, t).mean()
loss.backward()
optimizer.step()
losses.append(loss.item())
100%|██████████| 50000/50000 [01:08<00:00, 727.62it/s]
In [33]:
Copied!
plt.plot(losses, label='Training Loss', linestyle='-', color='blue', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)
plt.ylim([0,1])
# plt.yscale('log')
plt.show()
plt.plot(losses, label='Training Loss', linestyle='-', color='blue', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)
plt.ylim([0,1])
# plt.yscale('log')
plt.show()
Sample from D3PM¶
In [34]:
Copied!
ts = schedule.generate_schedule()
num_samples = 1000
xt = d3pm.sample_prior((num_samples, D))
for t in ts:
t = torch.full((xt.shape[0],), t).to(DEVICE)
logits = model(xt, t)
xt = d3pm.step(logits, t, xt)
ts = schedule.generate_schedule()
num_samples = 1000
xt = d3pm.sample_prior((num_samples, D))
for t in ts:
t = torch.full((xt.shape[0],), t).to(DEVICE)
logits = model(xt, t)
xt = d3pm.step(logits, t, xt)
D3PM Generated Distribution¶
In [35]:
Copied!
counts = xt.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
counts = xt.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
D3PM Prior Distribution¶
In [20]:
Copied!
xt = d3pm.sample_prior((num_samples, D))
counts = xt.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
xt = d3pm.sample_prior((num_samples, D))
counts = xt.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
Now let's try a new interpolant and a new prior¶
MDLM Interpolant¶
In [21]:
Copied!
from bionemo.moco.distributions.prior import DiscreteMaskedPrior
from bionemo.moco.interpolants import MDLM
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.schedules.noise.continuous_noise_transforms import CosineExpNoiseTransform
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
DEVICE = "cuda:0"
prior = DiscreteMaskedPrior(num_classes = 2, inclusive = False)
time_distribution = UniformTimeDistribution(discrete_time = False)
noise_schedule = CosineExpNoiseTransform()
mdlm = MDLM(time_distribution=time_distribution,
prior_distribution=prior,
noise_schedule = noise_schedule,
device=DEVICE)
schedule = LinearInferenceSchedule(direction = "diffusion", nsteps = 1000)
from bionemo.moco.distributions.prior import DiscreteMaskedPrior
from bionemo.moco.interpolants import MDLM
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.schedules.noise.continuous_noise_transforms import CosineExpNoiseTransform
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
DEVICE = "cuda:0"
prior = DiscreteMaskedPrior(num_classes = 2, inclusive = False)
time_distribution = UniformTimeDistribution(discrete_time = False)
noise_schedule = CosineExpNoiseTransform()
mdlm = MDLM(time_distribution=time_distribution,
prior_distribution=prior,
noise_schedule = noise_schedule,
device=DEVICE)
schedule = LinearInferenceSchedule(direction = "diffusion", nsteps = 1000)
In [22]:
Copied!
prior.num_classes # The inclusive flag allows us to chose whether or not to add a dimension
prior.num_classes # The inclusive flag allows us to chose whether or not to add a dimension
Out[22]:
3
Train MDLM¶
In [23]:
Copied!
# training
B = 32 # batch size
D = 10 # dimension
S = 3 # state space
model = Model(D, S)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model = model.to(DEVICE)
losses = []
for _ in tqdm(range(50000)):
num_ones = torch.randint(0, D+1, (B,))
x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long().to(DEVICE)
# x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
optimizer.zero_grad()
# x0 = dfm.sample_prior(x1.shape) # B x D
t = mdlm.sample_time(B)
xt = mdlm.interpolate(x1, t)
logits = model(xt, t) # (B, D, S)
loss = mdlm.loss(logits, x1, xt, t).mean()
loss.backward()
optimizer.step()
losses.append(loss.item())
# training
B = 32 # batch size
D = 10 # dimension
S = 3 # state space
model = Model(D, S)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model = model.to(DEVICE)
losses = []
for _ in tqdm(range(50000)):
num_ones = torch.randint(0, D+1, (B,))
x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long().to(DEVICE)
# x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
optimizer.zero_grad()
# x0 = dfm.sample_prior(x1.shape) # B x D
t = mdlm.sample_time(B)
xt = mdlm.interpolate(x1, t)
logits = model(xt, t) # (B, D, S)
loss = mdlm.loss(logits, x1, xt, t).mean()
loss.backward()
optimizer.step()
losses.append(loss.item())
0%| | 0/50000 [00:00<?, ?it/s]
100%|██████████| 50000/50000 [01:34<00:00, 530.83it/s]
In [24]:
Copied!
plt.plot(losses, label='Training Loss', linestyle='-', color='blue', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)
plt.ylim([0,1])
plt.show()
plt.plot(losses, label='Training Loss', linestyle='-', color='blue', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)
plt.ylim([0,1])
plt.show()
Visualize the MASK Prior¶
In [25]:
Copied!
num_samples = 1000
xt = mdlm.sample_prior((num_samples, D))
counts = xt.flatten().cpu()
# Compute frequency of each class index
class_counts = torch.bincount(counts)
# Plotting
plt.figure(figsize=(8, 5))
plt.bar(range(len(class_counts)), class_counts.numpy(), color='red')
plt.xlabel('Class Index')
plt.ylabel('Frequency')
plt.title('Discrete Distribution of Class Indices')
plt.xticks(range(len(class_counts))) # Set x-ticks to class indices
plt.show()
num_samples = 1000
xt = mdlm.sample_prior((num_samples, D))
counts = xt.flatten().cpu()
# Compute frequency of each class index
class_counts = torch.bincount(counts)
# Plotting
plt.figure(figsize=(8, 5))
plt.bar(range(len(class_counts)), class_counts.numpy(), color='red')
plt.xlabel('Class Index')
plt.ylabel('Frequency')
plt.title('Discrete Distribution of Class Indices')
plt.xticks(range(len(class_counts))) # Set x-ticks to class indices
plt.show()
Sample from the MDLM trained model¶
In [26]:
Copied!
ts = schedule.generate_schedule()
dts = schedule.discretize()
num_samples = 1000
xt = mdlm.sample_prior((num_samples, D))
for dt, t in zip(dts, ts):
t = torch.full((xt.shape[0],), t).to(DEVICE)
logits = model(xt, t)
xt = mdlm.step(logits, t, xt, dt)
ts = schedule.generate_schedule()
dts = schedule.discretize()
num_samples = 1000
xt = mdlm.sample_prior((num_samples, D))
for dt, t in zip(dts, ts):
t = torch.full((xt.shape[0],), t).to(DEVICE)
logits = model(xt, t)
xt = mdlm.step(logits, t, xt, dt)
Visualize the class breakdown (green) and generated samples (blue)¶
In [27]:
Copied!
counts = xt.flatten().cpu()
# Compute frequency of each class index
class_counts = torch.bincount(counts)
# Plotting
plt.figure(figsize=(8, 5))
plt.bar(range(len(class_counts)), class_counts.numpy(), color='green')
plt.xlabel('Class Index')
plt.ylabel('Frequency')
plt.title('Discrete Distribution of Class Indices')
plt.xticks(range(len(class_counts))) # Set x-ticks to class indices
plt.show()
counts = xt.flatten().cpu()
# Compute frequency of each class index
class_counts = torch.bincount(counts)
# Plotting
plt.figure(figsize=(8, 5))
plt.bar(range(len(class_counts)), class_counts.numpy(), color='green')
plt.xlabel('Class Index')
plt.ylabel('Frequency')
plt.title('Discrete Distribution of Class Indices')
plt.xticks(range(len(class_counts))) # Set x-ticks to class indices
plt.show()
In [28]:
Copied!
counts = xt.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()
counts = xt.cpu().sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()