Skip to content

Compare

assert_model_equivalence(ckpt_path, model_tag, precision='fp32', rtol=None, atol=None)

Testing utility to compare the outputs of a NeMo2 checkpoint to the original HuggingFace model weights.

Compares the cosine similarity of the logit and hidden state outputs of a NeMo2 model checkpoint to the outputs of the corresponding HuggingFace model.

Parameters:

Name Type Description Default
ckpt_path Path | str

A path to a NeMo2 checkpoint for an ESM-2 model.

required
model_tag str

The HuggingFace model tag for the model to compare against.

required
precision PrecisionTypes

The precision type to use for the comparison. Defaults to "fp32".

'fp32'
rtol float | None

The relative tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on the precision.

None
atol float | None

The absolute tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on the precision.

None
Source code in bionemo/esm2/testing/compare.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def assert_model_equivalence(
    ckpt_path: Path | str,
    model_tag: str,
    precision: PrecisionTypes = "fp32",
    rtol: float | None = None,
    atol: float | None = None,
) -> None:
    """Testing utility to compare the outputs of a NeMo2 checkpoint to the original HuggingFace model weights.

    Compares the cosine similarity of the logit and hidden state outputs of a NeMo2 model checkpoint to the outputs of
    the corresponding HuggingFace model.

    Args:
        ckpt_path: A path to a NeMo2 checkpoint for an ESM-2 model.
        model_tag: The HuggingFace model tag for the model to compare against.
        precision: The precision type to use for the comparison. Defaults to "fp32".
        rtol: The relative tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on
            the precision.
        atol: The absolute tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on
            the precision.
    """
    tokenizer = get_tokenizer()

    test_proteins = [
        "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA",
        "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG",
    ]
    tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda")
    input_ids = tokens["input_ids"]
    attention_mask = tokens["attention_mask"]

    dtype = get_autocast_dtype(precision)
    nemo_config = ESM2Config(
        initial_ckpt_path=str(ckpt_path),
        include_embeddings=True,
        include_hiddens=True,
        params_dtype=dtype,
        pipeline_dtype=dtype,
        autocast_dtype=dtype,
        bf16=dtype is torch.bfloat16,
        fp16=dtype is torch.float16,
    )

    nemo_model = nemo_config.configure_model(tokenizer).to("cuda").eval()

    if dtype is torch.float16 or dtype is torch.bfloat16:
        nemo_model = Float16Module(nemo_config, nemo_model)

    nemo_output = nemo_model(input_ids, attention_mask)
    nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size]
    nemo_hidden_state = nemo_output["hidden_states"]

    del nemo_model
    gc.collect()
    torch.cuda.empty_cache()

    hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(precision)).cuda().eval()
    hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True)
    hf_hidden_state = hf_output_all.hidden_states[-1]

    # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These
    # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences.
    # We don't care about the padding tokens, so we only compare the non-padding tokens.
    logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2)
    logit_similarity = logit_similarity[attention_mask == 1]

    hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2)
    hidden_state_similarity = hidden_state_similarity[attention_mask == 1]

    torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity), rtol=rtol, atol=atol)
    torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity), rtol=rtol, atol=atol)