Skip to content

Evo2 remove optimizer

HyenaOptimizerRemover

Bases: _OptimizerRemoverBase, ModelConnector['HyenaModel', HyenaModel]

Removes the optimizer state from a nemo2 format model checkpoint.

Source code in bionemo/evo2/utils/checkpoint/evo2_remove_optimizer.py
158
159
160
161
162
@io.model_importer(HyenaModel, "pytorch")
class HyenaOptimizerRemover(_OptimizerRemoverBase, io.ModelConnector["HyenaModel", HyenaModel]):
    """Removes the optimizer state from a nemo2 format model checkpoint."""

    MODEL_CLS = HyenaModel

LlamaOptimizerRemover

Bases: _OptimizerRemoverBase, ModelConnector['GPTModel', GPTModel]

Removes the optimizer state from a nemo2 format model checkpoint.

Source code in bionemo/evo2/utils/checkpoint/evo2_remove_optimizer.py
165
166
167
168
169
@io.model_importer(GPTModel, "pytorch")
class LlamaOptimizerRemover(_OptimizerRemoverBase, io.ModelConnector["GPTModel", GPTModel]):
    """Removes the optimizer state from a nemo2 format model checkpoint."""

    MODEL_CLS = GPTModel

MambaOptimizerRemover

Bases: _OptimizerRemoverBase, ModelConnector['MambaModel', MambaModel]

Removes the optimizer state from a nemo2 format model checkpoint.

Source code in bionemo/evo2/utils/checkpoint/evo2_remove_optimizer.py
172
173
174
175
176
@io.model_importer(MambaModel, "pytorch")
class MambaOptimizerRemover(_OptimizerRemoverBase, io.ModelConnector["MambaModel", MambaModel]):
    """Removes the optimizer state from a nemo2 format model checkpoint."""

    MODEL_CLS = MambaModel

main()

Convert a PyTorch Evo2 model checkpoint to a NeMo model checkpoint.

Source code in bionemo/evo2/utils/checkpoint/evo2_remove_optimizer.py
179
180
181
182
183
184
185
186
187
188
189
190
def main():
    """Convert a PyTorch Evo2 model checkpoint to a NeMo model checkpoint."""
    args = parse_args()
    if args.model_type == "hyena":
        optimizer_remover = HyenaOptimizerRemover(args.model_path)
    elif args.model_type == "mamba":
        optimizer_remover = MambaOptimizerRemover(args.model_path)
    elif args.model_type == "llama":
        optimizer_remover = LlamaOptimizerRemover(args.model_path)
    else:
        raise ValueError(f"Invalid model type: {args.model_type}.")
    optimizer_remover.apply(args.output_dir)

parse_args()

Parse command-line arguments.

Source code in bionemo/evo2/utils/checkpoint/evo2_remove_optimizer.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
        help="Path to the Evo2 un-sharded (MP1) model checkpoint file, or a Hugging Face model name. Any model "
        "from the Savanna Evo2 family is supported such as 'hf://arcinstitute/savanna_evo2_1b_base'.",
    )
    parser.add_argument("--output-dir", type=str, required=True, help="Output directory path for the converted model.")
    parser.add_argument(
        "--model-type",
        type=str,
        choices=["hyena", "mamba", "llama"],
        default="hyena",
        help="Model architecture to use, choose between 'hyena', 'mamba', or 'llama'.",
    )
    return parser.parse_args()