Skip to content

Token model

ESM2FineTuneTokenConfig dataclass

Bases: ESM2GenericConfig[ESM2FineTuneTokenModel, ClassifierLossReduction], IOMixinWithGettersSetters

ExampleConfig is a dataclass that is used to configure the model.

Timers from ModelParallelConfig are required for megatron forward compatibility.

Source code in bionemo/esm2/model/finetune/token_model.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@dataclass
class ESM2FineTuneTokenConfig(
    ESM2GenericConfig[ESM2FineTuneTokenModel, ClassifierLossReduction], iom.IOMixinWithGettersSetters
):
    """ExampleConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    model_cls: Type[ESM2FineTuneTokenModel] = ESM2FineTuneTokenModel
    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
    # that has this new head and want to keep using these weights, please drop this next line or set to []
    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["classification_head"])

    task_type: Literal["classification", "regression"] = "classification"
    encoder_frozen: bool = True  # freeze encoder parameters
    cnn_num_classes: int = 3  # number of classes in each label
    cnn_dropout: float = 0.25
    cnn_hidden_size: int = 32  # The number of output channels in the bottleneck layer of the convolution.

    def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]:
        """The loss function type."""
        return ClassifierLossReduction

get_loss_reduction_class()

The loss function type.

Source code in bionemo/esm2/model/finetune/token_model.py
132
133
134
def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]:
    """The loss function type."""
    return ClassifierLossReduction

ESM2FineTuneTokenModel

Bases: ESM2Model

An ESM2 model that is suitable for fine tuning.

Source code in bionemo/esm2/model/finetune/token_model.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class ESM2FineTuneTokenModel(ESM2Model):
    """An ESM2 model that is suitable for fine tuning."""

    def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
        """Constructor."""
        super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)

        # freeze encoder parameters
        if config.encoder_frozen:
            for _, param in self.named_parameters():
                param.requires_grad = False

        self.include_hiddens_finetuning = (
            include_hiddens  # this include_hiddens is for the final output of fine-tuning
        )
        # If post_process is True that means that we are at the last megatron parallelism stage and we can
        #   apply the head.
        if post_process:
            self.task_type = config.task_type
            # if we are doing post process (eg pipeline last stage) then we need to add the output layers
            self.head_name = f"{self.task_type}_head"  # Example: 'regression_head' or 'classification_head'
            setattr(self, self.head_name, MegatronConvNetHead(config))

    def forward(self, *args, **kwargs) -> Tensor | BioBertOutput:
        """Inference."""
        output = super().forward(*args, **kwargs)
        # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
        if not self.post_process:
            return output  # we are not at the last pipeline stage so just return what the parent has
        # Double check that the output from the parent has everything we need to do prediction in this head.
        if not isinstance(output, dict) or "hidden_states" not in output:
            raise ValueError(
                f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
                "Make sure include_hiddens=True in the call to super().__init__"
            )
        # Get the hidden state from the parent output, and pull out the [CLS] token for this task
        hidden_states: Tensor = output["hidden_states"]
        # Predict our 1d regression target
        task_head = getattr(self, self.head_name)
        output[f"{self.task_type}_output"] = task_head(hidden_states)
        if not self.include_hiddens_finetuning:
            del output["hidden_states"]
        return output

__init__(config, *args, include_hiddens=False, post_process=True, **kwargs)

Constructor.

Source code in bionemo/esm2/model/finetune/token_model.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
    """Constructor."""
    super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)

    # freeze encoder parameters
    if config.encoder_frozen:
        for _, param in self.named_parameters():
            param.requires_grad = False

    self.include_hiddens_finetuning = (
        include_hiddens  # this include_hiddens is for the final output of fine-tuning
    )
    # If post_process is True that means that we are at the last megatron parallelism stage and we can
    #   apply the head.
    if post_process:
        self.task_type = config.task_type
        # if we are doing post process (eg pipeline last stage) then we need to add the output layers
        self.head_name = f"{self.task_type}_head"  # Example: 'regression_head' or 'classification_head'
        setattr(self, self.head_name, MegatronConvNetHead(config))

forward(*args, **kwargs)

Inference.

Source code in bionemo/esm2/model/finetune/token_model.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def forward(self, *args, **kwargs) -> Tensor | BioBertOutput:
    """Inference."""
    output = super().forward(*args, **kwargs)
    # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
    if not self.post_process:
        return output  # we are not at the last pipeline stage so just return what the parent has
    # Double check that the output from the parent has everything we need to do prediction in this head.
    if not isinstance(output, dict) or "hidden_states" not in output:
        raise ValueError(
            f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
            "Make sure include_hiddens=True in the call to super().__init__"
        )
    # Get the hidden state from the parent output, and pull out the [CLS] token for this task
    hidden_states: Tensor = output["hidden_states"]
    # Predict our 1d regression target
    task_head = getattr(self, self.head_name)
    output[f"{self.task_type}_output"] = task_head(hidden_states)
    if not self.include_hiddens_finetuning:
        del output["hidden_states"]
    return output

MegatronConvNetHead

Bases: MegatronModule

A convolutional neural network class for residue-level classification.

Source code in bionemo/esm2/model/finetune/token_model.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class MegatronConvNetHead(MegatronModule):
    """A convolutional neural network class for residue-level classification."""

    def __init__(self, config: TransformerConfig):
        """Constructor."""
        super().__init__(config)

        self.finetune_model = torch.nn.Sequential(
            torch.nn.Conv2d(config.hidden_size, config.cnn_hidden_size, kernel_size=(7, 1), padding=(3, 0)),  # 7x32
            torch.nn.ReLU(),
            torch.nn.Dropout(config.cnn_dropout),
        )
        # class_heads (torch.nn.ModuleList): A list of convolutional layers, each corresponding to a different class head.
        # These are used for producing logits scores of varying sizes as specified in `output_sizes`.
        self.class_heads = torch.nn.Conv2d(32, config.cnn_num_classes, kernel_size=(7, 1), padding=(3, 0))

    def forward(self, hidden_states: Tensor) -> List[Tensor]:
        """Inference."""
        # [b, s, h] -> [b, h, s, 1]
        hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(dim=-1)
        hidden_states = self.finetune_model(hidden_states)  # [b, 32, s, 1]
        output = self.class_heads(hidden_states).squeeze(dim=-1).permute(0, 2, 1)  # [b, s, output_size]
        return output

__init__(config)

Constructor.

Source code in bionemo/esm2/model/finetune/token_model.py
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(self, config: TransformerConfig):
    """Constructor."""
    super().__init__(config)

    self.finetune_model = torch.nn.Sequential(
        torch.nn.Conv2d(config.hidden_size, config.cnn_hidden_size, kernel_size=(7, 1), padding=(3, 0)),  # 7x32
        torch.nn.ReLU(),
        torch.nn.Dropout(config.cnn_dropout),
    )
    # class_heads (torch.nn.ModuleList): A list of convolutional layers, each corresponding to a different class head.
    # These are used for producing logits scores of varying sizes as specified in `output_sizes`.
    self.class_heads = torch.nn.Conv2d(32, config.cnn_num_classes, kernel_size=(7, 1), padding=(3, 0))

forward(hidden_states)

Inference.

Source code in bionemo/esm2/model/finetune/token_model.py
58
59
60
61
62
63
64
def forward(self, hidden_states: Tensor) -> List[Tensor]:
    """Inference."""
    # [b, s, h] -> [b, h, s, 1]
    hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(dim=-1)
    hidden_states = self.finetune_model(hidden_states)  # [b, 32, s, 1]
    output = self.class_heads(hidden_states).squeeze(dim=-1).permute(0, 2, 1)  # [b, s, output_size]
    return output