Skip to content

Convert to nemo

main()

Convert a PyTorch Evo2 model checkpoint to a NeMo model checkpoint.

Source code in bionemo/evo2/utils/checkpoint/convert_to_nemo.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def main():
    """Convert a PyTorch Evo2 model checkpoint to a NeMo model checkpoint."""
    args = parse_args()
    model_type = infer_model_type(args.model_size)
    if model_type == "hyena":
        config_modifiers_init = {}
        if args.use_subquadratic_ops:
            config_modifiers_init["use_subquadratic_ops"] = True
        evo2_config = HYENA_MODEL_OPTIONS[args.model_size](**config_modifiers_init)
        if args.model_path.startswith("hf://"):
            importer = HuggingFaceSavannaHyenaImporter(args.model_path.lstrip("hf://"), model_config=evo2_config)
        else:
            importer = PyTorchHyenaImporter(args.model_path, model_config=evo2_config)
    elif model_type == "llama":
        importer = HFEdenLlamaImporter(args.model_path)
    else:
        raise ValueError(f"Importer model type: {model_type}.")
    importer.apply(args.output_dir)

parse_args()

Parse command-line arguments.

Source code in bionemo/evo2/utils/checkpoint/convert_to_nemo.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
        help="Path to the Evo2 un-sharded (MP1) model checkpoint file, or a Hugging Face model name. Any model "
        "from the Savanna Evo2 family is supported such as 'hf://arcinstitute/savanna_evo2_1b_base'.",
    )
    parser.add_argument("--output-dir", type=str, required=True, help="Output directory path for the converted model.")
    parser.add_argument(
        "--use-subquadratic_ops",
        action="store_true",
        help="The checkpoint being converted should use subquadratic_ops.",
    )
    parser.add_argument(
        "--model-size",
        type=str,
        choices=sorted(set(HYENA_MODEL_OPTIONS.keys()) | set(LLAMA_MODEL_OPTIONS.keys())),
        required=True,
        help="Model architecture to use, choose between 1b, 7b, 40b, or test (a sub-model of 4 layers, "
        "less than 1B parameters). '*_arc_longcontext' models have GLU / FFN dimensions that support 1M "
        "context length when trained with TP>>8. Note that Mamba models are not supported for conversion yet.",
    )
    return parser.parse_args()