Skip to content

Embedding

ESM2Embedding

Bases: LanguageModelEmbedding

ESM2 Embedding with custom logic for attention masking and token dropout.

Source code in bionemo/esm2/model/embedding.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class ESM2Embedding(LanguageModelEmbedding):
    """ESM2 Embedding with custom logic for attention masking and token dropout."""

    def __init__(
        self,
        config: TransformerConfig,
        vocab_size: int,
        max_sequence_length: int,
        position_embedding_type: Literal["learned_absolute", "rope"] = "rope",
        num_tokentypes: int = 0,
        # ESM2 NEW ARGS
        token_dropout: bool = True,
        use_attention_mask: bool = True,
        mask_token_id: Optional[int] = torch.nan,
    ) -> None:
        """Initialize the ESM2 Embedding module."""
        super().__init__(
            config=config,
            vocab_size=vocab_size,
            max_sequence_length=max_sequence_length,
            position_embedding_type=position_embedding_type,
            num_tokentypes=num_tokentypes,
        )
        self.token_dropout = token_dropout
        self.use_attention_mask = use_attention_mask
        self.mask_token_id = mask_token_id

    @property
    def dtype(self) -> torch.dtype:
        """The dtype of the embedding weights."""
        return self.word_embeddings.weight.dtype

    def _apply_esm2_customization(
        self, word_embeddings: Tensor, input_ids: Tensor, attention_mask: Tensor
    ) -> Tuple[Tensor, Tensor]:
        """ESM2 customization for attention masking and token dropout.

        Args:
            word_embeddings (Tensor[float]): The input tokens. Shape: [b, s, h]
            input_ids (Tensor[int]): The input tokens. Shape: [b, s]
            attention_mask (Tensor[bool]): attention mask. Shape: [b, s]

        Returns:
            Tuple[Tensor, Tensor]: (Updated embeddings, embedding mask) Shape: ([b, s, h], [b, s])
        """
        embeddings_mask = None
        if attention_mask is not None and (self.token_dropout or self.use_attention_mask):
            embeddings_mask = attention_mask

        if embeddings_mask is not None and self.token_dropout:
            word_embeddings = word_embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
            src_lengths = embeddings_mask.sum(-1)
            mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).to(self.dtype) / src_lengths

            scale_factor = (1 - ESM2_MASK_RATIO_TRAIN) / (1 - mask_ratio_observed)[:, None, None]
            word_embeddings = (word_embeddings * scale_factor).to(word_embeddings.dtype)
        if embeddings_mask is not None and self.use_attention_mask:
            word_embeddings = (word_embeddings * embeddings_mask.unsqueeze(-1)).to(word_embeddings.dtype)
        return word_embeddings, embeddings_mask

    def forward(
        self,
        input_ids: Tensor,
        position_ids: Tensor,
        tokentype_ids: Optional[int] = None,
        attention_mask: Optional[Tensor] = None,
    ) -> Tensor:
        """Forward pass of the embedding module.

        Args:
            input_ids (Tensor): The input tokens. Shape: [b, s]
            position_ids (Tensor): The position id's used to calculate position embeddings. Shape: [b, s]
            tokentype_ids (int, optional): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None
            attention_mask (Tensor): attention mask. Shape: [b, s]

        Returns:
            Tensor: The output embeddings
        """
        word_embeddings = self.word_embeddings(input_ids)  # [b, s, h]

        # ESM2 Customization
        word_embeddings, embeddings_mask = self._apply_esm2_customization(word_embeddings, input_ids, attention_mask)

        if self.add_position_embedding:
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = word_embeddings + position_embeddings
        else:
            embeddings = word_embeddings

        # ESM2 Customization: include attention masking from ESM2
        if embeddings_mask is not None and self.use_attention_mask:
            embeddings = (embeddings * embeddings_mask.unsqueeze(-1)).to(embeddings.dtype)

        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
        embeddings = embeddings.transpose(0, 1).contiguous()

        if tokentype_ids is not None:
            if self.tokentype_embeddings is None:
                raise ValueError("tokentype_embedding is needed to process tokentype_ids")
            # [b s h] -> [s b h] (So that it can be added with embeddings)
            tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
            embeddings = embeddings + tokentype_embedding
        else:
            assert self.tokentype_embeddings is None

        # If the input flag for fp32 residual connection is set, convert for float.
        if self.config.fp32_residual_connection:
            embeddings = embeddings.float()

        # Dropout.
        if self.config.sequence_parallel:
            embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
            # `scatter_to_sequence_parallel_region` returns a view, which prevents
            # the original tensor from being garbage collected. Clone to facilitate GC.
            # Has a small runtime cost (~0.5%).
            if self.config.clone_scatter_output_in_embedding:
                embeddings = embeddings.clone()
            with tensor_parallel.get_cuda_rng_tracker().fork():
                embeddings = self.embedding_dropout(embeddings)
        else:
            embeddings = self.embedding_dropout(embeddings)

        return embeddings

dtype: torch.dtype property

The dtype of the embedding weights.

__init__(config, vocab_size, max_sequence_length, position_embedding_type='rope', num_tokentypes=0, token_dropout=True, use_attention_mask=True, mask_token_id=torch.nan)

Initialize the ESM2 Embedding module.

Source code in bionemo/esm2/model/embedding.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def __init__(
    self,
    config: TransformerConfig,
    vocab_size: int,
    max_sequence_length: int,
    position_embedding_type: Literal["learned_absolute", "rope"] = "rope",
    num_tokentypes: int = 0,
    # ESM2 NEW ARGS
    token_dropout: bool = True,
    use_attention_mask: bool = True,
    mask_token_id: Optional[int] = torch.nan,
) -> None:
    """Initialize the ESM2 Embedding module."""
    super().__init__(
        config=config,
        vocab_size=vocab_size,
        max_sequence_length=max_sequence_length,
        position_embedding_type=position_embedding_type,
        num_tokentypes=num_tokentypes,
    )
    self.token_dropout = token_dropout
    self.use_attention_mask = use_attention_mask
    self.mask_token_id = mask_token_id

forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None)

Forward pass of the embedding module.

Parameters:

Name Type Description Default
input_ids Tensor

The input tokens. Shape: [b, s]

required
position_ids Tensor

The position id's used to calculate position embeddings. Shape: [b, s]

required
tokentype_ids int

The token type ids. Used when args.bert_binary_head is set to True. Defaults to None

None
attention_mask Tensor

attention mask. Shape: [b, s]

None

Returns:

Name Type Description
Tensor Tensor

The output embeddings

Source code in bionemo/esm2/model/embedding.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def forward(
    self,
    input_ids: Tensor,
    position_ids: Tensor,
    tokentype_ids: Optional[int] = None,
    attention_mask: Optional[Tensor] = None,
) -> Tensor:
    """Forward pass of the embedding module.

    Args:
        input_ids (Tensor): The input tokens. Shape: [b, s]
        position_ids (Tensor): The position id's used to calculate position embeddings. Shape: [b, s]
        tokentype_ids (int, optional): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None
        attention_mask (Tensor): attention mask. Shape: [b, s]

    Returns:
        Tensor: The output embeddings
    """
    word_embeddings = self.word_embeddings(input_ids)  # [b, s, h]

    # ESM2 Customization
    word_embeddings, embeddings_mask = self._apply_esm2_customization(word_embeddings, input_ids, attention_mask)

    if self.add_position_embedding:
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = word_embeddings + position_embeddings
    else:
        embeddings = word_embeddings

    # ESM2 Customization: include attention masking from ESM2
    if embeddings_mask is not None and self.use_attention_mask:
        embeddings = (embeddings * embeddings_mask.unsqueeze(-1)).to(embeddings.dtype)

    # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
    embeddings = embeddings.transpose(0, 1).contiguous()

    if tokentype_ids is not None:
        if self.tokentype_embeddings is None:
            raise ValueError("tokentype_embedding is needed to process tokentype_ids")
        # [b s h] -> [s b h] (So that it can be added with embeddings)
        tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
        embeddings = embeddings + tokentype_embedding
    else:
        assert self.tokentype_embeddings is None

    # If the input flag for fp32 residual connection is set, convert for float.
    if self.config.fp32_residual_connection:
        embeddings = embeddings.float()

    # Dropout.
    if self.config.sequence_parallel:
        embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
        # `scatter_to_sequence_parallel_region` returns a view, which prevents
        # the original tensor from being garbage collected. Clone to facilitate GC.
        # Has a small runtime cost (~0.5%).
        if self.config.clone_scatter_output_in_embedding:
            embeddings = embeddings.clone()
        with tensor_parallel.get_cuda_rng_tracker().fork():
            embeddings = self.embedding_dropout(embeddings)
    else:
        embeddings = self.embedding_dropout(embeddings)

    return embeddings