Skip to content

Compare

ForwardHook

A forward hook to extract a desired intermediate tensor for later comparison.

Source code in bionemo/esm2/testing/compare.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class ForwardHook:
    """A forward hook to extract a desired intermediate tensor for later comparison."""

    def __init__(self, transform_fn: TransformFn) -> None:
        """A forward hook to extract a desired intermediate tensor for later comparison.

        The resulting tensor is saved in the `data` attribute of the hook.

        Args:
            transform_fn: A function that maps the input and output tensors of the module to the desired tensor.
        """
        self._transform_fn = transform_fn
        self._data: torch.Tensor | None = None

    def __call__(self, module, module_in, module_out):
        """The forward hook function."""
        if not isinstance(module_out, tuple):
            module_out = (module_out,)
        if not isinstance(module_in, tuple):
            module_in = (module_in,)

        self._data = self._transform_fn(module_in, module_out).detach().cpu()

    @property
    def data(self) -> torch.Tensor:
        """The extracted tensor from the forward hook."""
        if self._data is None:
            raise ValueError("No data has been saved in this hook.")
        return self._data

data property

The extracted tensor from the forward hook.

__call__(module, module_in, module_out)

The forward hook function.

Source code in bionemo/esm2/testing/compare.py
229
230
231
232
233
234
235
236
def __call__(self, module, module_in, module_out):
    """The forward hook function."""
    if not isinstance(module_out, tuple):
        module_out = (module_out,)
    if not isinstance(module_in, tuple):
        module_in = (module_in,)

    self._data = self._transform_fn(module_in, module_out).detach().cpu()

__init__(transform_fn)

A forward hook to extract a desired intermediate tensor for later comparison.

The resulting tensor is saved in the data attribute of the hook.

Parameters:

Name Type Description Default
transform_fn TransformFn

A function that maps the input and output tensors of the module to the desired tensor.

required
Source code in bionemo/esm2/testing/compare.py
218
219
220
221
222
223
224
225
226
227
def __init__(self, transform_fn: TransformFn) -> None:
    """A forward hook to extract a desired intermediate tensor for later comparison.

    The resulting tensor is saved in the `data` attribute of the hook.

    Args:
        transform_fn: A function that maps the input and output tensors of the module to the desired tensor.
    """
    self._transform_fn = transform_fn
    self._data: torch.Tensor | None = None

TestHook

A test hook that just captures the raw inputs and outputs.

Source code in bionemo/esm2/testing/compare.py
246
247
248
249
250
251
252
253
254
255
256
257
class TestHook:
    """A test hook that just captures the raw inputs and outputs."""

    def __init__(self) -> None:
        """A test hook that just captures the raw inputs and outputs."""
        self.inputs: tuple[torch.Tensor, ...] | None = None
        self.outputs: tuple[torch.Tensor, ...] | None = None

    def __call__(self, module, inputs, outputs):
        """The forward hook function."""
        self.inputs = inputs
        self.outputs = outputs

__call__(module, inputs, outputs)

The forward hook function.

Source code in bionemo/esm2/testing/compare.py
254
255
256
257
def __call__(self, module, inputs, outputs):
    """The forward hook function."""
    self.inputs = inputs
    self.outputs = outputs

__init__()

A test hook that just captures the raw inputs and outputs.

Source code in bionemo/esm2/testing/compare.py
249
250
251
252
def __init__(self) -> None:
    """A test hook that just captures the raw inputs and outputs."""
    self.inputs: tuple[torch.Tensor, ...] | None = None
    self.outputs: tuple[torch.Tensor, ...] | None = None

assert_cosine_similarity(tensor1, tensor2, mask, rtol=None, atol=None, magnitude_rtol=0.01, magnitude_atol=0.01, msg=None)

Assert that both the cosine similarity between two tensors is close to 1, and the ratio of their magnitudes is 1.

Parameters:

Name Type Description Default
tensor1 Tensor

The first tensor to compare.

required
tensor2 Tensor

The second tensor to compare.

required
mask Tensor

A mask tensor to apply to the comparison.

required
rtol float | None

The relative tolerance to use for the comparison. Defaults to 1e-4.

None
atol float | None

The absolute tolerance to use for the comparison. Defaults to 1e-4.

None
magnitude_rtol float

The relative tolerance to use for the magnitude comparison. Defaults to 1e-2.

0.01
magnitude_atol float

The absolute tolerance to use for the magnitude comparison. Defaults to 1e-2.

0.01
msg str | None

An optional message to include in the assertion error.

None
Source code in bionemo/esm2/testing/compare.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def assert_cosine_similarity(
    tensor1: torch.Tensor,
    tensor2: torch.Tensor,
    mask: torch.Tensor,
    rtol: float | None = None,
    atol: float | None = None,
    magnitude_rtol: float = 1e-2,
    magnitude_atol: float = 1e-2,
    msg: str | None = None,
) -> None:
    """Assert that both the cosine similarity between two tensors is close to 1, and the ratio of their magnitudes is 1.

    Args:
        tensor1: The first tensor to compare.
        tensor2: The second tensor to compare.
        mask: A mask tensor to apply to the comparison.
        rtol: The relative tolerance to use for the comparison. Defaults to 1e-4.
        atol: The absolute tolerance to use for the comparison. Defaults to 1e-4.
        magnitude_rtol: The relative tolerance to use for the magnitude comparison. Defaults to 1e-2.
        magnitude_atol: The absolute tolerance to use for the magnitude comparison. Defaults to 1e-2.
        msg: An optional message to include in the assertion error.
    """
    assert tensor1.size() == tensor2.size()

    similarity = torch.nn.functional.cosine_similarity(tensor1, tensor2, dim=2)
    similarity = similarity[mask == 1]

    torch.testing.assert_close(
        similarity,
        torch.ones_like(similarity),
        rtol=rtol,
        atol=atol,
        msg=lambda x: f"{msg} (angle): {x}",
    )

    magnitude_similarity = torch.norm(tensor1, dim=2) / torch.norm(tensor2, dim=2)
    magnitude_similarity = magnitude_similarity[mask == 1]
    torch.testing.assert_close(
        magnitude_similarity,
        torch.ones_like(magnitude_similarity),
        rtol=magnitude_rtol,
        atol=magnitude_atol,
        msg=lambda x: f"{msg} (magnitude): {x}",
    )

assert_esm2_equivalence(ckpt_path, model_tag, precision='fp32', rtol=None, atol=None)

Testing utility to compare the outputs of a NeMo2 checkpoint to the original HuggingFace model weights.

Compares the cosine similarity of the logit and hidden state outputs of a NeMo2 model checkpoint to the outputs of the corresponding HuggingFace model.

Parameters:

Name Type Description Default
ckpt_path Path | str

A path to a NeMo2 checkpoint for an ESM-2 model.

required
model_tag str

The HuggingFace model tag for the model to compare against.

required
precision PrecisionTypes

The precision type to use for the comparison. Defaults to "fp32".

'fp32'
rtol float | None

The relative tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on the precision.

None
atol float | None

The absolute tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on the precision.

None
Source code in bionemo/esm2/testing/compare.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def assert_esm2_equivalence(
    ckpt_path: Path | str,
    model_tag: str,
    precision: PrecisionTypes = "fp32",
    rtol: float | None = None,
    atol: float | None = None,
) -> None:
    """Testing utility to compare the outputs of a NeMo2 checkpoint to the original HuggingFace model weights.

    Compares the cosine similarity of the logit and hidden state outputs of a NeMo2 model checkpoint to the outputs of
    the corresponding HuggingFace model.

    Args:
        ckpt_path: A path to a NeMo2 checkpoint for an ESM-2 model.
        model_tag: The HuggingFace model tag for the model to compare against.
        precision: The precision type to use for the comparison. Defaults to "fp32".
        rtol: The relative tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on
            the precision.
        atol: The absolute tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on
            the precision.
    """
    tokenizer = get_tokenizer()

    input_ids, attention_mask = get_input_tensors(tokenizer)

    nemo_logits, nemo_hidden_state = load_and_evaluate_nemo_esm2(ckpt_path, precision, input_ids, attention_mask)
    gc.collect()
    torch.cuda.empty_cache()
    hf_logits, hf_hidden_state = load_and_evaluate_hf_model(model_tag, precision, input_ids, attention_mask)

    # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These
    # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences.
    # We don't care about the padding tokens, so we only compare the non-padding tokens.
    assert_cosine_similarity(nemo_logits, hf_logits, attention_mask, rtol, atol)
    assert_cosine_similarity(nemo_hidden_state, hf_hidden_state, attention_mask, rtol, atol)

get_input_tensors(tokenizer)

Get input tensors for testing.

Parameters:

Name Type Description Default
tokenizer

A huggingface-like tokenizer object.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of the input IDs and attention mask tensors.

Source code in bionemo/esm2/testing/compare.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def get_input_tensors(tokenizer) -> tuple[torch.Tensor, torch.Tensor]:
    """Get input tensors for testing.

    Args:
        tokenizer: A huggingface-like tokenizer object.

    Returns:
        A tuple of the input IDs and attention mask tensors.
    """
    test_proteins = [
        "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA",
        "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG",
    ]
    tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True)
    input_ids: torch.Tensor = tokens["input_ids"]  # type: ignore
    attention_mask: torch.Tensor = tokens["attention_mask"]  # type: ignore

    # Pad the input IDs and attention mask to be divisible by 8 so xformers doesn't fail.
    padded_shape = math.ceil(attention_mask.size(1) / 8)
    padded_input_ids = torch.full((input_ids.size(0), padded_shape * 8), tokenizer.pad_token_id, dtype=torch.long)
    padded_input_ids[: input_ids.size(0), : input_ids.size(1)] = input_ids

    padded_attention_mask = torch.zeros((attention_mask.size(0), padded_shape * 8), dtype=torch.bool)
    padded_attention_mask[: attention_mask.size(0), : attention_mask.size(1)] = attention_mask

    return padded_input_ids.to("cuda"), padded_attention_mask.to("cuda")

load_and_evaluate_hf_model(model_tag, precision, input_ids, attention_mask)

Load a HuggingFace model and evaluate it on the given inputs.

Parameters:

Name Type Description Default
model_tag str

The HuggingFace model tag for the model to compare against.

required
precision PrecisionTypes

The precision type to use for the comparison.

required
input_ids Tensor

The input IDs tensor to evaluate.

required
attention_mask Tensor

The attention mask tensor to evaluate.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of the logits and hidden states tensors calculated by the HuggingFace model, respectively.

Source code in bionemo/esm2/testing/compare.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def load_and_evaluate_hf_model(
    model_tag: str, precision: PrecisionTypes, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Load a HuggingFace model and evaluate it on the given inputs.

    Args:
        model_tag: The HuggingFace model tag for the model to compare against.
        precision: The precision type to use for the comparison.
        input_ids: The input IDs tensor to evaluate.
        attention_mask: The attention mask tensor to evaluate.

    Returns:
        A tuple of the logits and hidden states tensors calculated by the HuggingFace model, respectively.
    """
    hf_model = AutoModelForMaskedLM.from_pretrained(
        model_tag,
        torch_dtype=get_autocast_dtype(precision),
        trust_remote_code=True,
    )
    hf_model = hf_model.to("cuda").eval()
    hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True)
    hf_hidden_state = hf_output_all.hidden_states[-1]
    return hf_output_all.logits, hf_hidden_state

load_and_evaluate_nemo_esm2(ckpt_path, precision, input_ids, attention_mask)

Load a NeMo2 ESM-2 model and evaluate it on the given inputs.

Parameters:

Name Type Description Default
ckpt_path Path | str

A path to a NeMo2 checkpoint for an ESM-2 model.

required
precision PrecisionTypes

The precision type to use for the comparison.

required
input_ids Tensor

The input IDs tensor to evaluate.

required
attention_mask Tensor

The attention mask tensor to evaluate.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of the logits and hidden states tensors calculated by the NeMo2 model, respectively.

Source code in bionemo/esm2/testing/compare.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def load_and_evaluate_nemo_esm2(
    ckpt_path: Path | str,
    precision: PrecisionTypes,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Load a NeMo2 ESM-2 model and evaluate it on the given inputs.

    Args:
        ckpt_path: A path to a NeMo2 checkpoint for an ESM-2 model.
        precision: The precision type to use for the comparison.
        input_ids: The input IDs tensor to evaluate.
        attention_mask: The attention mask tensor to evaluate.

    Returns:
        A tuple of the logits and hidden states tensors calculated by the NeMo2 model, respectively.
    """
    tokenizer = get_tokenizer()

    dtype = get_autocast_dtype(precision)
    nemo_config = ESM2Config(
        initial_ckpt_path=str(ckpt_path),
        include_embeddings=True,
        include_hiddens=True,
        params_dtype=dtype,
        pipeline_dtype=dtype,
        autocast_dtype=dtype,
        bf16=dtype is torch.bfloat16,
        fp16=dtype is torch.float16,
    )

    nemo_model = nemo_config.configure_model(tokenizer).to("cuda").eval()

    if dtype is torch.float16 or dtype is torch.bfloat16:
        nemo_model = Float16Module(nemo_config, nemo_model)

    nemo_output = nemo_model(input_ids, attention_mask)
    nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size]
    nemo_hidden_state = nemo_output["hidden_states"]
    return nemo_logits, nemo_hidden_state