Skip to content

Lightning

get_random_microbatch(microbatch_size, max_sequence_length, vocab_size, seed)

Generate random microbatches for testing.

Note that this follows the convention that token_logits are s,b, while other fields are b,s.

Source code in bionemo/testing/lightning.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_random_microbatch(
    microbatch_size: int, max_sequence_length: int, vocab_size: int, seed: int
) -> Dict[str, Dict[str, torch.Tensor]]:
    """Generate random microbatches for testing.

    Note that this follows the convention that token_logits are s,b, while other fields are b,s.
    """
    generator = torch.Generator(device=torch.cuda.current_device()).manual_seed(seed)
    labels = torch.randint(
        low=0,
        high=vocab_size,
        size=(microbatch_size, max_sequence_length),
        generator=generator,
        device=torch.cuda.current_device(),
    )  # [b s]
    loss_mask = torch.randint(
        low=1,
        high=1 + 1,
        size=(microbatch_size, max_sequence_length),
        dtype=torch.long,
        device=torch.cuda.current_device(),
        generator=generator,
    )  # [b s]
    token_logits = torch.rand(
        max_sequence_length, microbatch_size, vocab_size, device=torch.cuda.current_device(), generator=generator
    )  # [s b v]
    labels[loss_mask == 0] = -100  # propagate masking to labels
    microbatch_output = {
        "batch": {"labels": labels, "loss_mask": loss_mask},
        "forward_out": {"token_logits": token_logits},
    }
    return microbatch_output