Skip to content

Weight utils

load_weights_sharded_inplace_nemo2_to_mcore(model, distributed_checkpoint_dir, skip_keys_with_these_prefixes)

Given a megatron module, this function will determine which keys/subsets of weights to load given the parallel/distributed state. This operates assuming a checkpoint was saved by a nemo2 trainer which places the module. prefix on all key names, but we are then going to load directly in to the megatron module without the module. prefix. Note that if there are any extra keys that you do not want to search the checkpoint for, for example if you add new layers/heads onto your module, you need to supply the prefix path to those keys in your model and they will be ignored. This latter feature is key for flexible fine-tuning strategies where you load weights partially from other models with partially overlapping structures.

Parameters:

Name Type Description Default
model MegatronModelType

Megatron model that you want to load weights into.

required
distributed_checkpoint_dir str | Path

description

required
skip_keys_with_these_prefixes Set[str]

description

required
Source code in bionemo/llm/utils/weight_utils.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def load_weights_sharded_inplace_nemo2_to_mcore(
    model: MegatronModelType, distributed_checkpoint_dir: str | Path, skip_keys_with_these_prefixes: Set[str]
) -> None:
    """Given a megatron module, this function will determine which keys/subsets of weights to load given the
        parallel/distributed state. This operates assuming a checkpoint was saved by a nemo2 trainer which places
        the `module.` prefix on all key names, but we are then going to load directly in to the megatron module
        without the `module.` prefix. Note that if there are any _extra_ keys that you do not want to search the
        checkpoint for, for example if you add new layers/heads onto your module, you need to supply the prefix
        path to those keys in your model and they will be ignored. This latter feature is key for flexible fine-tuning
        strategies where you load weights partially from other models with partially overlapping structures.

    Args:
        model: Megatron model that you want to load weights into.
        distributed_checkpoint_dir: _description_
        skip_keys_with_these_prefixes: _description_
    """  # noqa: D205
    sharded_state_dict = {
        _munge_key_megatron_to_nemo2(k): _munge_sharded_tensor_key_megatron_to_nemo2(v)
        for k, v in model.sharded_state_dict().items()
        if not _key_in_filter(k, skip_keys_with_these_prefixes) and "_extra_state" not in k
    }
    dist_checkpointing.load(
        sharded_state_dict=sharded_state_dict,
        checkpoint_dir=str(Path(distributed_checkpoint_dir) / "weights"),
        strict=dist_checkpointing.serialization.StrictHandling.ASSUME_OK_UNEXPECTED,
    )

nemo1_to_nemo2_biobert_key_mapping(old_key, new_model_prefix='module', old_model_prefix='model', te_mapping=False)

This function is used to map the keys from the old nemo BERT models to the new BioBERT models

Parameters:

Name Type Description Default
old_key str

old key we want to map to the expected new key name.

required
new_model_prefix str

The new key for the base weights. If you point this at the core megatron model set it to "". For the regular nemo2 lightning module following standards, set it to "module". Defaults to "module".

'module'
old_model_prefix str

The previous saved weight prefix. Defaults to "model" which was the standard in nemo1.

'model'

Returns:

Name Type Description
str str

New key name

Source code in bionemo/llm/utils/weight_utils.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def nemo1_to_nemo2_biobert_key_mapping(  # noqa: D417
    old_key: str,
    new_model_prefix: str = "module",
    old_model_prefix: str = "model",
    te_mapping: bool = False,
) -> str:
    """This function is used to map the keys from the old nemo BERT models to the new BioBERT models

    Args:
        old_key (str): old key we want to map to the expected new key name.
        new_model_prefix (str, optional): The new key for the base weights.
            If you point this at the core megatron model set it to "".
            For the regular nemo2 lightning module following standards, set it to "module".
            Defaults to "module".
        old_model_prefix (str, optional): The previous saved weight prefix. Defaults to "model" which was the standard in nemo1.

    Returns:
        str: New key name
    """  # noqa: D415
    # add the . to the end of the input prefixes if they are not the empty string,
    #  unless the user has already done so.
    if old_model_prefix != "":
        old_model_prefix = f"{old_model_prefix.rstrip('.')}."
    if new_model_prefix != "":
        new_model_prefix = f"{new_model_prefix.rstrip('.')}."

    # This function is used to map the keys from the old nemo BERT models to the new BioBERT models
    base_rename = old_key.replace(f"{old_model_prefix}language_model.", f"{new_model_prefix}")
    base_rename = base_rename.replace(f"{old_model_prefix}", f"{new_model_prefix}")
    if "dense_h_to_4h" in base_rename:
        return base_rename.replace("dense_h_to_4h", "linear_fc1")
    if "dense_4h_to_h" in base_rename:
        return base_rename.replace("dense_4h_to_h", "linear_fc2")
    if "query_key_value" in base_rename:
        return base_rename.replace("query_key_value", "linear_qkv")
    if "self_attention.dense" in base_rename:
        #  This is definitely the linear_proj and not the qkv. The linear_proj shapes are 256x256
        #   which match dense but not query_key_value
        # (Pdb) new_state_dict['encoder.layers.4.self_attention.linear_proj.weight'].shape
        #  torch.Size([256, 256])
        # (Pdb) new_state_dict['encoder.layers.4.self_attention.linear_qkv.weight'].shape
        # torch.Size([768, 256])
        # (Pdb) new_state_dict['encoder.layers.4.self_attention.linear_qkv.bias'].shape
        # torch.Size([768])
        return base_rename.replace("self_attention.dense", "self_attention.linear_proj")
    if "lm_head.bias" in base_rename:
        return base_rename.replace("lm_head.bias", "output_layer.bias")
    if "lm_head.weight" in base_rename:
        return base_rename.replace("lm_head.weight", "output_layer.weight")
    if "lm_head.layernorm" in base_rename:
        return base_rename.replace("lm_head.layernorm", "lm_head.layer_norm")

    if "post_attention_layernorm" in base_rename:
        base_rename = base_rename.replace("post_attention_layernorm", "pre_mlp_layernorm")

    # Handle the transformer engine spec's differences in layer naming and where things like layernorm are stored.
    #  TE moves layernorm from  an object that's part of the main attention layer to being an internal component of
    #  the linear layers, probably for efficiency/fusion of some sort.
    if te_mapping:
        if ".input_layernorm.weight" in base_rename:
            return base_rename.replace(".input_layernorm.weight", ".self_attention.linear_qkv.layer_norm_weight")
        if ".input_layernorm.bias" in base_rename:
            return base_rename.replace(".input_layernorm.bias", ".self_attention.linear_qkv.layer_norm_bias")
        if ".pre_mlp_layernorm.bias" in base_rename:
            return base_rename.replace(".pre_mlp_layernorm.bias", ".mlp.linear_fc1.layer_norm_bias")
        if ".pre_mlp_layernorm.weight" in base_rename:
            return base_rename.replace(".pre_mlp_layernorm.weight", ".mlp.linear_fc1.layer_norm_weight")
    return base_rename