Skip to content

Convert

HFESM2Importer

Bases: ModelConnector[AutoModelForMaskedLM, BionemoLightningModule]

Converts a Hugging Face ESM-2 model to a NeMo ESM-2 model.

Source code in bionemo/esm2/model/convert.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
@io.model_importer(BionemoLightningModule, "hf")
class HFESM2Importer(io.ModelConnector[AutoModelForMaskedLM, BionemoLightningModule]):
    """Converts a Hugging Face ESM-2 model to a NeMo ESM-2 model."""

    def init(self) -> BionemoLightningModule:
        """Initialize the converted model."""
        return biobert_lightning_module(self.config, tokenizer=self.tokenizer)

    def apply(self, output_path: Path) -> Path:
        """Applies the transformation.

        Largely inspired by
        https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/features/hf-integration.html
        """
        source = AutoModelForMaskedLM.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto")
        target = self.init()
        trainer = self.nemo_setup(target)
        self.convert_state(source, target)
        self.nemo_save(output_path, trainer)

        print(f"Converted ESM-2 model to Nemo, model saved to {output_path}")

        teardown(trainer, target)
        del trainer, target

        return output_path

    def convert_state(self, source, target):
        """Converting HF state dict to NeMo state dict."""
        mapping = {
            # "esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq": "rotary_pos_emb.inv_freq",
            "esm.encoder.layer.*.attention.output.dense.weight": "encoder.layers.*.self_attention.linear_proj.weight",
            "esm.encoder.layer.*.attention.output.dense.bias": "encoder.layers.*.self_attention.linear_proj.bias",
            "esm.encoder.layer.*.attention.LayerNorm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
            "esm.encoder.layer.*.attention.LayerNorm.bias": "encoder.layers.*.self_attention.linear_qkv.layer_norm_bias",
            "esm.encoder.layer.*.intermediate.dense.weight": "encoder.layers.*.mlp.linear_fc1.weight",
            "esm.encoder.layer.*.intermediate.dense.bias": "encoder.layers.*.mlp.linear_fc1.bias",
            "esm.encoder.layer.*.output.dense.weight": "encoder.layers.*.mlp.linear_fc2.weight",
            "esm.encoder.layer.*.output.dense.bias": "encoder.layers.*.mlp.linear_fc2.bias",
            "esm.encoder.layer.*.LayerNorm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight",
            "esm.encoder.layer.*.LayerNorm.bias": "encoder.layers.*.mlp.linear_fc1.layer_norm_bias",
            "esm.encoder.emb_layer_norm_after.weight": "encoder.final_layernorm.weight",
            "esm.encoder.emb_layer_norm_after.bias": "encoder.final_layernorm.bias",
            "lm_head.dense.weight": "lm_head.dense.weight",
            "lm_head.dense.bias": "lm_head.dense.bias",
            "lm_head.layer_norm.weight": "lm_head.layer_norm.weight",
            "lm_head.layer_norm.bias": "lm_head.layer_norm.bias",
        }

        # lm_head.bias
        return io.apply_transforms(
            source,
            target,
            mapping=mapping,
            transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight, _import_qkv_bias],
        )

    @property
    def tokenizer(self) -> BioNeMoESMTokenizer:
        """We just have the one tokenizer for ESM-2."""
        return get_tokenizer()

    @property
    def config(self) -> ESM2Config:
        """Returns the transformed ESM-2 config given the model tag."""
        source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True)
        output = ESM2Config(
            num_layers=source.num_hidden_layers,
            hidden_size=source.hidden_size,
            ffn_hidden_size=source.intermediate_size,
            position_embedding_type="rope",
            num_attention_heads=source.num_attention_heads,
            seq_length=source.max_position_embeddings,
            fp16=(dtype_from_hf(source) == torch.float16),
            bf16=(dtype_from_hf(source) == torch.bfloat16),
            params_dtype=dtype_from_hf(source),
        )

        return output

config property

Returns the transformed ESM-2 config given the model tag.

tokenizer property

We just have the one tokenizer for ESM-2.

apply(output_path)

Applies the transformation.

Largely inspired by https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/features/hf-integration.html

Source code in bionemo/esm2/model/convert.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def apply(self, output_path: Path) -> Path:
    """Applies the transformation.

    Largely inspired by
    https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/features/hf-integration.html
    """
    source = AutoModelForMaskedLM.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto")
    target = self.init()
    trainer = self.nemo_setup(target)
    self.convert_state(source, target)
    self.nemo_save(output_path, trainer)

    print(f"Converted ESM-2 model to Nemo, model saved to {output_path}")

    teardown(trainer, target)
    del trainer, target

    return output_path

convert_state(source, target)

Converting HF state dict to NeMo state dict.

Source code in bionemo/esm2/model/convert.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def convert_state(self, source, target):
    """Converting HF state dict to NeMo state dict."""
    mapping = {
        # "esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq": "rotary_pos_emb.inv_freq",
        "esm.encoder.layer.*.attention.output.dense.weight": "encoder.layers.*.self_attention.linear_proj.weight",
        "esm.encoder.layer.*.attention.output.dense.bias": "encoder.layers.*.self_attention.linear_proj.bias",
        "esm.encoder.layer.*.attention.LayerNorm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
        "esm.encoder.layer.*.attention.LayerNorm.bias": "encoder.layers.*.self_attention.linear_qkv.layer_norm_bias",
        "esm.encoder.layer.*.intermediate.dense.weight": "encoder.layers.*.mlp.linear_fc1.weight",
        "esm.encoder.layer.*.intermediate.dense.bias": "encoder.layers.*.mlp.linear_fc1.bias",
        "esm.encoder.layer.*.output.dense.weight": "encoder.layers.*.mlp.linear_fc2.weight",
        "esm.encoder.layer.*.output.dense.bias": "encoder.layers.*.mlp.linear_fc2.bias",
        "esm.encoder.layer.*.LayerNorm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight",
        "esm.encoder.layer.*.LayerNorm.bias": "encoder.layers.*.mlp.linear_fc1.layer_norm_bias",
        "esm.encoder.emb_layer_norm_after.weight": "encoder.final_layernorm.weight",
        "esm.encoder.emb_layer_norm_after.bias": "encoder.final_layernorm.bias",
        "lm_head.dense.weight": "lm_head.dense.weight",
        "lm_head.dense.bias": "lm_head.dense.bias",
        "lm_head.layer_norm.weight": "lm_head.layer_norm.weight",
        "lm_head.layer_norm.bias": "lm_head.layer_norm.bias",
    }

    # lm_head.bias
    return io.apply_transforms(
        source,
        target,
        mapping=mapping,
        transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight, _import_qkv_bias],
    )

init()

Initialize the converted model.

Source code in bionemo/esm2/model/convert.py
35
36
37
def init(self) -> BionemoLightningModule:
    """Initialize the converted model."""
    return biobert_lightning_module(self.config, tokenizer=self.tokenizer)