Skip to content

Mamba remove optimizer

MambaOptimizerRemover

Bases: ModelConnector['MambaModel', MambaModel]

Removes the optimizer state from a nemo2 format model checkpoint.

Source code in bionemo/evo2/utils/checkpoint/mamba_remove_optimizer.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@io.model_importer(MambaModel, "pytorch")
class MambaOptimizerRemover(io.ModelConnector["MambaModel", MambaModel]):
    """Removes the optimizer state from a nemo2 format model checkpoint."""

    def __new__(cls, path: str, model_config=None):
        """Creates a new importer instance.

        Args:
            path: Path to the PyTorch model
            model_config: Optional model configuration

        Returns:
            PyTorchHyenaImporter instance
        """
        instance = super().__new__(cls, path)
        instance.model_config = model_config
        return instance

    def init(self) -> MambaModel:
        """Initializes a new HyenaModel instance.

        Returns:
            HyenaModel: Initialized model
        """
        return MambaModel(self.config, tokenizer=self.tokenizer)

    def get_source_model(self):
        """Returns the source model."""
        model, _ = self.nemo_load(self)
        return model

    def apply(self, output_path: Path, checkpoint_format: str = "torch_dist") -> Path:
        """Applies the model conversion from PyTorch to NeMo format.

        Args:
            output_path: Path to save the converted model
            checkpoint_format: Format for saving checkpoints

        Returns:
            Path: Path to the saved NeMo model
        """
        source = self.get_source_model()

        target = self.init()
        trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format=checkpoint_format)
        source.to(self.config.params_dtype)
        target.to(self.config.params_dtype)
        self.convert_state(source, target)
        self.nemo_save(output_path, trainer)

        logging.info(f"Converted Hyena model to Nemo, model saved to {output_path}")

        teardown(trainer, target)
        del trainer, target

        return output_path

    def convert_state(self, source, target):
        """Converts the state dictionary from source format to target format.

        Args:
            source: Source model state
            target: Target model

        Returns:
            Result of applying state transforms
        """
        mapping = {k: k for k in source.module.state_dict().keys()}
        return io.apply_transforms(
            source,
            target,
            mapping=mapping,
        )

    @property
    def tokenizer(self):
        """Gets the tokenizer for the model.

        Returns:
            Tokenizer instance
        """
        from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

        tokenizer = get_nmt_tokenizer(
            library=self.model_config.tokenizer_library,
        )

        return tokenizer

    @property
    def config(self) -> NemotronHConfigBase:
        """Gets the model configuration.

        Returns:
            HyenaConfig: Model configuration
        """
        return self.model_config

config property

Gets the model configuration.

Returns:

Name Type Description
HyenaConfig NemotronHConfigBase

Model configuration

tokenizer property

Gets the tokenizer for the model.

Returns:

Type Description

Tokenizer instance

__new__(path, model_config=None)

Creates a new importer instance.

Parameters:

Name Type Description Default
path str

Path to the PyTorch model

required
model_config

Optional model configuration

None

Returns:

Type Description

PyTorchHyenaImporter instance

Source code in bionemo/evo2/utils/checkpoint/mamba_remove_optimizer.py
58
59
60
61
62
63
64
65
66
67
68
69
70
def __new__(cls, path: str, model_config=None):
    """Creates a new importer instance.

    Args:
        path: Path to the PyTorch model
        model_config: Optional model configuration

    Returns:
        PyTorchHyenaImporter instance
    """
    instance = super().__new__(cls, path)
    instance.model_config = model_config
    return instance

apply(output_path, checkpoint_format='torch_dist')

Applies the model conversion from PyTorch to NeMo format.

Parameters:

Name Type Description Default
output_path Path

Path to save the converted model

required
checkpoint_format str

Format for saving checkpoints

'torch_dist'

Returns:

Name Type Description
Path Path

Path to the saved NeMo model

Source code in bionemo/evo2/utils/checkpoint/mamba_remove_optimizer.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def apply(self, output_path: Path, checkpoint_format: str = "torch_dist") -> Path:
    """Applies the model conversion from PyTorch to NeMo format.

    Args:
        output_path: Path to save the converted model
        checkpoint_format: Format for saving checkpoints

    Returns:
        Path: Path to the saved NeMo model
    """
    source = self.get_source_model()

    target = self.init()
    trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format=checkpoint_format)
    source.to(self.config.params_dtype)
    target.to(self.config.params_dtype)
    self.convert_state(source, target)
    self.nemo_save(output_path, trainer)

    logging.info(f"Converted Hyena model to Nemo, model saved to {output_path}")

    teardown(trainer, target)
    del trainer, target

    return output_path

convert_state(source, target)

Converts the state dictionary from source format to target format.

Parameters:

Name Type Description Default
source

Source model state

required
target

Target model

required

Returns:

Type Description

Result of applying state transforms

Source code in bionemo/evo2/utils/checkpoint/mamba_remove_optimizer.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def convert_state(self, source, target):
    """Converts the state dictionary from source format to target format.

    Args:
        source: Source model state
        target: Target model

    Returns:
        Result of applying state transforms
    """
    mapping = {k: k for k in source.module.state_dict().keys()}
    return io.apply_transforms(
        source,
        target,
        mapping=mapping,
    )

get_source_model()

Returns the source model.

Source code in bionemo/evo2/utils/checkpoint/mamba_remove_optimizer.py
80
81
82
83
def get_source_model(self):
    """Returns the source model."""
    model, _ = self.nemo_load(self)
    return model

init()

Initializes a new HyenaModel instance.

Returns:

Name Type Description
HyenaModel MambaModel

Initialized model

Source code in bionemo/evo2/utils/checkpoint/mamba_remove_optimizer.py
72
73
74
75
76
77
78
def init(self) -> MambaModel:
    """Initializes a new HyenaModel instance.

    Returns:
        HyenaModel: Initialized model
    """
    return MambaModel(self.config, tokenizer=self.tokenizer)

main()

Convert a PyTorch Evo2 model checkpoint to a NeMo model checkpoint.

Source code in bionemo/evo2/utils/checkpoint/mamba_remove_optimizer.py
153
154
155
156
157
158
159
160
def main():
    """Convert a PyTorch Evo2 model checkpoint to a NeMo model checkpoint."""
    args = parse_args()

    evo2_config = MAMBA_MODEL_OPTIONS[args.model_size]()
    importer = MambaOptimizerRemover(args.model_path, model_config=evo2_config)
    assert not args.model_path.startswith("hf://"), "Strip optimizer only works on local nemo2 format checkpoints."
    importer.apply(args.output_dir)

parse_args()

Parse command-line arguments.

Source code in bionemo/evo2/utils/checkpoint/mamba_remove_optimizer.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
        help="Path to the Evo2 un-sharded (MP1) model checkpoint file, or a Hugging Face model name. Any model "
        "from the Savanna Evo2 family is supported such as 'hf://arcinstitute/savanna_evo2_1b_base'.",
    )
    parser.add_argument("--output-dir", type=str, required=True, help="Output directory path for the converted model.")
    parser.add_argument(
        "--model-size",
        type=str,
        choices=sorted(MAMBA_MODEL_OPTIONS.keys()),
        default="hybrid_mamba_8b",
        help="Model arch to use.",
    )
    return parser.parse_args()