Skip to content

Finetune token classifier

ClassifierLossReduction

Bases: BERTMLMLossWithReduction

A class for calculating the cross entropy loss of classification output.

This class used for calculating the loss, and for logging the reduced loss across micro batches.

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class ClassifierLossReduction(BERTMLMLossWithReduction):
    """A class for calculating the cross entropy loss of classification output.

    This class used for calculating the loss, and for logging the reduced loss across micro batches.
    """

    def forward(
        self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
    ) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

        Args:
            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
            forward_out: the output of the forward method inside classification head.

        Returns:
            A tuple where the loss tensor will be used for backpropagation and the dict will be passed to
            the reduce method, which currently only works for logging.
        """
        targets = batch["labels"]  # [b, s]
        # [b, s, num_class] -> [b, num_class, s] to satisfy input dims for cross_entropy loss
        classification_output = forward_out["classification_output"].permute(0, 2, 1)
        loss_mask = batch["loss_mask"]  # [b, s]

        cp_size = parallel_state.get_context_parallel_world_size()
        if cp_size == 1:
            losses = torch.nn.functional.cross_entropy(classification_output, targets, reduction="none")
            # losses may contain NaNs at masked locations. We use masked_select to filter out these NaNs
            masked_loss = torch.masked_select(losses, loss_mask)
            loss = masked_loss.sum() / loss_mask.sum()
        else:  # TODO: support CP with masked_token_loss_context_parallel
            raise NotImplementedError("Context Parallel support is not implemented for this loss")

        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
        """Works across micro-batches. (data on single gpu).

        Note: This currently only works for logging and this loss will not be used for backpropagation.

        Args:
            losses_reduced_per_micro_batch: a list of the outputs of forward

        Returns:
            A tensor that is the mean of the losses. (used for logging).
        """
        losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return losses.mean()

forward(batch, forward_out)

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

Parameters:

Name Type Description Default
batch Dict[str, Tensor]

A batch of data that gets passed to the original forward inside LitAutoEncoder.

required
forward_out Dict[str, Tensor]

the output of the forward method inside classification head.

required

Returns:

Type Description
Tensor

A tuple where the loss tensor will be used for backpropagation and the dict will be passed to

PerTokenLossDict | SameSizeLossDict

the reduce method, which currently only works for logging.

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def forward(
    self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

    Args:
        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
        forward_out: the output of the forward method inside classification head.

    Returns:
        A tuple where the loss tensor will be used for backpropagation and the dict will be passed to
        the reduce method, which currently only works for logging.
    """
    targets = batch["labels"]  # [b, s]
    # [b, s, num_class] -> [b, num_class, s] to satisfy input dims for cross_entropy loss
    classification_output = forward_out["classification_output"].permute(0, 2, 1)
    loss_mask = batch["loss_mask"]  # [b, s]

    cp_size = parallel_state.get_context_parallel_world_size()
    if cp_size == 1:
        losses = torch.nn.functional.cross_entropy(classification_output, targets, reduction="none")
        # losses may contain NaNs at masked locations. We use masked_select to filter out these NaNs
        masked_loss = torch.masked_select(losses, loss_mask)
        loss = masked_loss.sum() / loss_mask.sum()
    else:  # TODO: support CP with masked_token_loss_context_parallel
        raise NotImplementedError("Context Parallel support is not implemented for this loss")

    return loss, {"avg": loss}

reduce(losses_reduced_per_micro_batch)

Works across micro-batches. (data on single gpu).

Note: This currently only works for logging and this loss will not be used for backpropagation.

Parameters:

Name Type Description Default
losses_reduced_per_micro_batch Sequence[SameSizeLossDict]

a list of the outputs of forward

required

Returns:

Type Description
Tensor

A tensor that is the mean of the losses. (used for logging).

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
79
80
81
82
83
84
85
86
87
88
89
90
91
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
    """Works across micro-batches. (data on single gpu).

    Note: This currently only works for logging and this loss will not be used for backpropagation.

    Args:
        losses_reduced_per_micro_batch: a list of the outputs of forward

    Returns:
        A tensor that is the mean of the losses. (used for logging).
    """
    losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
    return losses.mean()

ESM2FineTuneTokenConfig dataclass

Bases: ESM2GenericConfig[ESM2FineTuneTokenModel, ClassifierLossReduction], IOMixinWithGettersSetters

ExampleConfig is a dataclass that is used to configure the model.

Timers from ModelParallelConfig are required for megatron forward compatibility.

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
@dataclass
class ESM2FineTuneTokenConfig(
    ESM2GenericConfig[ESM2FineTuneTokenModel, ClassifierLossReduction], iom.IOMixinWithGettersSetters
):
    """ExampleConfig is a dataclass that is used to configure the model.

    Timers from ModelParallelConfig are required for megatron forward compatibility.
    """

    model_cls: Type[ESM2FineTuneTokenModel] = ESM2FineTuneTokenModel
    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
    # that has this new head and want to keep using these weights, please drop this next line or set to []
    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["classification_head"])

    encoder_frozen: bool = True  # freeze encoder parameters
    cnn_num_classes: int = 3  # number of classes in each label
    cnn_dropout: float = 0.25
    cnn_hidden_dim: int = 32  # The number of output channels in the bottleneck layer of the convolution.

    def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]:
        """The loss function type."""
        return ClassifierLossReduction

get_loss_reduction_class()

The loss function type.

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
181
182
183
def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]:
    """The loss function type."""
    return ClassifierLossReduction

ESM2FineTuneTokenModel

Bases: ESM2Model

An ESM2 model that is suitable for fine tuning.

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class ESM2FineTuneTokenModel(ESM2Model):
    """An ESM2 model that is suitable for fine tuning."""

    def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
        """Constructor."""
        super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)

        # freeze encoder parameters
        if config.encoder_frozen:
            for _, param in self.named_parameters():
                param.requires_grad = False

        self.include_hiddens_finetuning = (
            include_hiddens  # this include_hiddens is for the final output of fine-tuning
        )
        # If post_process is True that means that we are at the last megatron parallelism stage and we can
        #   apply the head.
        if post_process:
            # if we are doing post process (eg pipeline last stage) then we need to add the output layers
            self.classification_head = MegatronConvNetHead(config)

    def forward(self, *args, **kwargs) -> Tensor | BioBertOutput:
        """Inference."""
        output = super().forward(*args, **kwargs)
        # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
        if not self.post_process:
            return output  # we are not at the last pipeline stage so just return what the parent has
        # Double check that the output from the parent has everything we need to do prediction in this head.
        if not isinstance(output, dict) or "hidden_states" not in output:
            raise ValueError(
                f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
                "Make sure include_hiddens=True in the call to super().__init__"
            )
        # Get the hidden state from the parent output, and pull out the [CLS] token for this task
        hidden_states: Tensor = output["hidden_states"]
        # Predict our 1d regression target
        classification_output = self.classification_head(hidden_states)
        if not self.include_hiddens_finetuning:
            del output["hidden_states"]
        output["classification_output"] = classification_output
        return output

__init__(config, *args, include_hiddens=False, post_process=True, **kwargs)

Constructor.

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
    """Constructor."""
    super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)

    # freeze encoder parameters
    if config.encoder_frozen:
        for _, param in self.named_parameters():
            param.requires_grad = False

    self.include_hiddens_finetuning = (
        include_hiddens  # this include_hiddens is for the final output of fine-tuning
    )
    # If post_process is True that means that we are at the last megatron parallelism stage and we can
    #   apply the head.
    if post_process:
        # if we are doing post process (eg pipeline last stage) then we need to add the output layers
        self.classification_head = MegatronConvNetHead(config)

forward(*args, **kwargs)

Inference.

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def forward(self, *args, **kwargs) -> Tensor | BioBertOutput:
    """Inference."""
    output = super().forward(*args, **kwargs)
    # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
    if not self.post_process:
        return output  # we are not at the last pipeline stage so just return what the parent has
    # Double check that the output from the parent has everything we need to do prediction in this head.
    if not isinstance(output, dict) or "hidden_states" not in output:
        raise ValueError(
            f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
            "Make sure include_hiddens=True in the call to super().__init__"
        )
    # Get the hidden state from the parent output, and pull out the [CLS] token for this task
    hidden_states: Tensor = output["hidden_states"]
    # Predict our 1d regression target
    classification_output = self.classification_head(hidden_states)
    if not self.include_hiddens_finetuning:
        del output["hidden_states"]
    output["classification_output"] = classification_output
    return output

MegatronConvNetHead

Bases: MegatronModule

A convolutional neural network class for residue-level classification.

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class MegatronConvNetHead(MegatronModule):
    """A convolutional neural network class for residue-level classification."""

    def __init__(self, config: TransformerConfig):
        """Constructor."""
        super().__init__(config)

        self.finetune_model = torch.nn.Sequential(
            torch.nn.Conv2d(config.hidden_size, config.cnn_hidden_dim, kernel_size=(7, 1), padding=(3, 0)),  # 7x32
            torch.nn.ReLU(),
            torch.nn.Dropout(config.cnn_dropout),
        )
        # class_heads (torch.nn.ModuleList): A list of convolutional layers, each corresponding to a different class head.
        # These are used for producing logits scores of varying sizes as specified in `output_sizes`.
        self.class_heads = torch.nn.Conv2d(32, config.cnn_num_classes, kernel_size=(7, 1), padding=(3, 0))

    def forward(self, hidden_states: Tensor) -> List[Tensor]:
        """Inference."""
        # [b, s, h] -> [b, h, s, 1]
        hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(dim=-1)
        hidden_states = self.finetune_model(hidden_states)  # [b, 32, s, 1]
        output = self.class_heads(hidden_states).squeeze(dim=-1).permute(0, 2, 1)  # [b, s, output_size]
        return output

__init__(config)

Constructor.

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
def __init__(self, config: TransformerConfig):
    """Constructor."""
    super().__init__(config)

    self.finetune_model = torch.nn.Sequential(
        torch.nn.Conv2d(config.hidden_size, config.cnn_hidden_dim, kernel_size=(7, 1), padding=(3, 0)),  # 7x32
        torch.nn.ReLU(),
        torch.nn.Dropout(config.cnn_dropout),
    )
    # class_heads (torch.nn.ModuleList): A list of convolutional layers, each corresponding to a different class head.
    # These are used for producing logits scores of varying sizes as specified in `output_sizes`.
    self.class_heads = torch.nn.Conv2d(32, config.cnn_num_classes, kernel_size=(7, 1), padding=(3, 0))

forward(hidden_states)

Inference.

Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py
110
111
112
113
114
115
116
def forward(self, hidden_states: Tensor) -> List[Tensor]:
    """Inference."""
    # [b, s, h] -> [b, h, s, 1]
    hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(dim=-1)
    hidden_states = self.finetune_model(hidden_states)  # [b, 32, s, 1]
    output = self.class_heads(hidden_states).squeeze(dim=-1).permute(0, 2, 1)  # [b, s, output_size]
    return output