Skip to content

Peft

ESM2LoRA

Bases: LoRA

LoRA for the BioNeMo2 ESM Model.

Source code in bionemo/esm2/model/finetune/peft.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class ESM2LoRA(LoRA):
    """LoRA for the BioNeMo2 ESM Model."""

    def __init__(
        self,
        peft_ckpt_path: Optional[str] = None,
        freeze_modules: List[str] = ["encoder", "embedding"],
        *args,
        **kwarg,
    ):
        """Initialize the LoRA Adapter.

        Args:
            peft_ckpt_path: config for peft chekpoint.
            freeze_modules: modules to freeze.
            *args: args for the LoRA class.
            **kwarg: kwargs for the LoRA class.
        """
        super().__init__(*args, **kwarg)
        self.freeze_modules = freeze_modules
        self.peft_ckpt_path = peft_ckpt_path

    def setup(self, *args, **kwarg):
        """Initialize the LoRA Adapter. Pass the peft_ckpt_path to the wrapped io.

        Args:
            *args: args for the LoRA class.
            **kwarg: kwargs for the LoRA class.
        """
        super().setup(*args, **kwarg)
        self.wrapped_io.adapter_ckpt_path = self.peft_ckpt_path

    def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        """Event hook.

        Args:
            trainer: The trainer object.
            pl_module: The LightningModule object.
        """
        self._maybe_apply_transform(trainer)

    def adapter_key_filter(self, key: str) -> bool:
        """Given a key in the state dict, return whether the key is an adapter (or base model).

        Args:
            key: the key to filter
        """
        if isinstance(key, tuple):
            return key[1].requires_grad
        if "_extra_state" in key:
            return False
        return (
            (not any(substring in key for substring in self.freeze_modules))
            or ".adapter." in key
            or key.endswith(".adapters")
        )

    def __call__(self, model: nn.Module) -> nn.Module:
        """This method is called when the object is called as a function.

        Args:
            model: The input model.

            The modified model.
        """
        fn.walk(model, self.selective_freeze)
        fn.walk(model, self.transform)
        return model

    def selective_freeze(self, m: nn.Module, name=None, prefix=None):
        """Freezes specific modules in the given model.

        Args:
            m (nn.Module): The model to selectively freeze.
            name (str): The name of the module to freeze. Valid values are "encoder" and "embedding".
            prefix (str): The prefix of the module to freeze.

        Returns:
            nn.Module: The modified model with the specified modules frozen.

        See Also:
            nemo.collections.llm.fn.mixin.FNMixin
        """
        if name in self.freeze_modules:
            FNMixin.freeze(m)
        return m

__call__(model)

This method is called when the object is called as a function.

Parameters:

Name Type Description Default
model Module

The input model.

required
Source code in bionemo/esm2/model/finetune/peft.py
 97
 98
 99
100
101
102
103
104
105
106
107
def __call__(self, model: nn.Module) -> nn.Module:
    """This method is called when the object is called as a function.

    Args:
        model: The input model.

        The modified model.
    """
    fn.walk(model, self.selective_freeze)
    fn.walk(model, self.transform)
    return model

__init__(peft_ckpt_path=None, freeze_modules=['encoder', 'embedding'], *args, **kwarg)

Initialize the LoRA Adapter.

Parameters:

Name Type Description Default
peft_ckpt_path Optional[str]

config for peft chekpoint.

None
freeze_modules List[str]

modules to freeze.

['encoder', 'embedding']
*args

args for the LoRA class.

()
**kwarg

kwargs for the LoRA class.

{}
Source code in bionemo/esm2/model/finetune/peft.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    peft_ckpt_path: Optional[str] = None,
    freeze_modules: List[str] = ["encoder", "embedding"],
    *args,
    **kwarg,
):
    """Initialize the LoRA Adapter.

    Args:
        peft_ckpt_path: config for peft chekpoint.
        freeze_modules: modules to freeze.
        *args: args for the LoRA class.
        **kwarg: kwargs for the LoRA class.
    """
    super().__init__(*args, **kwarg)
    self.freeze_modules = freeze_modules
    self.peft_ckpt_path = peft_ckpt_path

adapter_key_filter(key)

Given a key in the state dict, return whether the key is an adapter (or base model).

Parameters:

Name Type Description Default
key str

the key to filter

required
Source code in bionemo/esm2/model/finetune/peft.py
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def adapter_key_filter(self, key: str) -> bool:
    """Given a key in the state dict, return whether the key is an adapter (or base model).

    Args:
        key: the key to filter
    """
    if isinstance(key, tuple):
        return key[1].requires_grad
    if "_extra_state" in key:
        return False
    return (
        (not any(substring in key for substring in self.freeze_modules))
        or ".adapter." in key
        or key.endswith(".adapters")
    )

on_predict_epoch_start(trainer, pl_module)

Event hook.

Parameters:

Name Type Description Default
trainer Trainer

The trainer object.

required
pl_module LightningModule

The LightningModule object.

required
Source code in bionemo/esm2/model/finetune/peft.py
72
73
74
75
76
77
78
79
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Event hook.

    Args:
        trainer: The trainer object.
        pl_module: The LightningModule object.
    """
    self._maybe_apply_transform(trainer)

selective_freeze(m, name=None, prefix=None)

Freezes specific modules in the given model.

Parameters:

Name Type Description Default
m Module

The model to selectively freeze.

required
name str

The name of the module to freeze. Valid values are "encoder" and "embedding".

None
prefix str

The prefix of the module to freeze.

None

Returns:

Type Description

nn.Module: The modified model with the specified modules frozen.

See Also

nemo.collections.llm.fn.mixin.FNMixin

Source code in bionemo/esm2/model/finetune/peft.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def selective_freeze(self, m: nn.Module, name=None, prefix=None):
    """Freezes specific modules in the given model.

    Args:
        m (nn.Module): The model to selectively freeze.
        name (str): The name of the module to freeze. Valid values are "encoder" and "embedding".
        prefix (str): The prefix of the module to freeze.

    Returns:
        nn.Module: The modified model with the specified modules frozen.

    See Also:
        nemo.collections.llm.fn.mixin.FNMixin
    """
    if name in self.freeze_modules:
        FNMixin.freeze(m)
    return m

setup(*args, **kwarg)

Initialize the LoRA Adapter. Pass the peft_ckpt_path to the wrapped io.

Parameters:

Name Type Description Default
*args

args for the LoRA class.

()
**kwarg

kwargs for the LoRA class.

{}
Source code in bionemo/esm2/model/finetune/peft.py
62
63
64
65
66
67
68
69
70
def setup(self, *args, **kwarg):
    """Initialize the LoRA Adapter. Pass the peft_ckpt_path to the wrapped io.

    Args:
        *args: args for the LoRA class.
        **kwarg: kwargs for the LoRA class.
    """
    super().setup(*args, **kwarg)
    self.wrapped_io.adapter_ckpt_path = self.peft_ckpt_path