nvalchemi.models.base.BaseModelMixin#
- class nvalchemi.models.base.BaseModelMixin[source]#
Abstract MixIn class providing a homogenized interface for wrapper models from external machine learning interatomic potential projects.
This mixin defines the core interface that all external model wrappers should implement to ensure consistency across different model types.
The mixin provides abstract methods for:
Computing embeddings at different graph levels
Predicting energies and forces
Defining expected output shapes
Adapting inputs and outputs between framework and external model formats
A concrete implementation of this mixin should utilize the following functions to implement predictions:
_adapt_input, which adapts the input batch to the model’s expected format_adapt_output, which adapts the model’s output to the framework’s expected formatvalidate_batch, which ensures that the input batch is compatible with the modelcompute_embeddings, which computes embeddings at different graph levels
The mixin also defines several properties that must be implemented to specify model capabilities; when adding a new model, these properties must be implemented.
model_card: Pydantic model that contains information about the model’s capabilities and requirementsembedding_shapes: Expected shapes of node, edge, and graph embeddings
The workflow for using this mixin is:
Implement all required properties to specify model capabilities
Implement
_adapt_inputto convert framework data to model formatImplement
parse_outputto convert model output to framework formatImplement prediction methods based on supported capabilities
Use
validate_batchto ensure input compatibilityCall
parse_outputto write model outputs to theBatchdata structure
- Raises:
NotImplementedError – If any required abstract methods or properties are not implemented
ValueError – If input validation fails in validate_batch
- adapt_input(data, **kwargs)[source]#
Adapt framework batch data to external model input format.
The base implementation will check the model_config to determine what input keys need gradients enabled, depending on what is required.
A subclass implementation should call this, in addition to doing whatever is needed to extract Batch inputs into arguments for the underlying model forward call.
The method should return a dictionary of input arguments that will be unpacked in the actual forward and/or __call__ methods.
- Parameters:
batch (Batch) – Framework batch data
data (AtomicData | Batch | AtomsLike)
kwargs (Any)
- Returns:
Input in the format expected by the external model (could be dict, custom object, etc.)
- Return type:
dict[str, Any]
- adapt_output(model_output, data)[source]#
Adapt external model output to the framework’s standard output format (ModelOutputs).
This implementation returns a ModelOutputs (OrderedDict) with keys from output_data(), initialized to None, and populates with values from model_output if present and if we can match the key names generically. It is unlikely that this will perfectly match key names for all models, so it is imperative to manually check and override this implementation in a subclass.
- Parameters:
model_output (Any) – Raw output from the external model
data (AtomicData | Batch) – Original input data (may be needed for context/metadata)
- Returns:
OrderedDict with expected output keys and their values (or None if not present).
- Return type:
ModelOutputs
- add_output_head(prefix)[source]#
Add an output head to the model.
This method should create an multilayer perceptron block for mapping input embeddings to a desired output shape. The logic for this should differentiate based on invariant/equivariant models - specifically those that use e3nn layers.
The method should then save the output head to a output_heads ModuleDict attribute.
- Parameters:
prefix (str) – Prefix for the output head
- Return type:
None
- abstractmethod compute_embeddings(data, **kwargs)[source]#
Compute embeddings at different levels of a batch of atomic graphs.
This method should extract meaningful representations from the model at node (atomic), edge (bond), and/or graph/system (structure) levels. The concrete implementation should check if the model supports computing embeddings, as well as perform validation on kwargs to make sure they are valid for the model.
The method should add graph, node, and/or edge embeddings to the Batch data structure in-place.
- Parameters:
data (AtomicData | Batch) – Input atomic data containing positions, atomic numbers, etc.
kwargs (Any)
- Returns:
Standardized AtomicData or Batch data structure mutated in place.
- Return type:
- Raises:
NotImplementedError – If the model does not support embeddings computation
- abstract property embedding_shapes: dict[str, tuple[int, ...]]#
Retrieves the expected shapes of the node, edge, and graph embeddings.
- export_model(path, as_state_dict=False)[source]#
Export the current model without the
BaseModelMixininterface.The idea behind this method is to allow users to use the trained model with the same interface as the corresponding ‘upstream’ version, so that they can re-use validation code that might have been written for the upstream case (e.g.
ase.Calculatorinstances).Essentially, this method should recreate the equivalent base class (by checking MRO), then run
torch.saveand serialize the model either directly or as itsstate_dict.- Parameters:
path (Path)
as_state_dict (bool)
- Return type:
None
- input_data()[source]#
Returns a set of keys that are expected to be in the input data.
This method provides the base logic that is generally common across all models, but can be overridden by subclasses to add more expected keys.
- Returns:
Set of keys that are expected to be in the input data.
- Return type:
set[str]
- make_neighbor_hooks()[source]#
Return a list of
NeighborListHookinstances for this model’s neighbor configuration.Returns an empty list if the model does not require a neighbor list. Defers the import to avoid circular imports.
- Return type:
list
- abstract property model_card: ModelCard#
Retrieves the model card for the model.
The model card is a Pydantic model that contains information about the model’s capabilities and requirements.
- output_data()[source]#
Returns a set of keys that are expected to be computed by the model and written to the AtomicData or Batch data structure.
This method provides the base logic that is generally common across all models, but can be overridden by subclasses to add more expected keys.
- Returns:
Set of keys that are expected to be computed by the model and written to the AtomicData or Batch data structure.
- Return type:
set[str]