Generate random microbatches for testing.
Note that this follows the convention that token_logits are s,b, while other fields are b,s.
Source code in bionemo/testing/lightning.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53 | def get_random_microbatch(
microbatch_size: int, max_sequence_length: int, vocab_size: int, seed: int
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Generate random microbatches for testing.
Note that this follows the convention that token_logits are s,b, while other fields are b,s.
"""
generator = torch.Generator(device=torch.cuda.current_device()).manual_seed(seed)
labels = torch.randint(
low=0,
high=vocab_size,
size=(microbatch_size, max_sequence_length),
generator=generator,
device=torch.cuda.current_device(),
) # [b s]
loss_mask = torch.randint(
low=1,
high=1 + 1,
size=(microbatch_size, max_sequence_length),
dtype=torch.long,
device=torch.cuda.current_device(),
generator=generator,
) # [b s]
token_logits = torch.rand(
max_sequence_length, microbatch_size, vocab_size, device=torch.cuda.current_device(), generator=generator
) # [s b v]
labels[loss_mask == 0] = -100 # propagate masking to labels
microbatch_output = {
"batch": {"labels": labels, "loss_mask": loss_mask},
"forward_out": {"token_logits": token_logits},
}
return microbatch_output
|