Skip to content

Peft

ESM2LoRA

Bases: LoRA

LoRA for the BioNeMo2 ESM Model.

Source code in bionemo/esm2/model/finetune/peft.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class ESM2LoRA(LoRA):
    """LoRA for the BioNeMo2 ESM Model."""

    def __call__(self, model: nn.Module) -> nn.Module:
        """This method is called when the object is called as a function.

        Args:
            model: The input model.

        Returns:
            The modified model.
        """
        fn.walk(model, self.selective_freeze)
        fn.walk(model, self.transform)
        return model

    def selective_freeze(self, m: nn.Module, name=None, prefix=None):
        """Freezes specific modules in the given model.

        Args:
            m (nn.Module): The model to selectively freeze.
            name (str): The name of the module to freeze. Valid values are "encoder" and "embedding".
            prefix (str): The prefix of the module to freeze.

        Returns:
            nn.Module: The modified model with the specified modules frozen.

        See Also:
            nemo.collections.llm.fn.mixin.FNMixin
        """
        if name in ["encoder", "embedding"]:
            FNMixin.freeze(m)
        return m

__call__(model)

This method is called when the object is called as a function.

Parameters:

Name Type Description Default
model Module

The input model.

required

Returns:

Type Description
Module

The modified model.

Source code in bionemo/esm2/model/finetune/peft.py
40
41
42
43
44
45
46
47
48
49
50
51
def __call__(self, model: nn.Module) -> nn.Module:
    """This method is called when the object is called as a function.

    Args:
        model: The input model.

    Returns:
        The modified model.
    """
    fn.walk(model, self.selective_freeze)
    fn.walk(model, self.transform)
    return model

selective_freeze(m, name=None, prefix=None)

Freezes specific modules in the given model.

Parameters:

Name Type Description Default
m Module

The model to selectively freeze.

required
name str

The name of the module to freeze. Valid values are "encoder" and "embedding".

None
prefix str

The prefix of the module to freeze.

None

Returns:

Type Description

nn.Module: The modified model with the specified modules frozen.

See Also

nemo.collections.llm.fn.mixin.FNMixin

Source code in bionemo/esm2/model/finetune/peft.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def selective_freeze(self, m: nn.Module, name=None, prefix=None):
    """Freezes specific modules in the given model.

    Args:
        m (nn.Module): The model to selectively freeze.
        name (str): The name of the module to freeze. Valid values are "encoder" and "embedding".
        prefix (str): The prefix of the module to freeze.

    Returns:
        nn.Module: The modified model with the specified modules frozen.

    See Also:
        nemo.collections.llm.fn.mixin.FNMixin
    """
    if name in ["encoder", "embedding"]:
        FNMixin.freeze(m)
    return m