Skip to content

Convert to nemo

HyenaOptimizerRemover

Bases: ModelConnector['HyenaModel', HyenaModel]

Removes the optimizer state from a nemo2 format model checkpoint.

Source code in bionemo/evo2/utils/checkpoint/convert_to_nemo.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
@io.model_importer(HyenaModel, "pytorch")
class HyenaOptimizerRemover(io.ModelConnector["HyenaModel", HyenaModel]):
    """Removes the optimizer state from a nemo2 format model checkpoint."""

    def __new__(cls, path: str, model_config=None):
        """Creates a new importer instance.

        Args:
            path: Path to the PyTorch model
            model_config: Optional model configuration

        Returns:
            PyTorchHyenaImporter instance
        """
        instance = super().__new__(cls, path)
        instance.model_config = model_config
        return instance

    def init(self) -> HyenaModel:
        """Initializes a new HyenaModel instance.

        Returns:
            HyenaModel: Initialized model
        """
        return HyenaModel(self.config, tokenizer=self.tokenizer)

    def get_source_model(self):
        """Returns the source model."""
        model, _ = self.nemo_load(self)
        return model

    def apply(self, output_path: Path, checkpoint_format: str = "torch_dist") -> Path:
        """Applies the model conversion from PyTorch to NeMo format.

        Args:
            output_path: Path to save the converted model
            checkpoint_format: Format for saving checkpoints

        Returns:
            Path: Path to the saved NeMo model
        """
        source = self.get_source_model()

        target = self.init()
        trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format=checkpoint_format)
        source.to(self.config.params_dtype)
        target.to(self.config.params_dtype)
        self.convert_state(source, target)
        self.nemo_save(output_path, trainer)

        logging.info(f"Converted Hyena model to Nemo, model saved to {output_path}")

        teardown(trainer, target)
        del trainer, target

        return output_path

    def convert_state(self, source, target):
        """Converts the state dictionary from source format to target format.

        Args:
            source: Source model state
            target: Target model

        Returns:
            Result of applying state transforms
        """
        mapping = {k: k for k in source.module.state_dict().keys()}
        return io.apply_transforms(
            source,
            target,
            mapping=mapping,
        )

    @property
    def tokenizer(self):
        """Gets the tokenizer for the model.

        Returns:
            Tokenizer instance
        """
        from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

        tokenizer = get_nmt_tokenizer(
            library=self.model_config.tokenizer_library,
        )

        return tokenizer

    @property
    def config(self) -> HyenaConfig:
        """Gets the model configuration.

        Returns:
            HyenaConfig: Model configuration
        """
        return self.model_config

config property

Gets the model configuration.

Returns:

Name Type Description
HyenaConfig HyenaConfig

Model configuration

tokenizer property

Gets the tokenizer for the model.

Returns:

Type Description

Tokenizer instance

__new__(path, model_config=None)

Creates a new importer instance.

Parameters:

Name Type Description Default
path str

Path to the PyTorch model

required
model_config

Optional model configuration

None

Returns:

Type Description

PyTorchHyenaImporter instance

Source code in bionemo/evo2/utils/checkpoint/convert_to_nemo.py
66
67
68
69
70
71
72
73
74
75
76
77
78
def __new__(cls, path: str, model_config=None):
    """Creates a new importer instance.

    Args:
        path: Path to the PyTorch model
        model_config: Optional model configuration

    Returns:
        PyTorchHyenaImporter instance
    """
    instance = super().__new__(cls, path)
    instance.model_config = model_config
    return instance

apply(output_path, checkpoint_format='torch_dist')

Applies the model conversion from PyTorch to NeMo format.

Parameters:

Name Type Description Default
output_path Path

Path to save the converted model

required
checkpoint_format str

Format for saving checkpoints

'torch_dist'

Returns:

Name Type Description
Path Path

Path to the saved NeMo model

Source code in bionemo/evo2/utils/checkpoint/convert_to_nemo.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def apply(self, output_path: Path, checkpoint_format: str = "torch_dist") -> Path:
    """Applies the model conversion from PyTorch to NeMo format.

    Args:
        output_path: Path to save the converted model
        checkpoint_format: Format for saving checkpoints

    Returns:
        Path: Path to the saved NeMo model
    """
    source = self.get_source_model()

    target = self.init()
    trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format=checkpoint_format)
    source.to(self.config.params_dtype)
    target.to(self.config.params_dtype)
    self.convert_state(source, target)
    self.nemo_save(output_path, trainer)

    logging.info(f"Converted Hyena model to Nemo, model saved to {output_path}")

    teardown(trainer, target)
    del trainer, target

    return output_path

convert_state(source, target)

Converts the state dictionary from source format to target format.

Parameters:

Name Type Description Default
source

Source model state

required
target

Target model

required

Returns:

Type Description

Result of applying state transforms

Source code in bionemo/evo2/utils/checkpoint/convert_to_nemo.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def convert_state(self, source, target):
    """Converts the state dictionary from source format to target format.

    Args:
        source: Source model state
        target: Target model

    Returns:
        Result of applying state transforms
    """
    mapping = {k: k for k in source.module.state_dict().keys()}
    return io.apply_transforms(
        source,
        target,
        mapping=mapping,
    )

get_source_model()

Returns the source model.

Source code in bionemo/evo2/utils/checkpoint/convert_to_nemo.py
88
89
90
91
def get_source_model(self):
    """Returns the source model."""
    model, _ = self.nemo_load(self)
    return model

init()

Initializes a new HyenaModel instance.

Returns:

Name Type Description
HyenaModel HyenaModel

Initialized model

Source code in bionemo/evo2/utils/checkpoint/convert_to_nemo.py
80
81
82
83
84
85
86
def init(self) -> HyenaModel:
    """Initializes a new HyenaModel instance.

    Returns:
        HyenaModel: Initialized model
    """
    return HyenaModel(self.config, tokenizer=self.tokenizer)

main()

Convert a PyTorch Evo2 model checkpoint to a NeMo model checkpoint.

Source code in bionemo/evo2/utils/checkpoint/convert_to_nemo.py
161
162
163
164
165
166
167
168
169
170
171
172
173
def main():
    """Convert a PyTorch Evo2 model checkpoint to a NeMo model checkpoint."""
    args = parse_args()

    evo2_config = HYENA_MODEL_OPTIONS[args.model_size]()
    if args.strip_optimizer:
        importer = HyenaOptimizerRemover(args.model_path, model_config=evo2_config)
        assert not args.model_path.startswith("hf://"), "Strip optimizer only works on local nemo2 format checkpoints."
    elif args.model_path.startswith("hf://"):
        importer = HuggingFaceSavannaHyenaImporter(args.model_path.lstrip("hf://"), model_config=evo2_config)
    else:
        importer = PyTorchHyenaImporter(args.model_path, model_config=evo2_config)
    importer.apply(args.output_dir)

parse_args()

Parse command-line arguments.

Source code in bionemo/evo2/utils/checkpoint/convert_to_nemo.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
        help="Path to the Evo2 un-sharded (MP1) model checkpoint file, or a Hugging Face model name. Any model "
        "from the Savanna Evo2 family is supported such as 'hf://arcinstitute/savanna_evo2_1b_base'.",
    )
    parser.add_argument("--output-dir", type=str, required=True, help="Output directory path for the converted model.")
    parser.add_argument(
        "--model-size",
        type=str,
        choices=sorted(HYENA_MODEL_OPTIONS.keys()),
        default="1b",
        help="Model architecture to use, choose between 1b, 7b, 40b, or test (a sub-model of 4 layers, "
        "less than 1B parameters). '*_arc_longcontext' models have GLU / FFN dimensions that support 1M "
        "context length when trained with TP>>8.",
    )
    parser.add_argument(
        "--strip-optimizer",
        action="store_true",
        help="Strip the optimizer state from the model checkpoint, this works on nemo2 format checkpoints.",
    )
    return parser.parse_args()