Skip to content

Config

IOMixinProto

Bases: Protocol

A Protocol for the get/set hparam functions of the IOMixin class from NeMo.

Source code in bionemo/llm/model/config.py
123
124
125
126
127
128
129
130
131
132
class IOMixinProto(Protocol):
    """A Protocol for the get/set hparam functions of the IOMixin class from NeMo."""

    def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
        """Set the value of an attribute in the config attached to the class by the IOMixin."""
        ...

    def get_hparam(self, attribute: str) -> Any:
        """Get the value of an attribute in the config attached to the class by the IOMixin."""
        ...

get_hparam(attribute)

Get the value of an attribute in the config attached to the class by the IOMixin.

Source code in bionemo/llm/model/config.py
130
131
132
def get_hparam(self, attribute: str) -> Any:
    """Get the value of an attribute in the config attached to the class by the IOMixin."""
    ...

set_hparam(attribute, value, also_change_value=True)

Set the value of an attribute in the config attached to the class by the IOMixin.

Source code in bionemo/llm/model/config.py
126
127
128
def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
    """Set the value of an attribute in the config attached to the class by the IOMixin."""
    ...

MegatronBioNeMoModelConfig

Bases: BionemoModelConfig[MegatronModelType], TransformerConfig, WillHaveGetSetHparam

A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires.

Source code in bionemo/llm/model/config.py
59
60
61
62
class MegatronBioNeMoModelConfig(BionemoModelConfig[MegatronModelType], TransformerConfig, iom.WillHaveGetSetHparam):
    """A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires."""

    model_cls: Type[MegatronModelType]

MegatronBioNeMoTrainableModelConfig dataclass

Bases: MegatronBioNeMoModelConfig[MegatronModelType], BionemoTrainableModelConfig[MegatronModelType, MegatronLossType], Generic[MegatronModelType, MegatronLossType]

A TrainableModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires.

Source code in bionemo/llm/model/config.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@dataclass
class MegatronBioNeMoTrainableModelConfig(
    MegatronBioNeMoModelConfig[MegatronModelType],
    BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
    Generic[MegatronModelType, MegatronLossType],
):
    """A TrainableModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires."""

    initial_ckpt_path: str | None = None
    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
    override_parent_fields: List[str] = field(default_factory=lambda: _OVERRIDE_BIONEMO_CONFIG_DEFAULTS)

    def load_settings_from_checkpoint(self, initial_ckpt_path: str) -> None:
        """Load settings into self from the checkpoint saved in self.

        Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper
        parameters in this config, as well as the associated attributes in self in case they were modified post-init.

        Args:
            initial_ckpt_path: The path to the checkpoint to load, note that everything is loaded from this checkpoint
                other than the settings in self.override_parent_fields.

        Returns:
            None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into
                a checkpoint are updated.
        """
        logger.warning(f"Loading {self.initial_ckpt_path}")
        # 1. get the config from the trainer io context by querying the `model.config` subpath of the trainer.
        initial_config: MegatronBioNeMoTrainableModelConfig = io.load_context(
            path=Path(initial_ckpt_path) / "context", subpath="model.config"
        )  # type: ignore
        initial_fields = {f.name for f in fields(initial_config)}
        my_fields = [f.name for f in fields(self)]
        skip_fields = set(self.override_parent_fields)
        override_fields = [f for f in my_fields if f in initial_fields and f not in skip_fields]
        override_mutate_possibly_extra_mutated_fiddle(self, initial_config, override_fields)

    def update_model_from_checkpoint(self, model: MegatronModelType, initial_ckpt_path: str) -> None:
        """Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.

        Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in
            self.initial_ckpt_skip_keys_with_these_prefixes.

        Args:
            model: The Megatron model to update.
            initial_ckpt_path: The path to the megatron checkpoint to load.

        Returns:
            None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring
                any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.
        """
        load_weights_sharded_inplace_nemo2_to_mcore(
            model=model,  # type: ignore
            distributed_checkpoint_dir=initial_ckpt_path,
            skip_keys_with_these_prefixes=set(self.initial_ckpt_skip_keys_with_these_prefixes),
        )

load_settings_from_checkpoint(initial_ckpt_path)

Load settings into self from the checkpoint saved in self.

Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper parameters in this config, as well as the associated attributes in self in case they were modified post-init.

Parameters:

Name Type Description Default
initial_ckpt_path str

The path to the checkpoint to load, note that everything is loaded from this checkpoint other than the settings in self.override_parent_fields.

required

Returns:

Type Description
None

None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into a checkpoint are updated.

Source code in bionemo/llm/model/config.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def load_settings_from_checkpoint(self, initial_ckpt_path: str) -> None:
    """Load settings into self from the checkpoint saved in self.

    Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper
    parameters in this config, as well as the associated attributes in self in case they were modified post-init.

    Args:
        initial_ckpt_path: The path to the checkpoint to load, note that everything is loaded from this checkpoint
            other than the settings in self.override_parent_fields.

    Returns:
        None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into
            a checkpoint are updated.
    """
    logger.warning(f"Loading {self.initial_ckpt_path}")
    # 1. get the config from the trainer io context by querying the `model.config` subpath of the trainer.
    initial_config: MegatronBioNeMoTrainableModelConfig = io.load_context(
        path=Path(initial_ckpt_path) / "context", subpath="model.config"
    )  # type: ignore
    initial_fields = {f.name for f in fields(initial_config)}
    my_fields = [f.name for f in fields(self)]
    skip_fields = set(self.override_parent_fields)
    override_fields = [f for f in my_fields if f in initial_fields and f not in skip_fields]
    override_mutate_possibly_extra_mutated_fiddle(self, initial_config, override_fields)

update_model_from_checkpoint(model, initial_ckpt_path)

Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.

Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in self.initial_ckpt_skip_keys_with_these_prefixes.

Parameters:

Name Type Description Default
model MegatronModelType

The Megatron model to update.

required
initial_ckpt_path str

The path to the megatron checkpoint to load.

required

Returns:

Type Description
None

None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.

Source code in bionemo/llm/model/config.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def update_model_from_checkpoint(self, model: MegatronModelType, initial_ckpt_path: str) -> None:
    """Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.

    Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in
        self.initial_ckpt_skip_keys_with_these_prefixes.

    Args:
        model: The Megatron model to update.
        initial_ckpt_path: The path to the megatron checkpoint to load.

    Returns:
        None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring
            any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.
    """
    load_weights_sharded_inplace_nemo2_to_mcore(
        model=model,  # type: ignore
        distributed_checkpoint_dir=initial_ckpt_path,
        skip_keys_with_these_prefixes=set(self.initial_ckpt_skip_keys_with_these_prefixes),
    )

TorchmetricsConfig dataclass

TorchmetricsConfig to instantiate torchmetrics.Metric class.

Fiddle requires all objects in config serializable and torchmetric.Metric is not. Its instantiation must be deferred into BionemoLightningModule.init. Only support torchmetrics currently, e.g. users can provide 'text.Perplexity' to 'class_path' to use 'torchmetrics.text.Perplexity'.

Source code in bionemo/llm/model/config.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@dataclass
class TorchmetricsConfig:
    """TorchmetricsConfig to instantiate torchmetrics.Metric class.

    Fiddle requires all objects in config serializable and torchmetric.Metric is not. Its instantiation must be deferred into BionemoLightningModule.__init__.
    Only support torchmetrics currently, e.g. users can provide 'text.Perplexity' to 'class_path' to use 'torchmetrics.text.Perplexity'.
    """

    class_path: str
    task: Literal["lm", "classification", "regression"]
    metric_name: str
    kwargs: Optional[dict[str, Any]] = None

    def __post_init__(self):
        """__post_init__ in dataclass."""
        self.kwargs = {} if self.kwargs is None else self.kwargs

    def get_instance(self) -> torchmetrics.Metric:
        """Dynamically imports and instantiates the metric class."""
        if "." in self.class_path:
            module_path, class_name = self.class_path.rsplit(".", 1)
            module = importlib.import_module(f"torchmetrics.{module_path}")
        else:
            class_name = self.class_path
            module = importlib.import_module("torchmetrics")

        cls_ = getattr(module, class_name)
        return cls_(**self.kwargs)

__post_init__()

post_init in dataclass.

Source code in bionemo/llm/model/config.py
172
173
174
def __post_init__(self):
    """__post_init__ in dataclass."""
    self.kwargs = {} if self.kwargs is None else self.kwargs

get_instance()

Dynamically imports and instantiates the metric class.

Source code in bionemo/llm/model/config.py
176
177
178
179
180
181
182
183
184
185
186
def get_instance(self) -> torchmetrics.Metric:
    """Dynamically imports and instantiates the metric class."""
    if "." in self.class_path:
        module_path, class_name = self.class_path.rsplit(".", 1)
        module = importlib.import_module(f"torchmetrics.{module_path}")
    else:
        class_name = self.class_path
        module = importlib.import_module("torchmetrics")

    cls_ = getattr(module, class_name)
    return cls_(**self.kwargs)

override_mutate_possibly_extra_mutated_fiddle(target_cfg, source_cfg, maybe_mutated_elements_to_clone)

Override the values of the target config with the values of the source config for the given elements.

This will modify the tracked init hyper-parameter values, as well as modifying the associated attributes in self incase they were modified later by post_init code.

Parameters:

Name Type Description Default
target_cfg IOMixinProto

The config to update.

required
source_cfg IOMixinProto

The config to copy values from.

required
maybe_mutated_elements_to_clone List[str]

The list of elements to copy from the source config to the target config.

required

Returns:

Type Description
None

None, the target config is updated in place.

Source code in bionemo/llm/model/config.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def override_mutate_possibly_extra_mutated_fiddle(
    target_cfg: IOMixinProto, source_cfg: IOMixinProto, maybe_mutated_elements_to_clone: List[str]
) -> None:
    """Override the values of the target config with the values of the source config for the given elements.

    This will modify the tracked init hyper-parameter values, as well as modifying the associated attributes in
        self incase they were modified later by post_init code.

    Args:
        target_cfg: The config to update.
        source_cfg: The config to copy values from.
        maybe_mutated_elements_to_clone: The list of elements to copy from the source config to the target config.

    Returns:
        None, the target config is updated in place.
    """
    for f in maybe_mutated_elements_to_clone:
        # 1. Update the tracked config values. Note that the associated attribute in self may have been modified
        #  post-init, so we don't want to change the value in self here. We do that separately next.
        target_cfg.set_hparam(f, source_cfg.get_hparam(f), also_change_value=False)
        # 2. Update the lazily untracked values (if the same variable name is used post-init)
        setattr(target_cfg, f, getattr(source_cfg, f))