nvalchemi.models.base.BaseModelMixin#

class nvalchemi.models.base.BaseModelMixin[source]#

Abstract MixIn class providing a homogenized interface for wrapper models from external machine learning interatomic potential projects.

This mixin defines the core interface that all external model wrappers should implement to ensure consistency across different model types.

The mixin provides abstract methods for:

  • Computing embeddings at different graph levels

  • Predicting energies and forces

  • Defining expected output shapes

  • Adapting inputs and outputs between framework and external model formats

A concrete implementation of this mixin should utilize the following functions to implement predictions:

  • _adapt_input, which adapts the input batch to the model’s expected format

  • _adapt_output, which adapts the model’s output to the framework’s expected format

  • validate_batch, which ensures that the input batch is compatible with the model

  • compute_embeddings, which computes embeddings at different graph levels

The mixin also defines several properties that must be implemented to specify model capabilities; when adding a new model, these properties must be implemented.

  • model_card: Pydantic model that contains information about the model’s capabilities and requirements

  • embedding_shapes: Expected shapes of node, edge, and graph embeddings

The workflow for using this mixin is:

  1. Implement all required properties to specify model capabilities

  2. Implement _adapt_input to convert framework data to model format

  3. Implement parse_output to convert model output to framework format

  4. Implement prediction methods based on supported capabilities

  5. Use validate_batch to ensure input compatibility

  6. Call parse_output to write model outputs to the Batch data structure

Raises:
  • NotImplementedError – If any required abstract methods or properties are not implemented

  • ValueError – If input validation fails in validate_batch

adapt_input(data, **kwargs)[source]#

Adapt framework batch data to external model input format.

The base implementation will check the model_config to determine what input keys need gradients enabled, depending on what is required.

A subclass implementation should call this, in addition to doing whatever is needed to extract Batch inputs into arguments for the underlying model forward call.

The method should return a dictionary of input arguments that will be unpacked in the actual forward and/or __call__ methods.

Parameters:
Returns:

Input in the format expected by the external model (could be dict, custom object, etc.)

Return type:

dict[str, Any]

adapt_output(model_output, data)[source]#

Adapt external model output to the framework’s standard output format (ModelOutputs).

This implementation returns a ModelOutputs (OrderedDict) with keys from output_data(), initialized to None, and populates with values from model_output if present and if we can match the key names generically. It is unlikely that this will perfectly match key names for all models, so it is imperative to manually check and override this implementation in a subclass.

Parameters:
  • model_output (Any) – Raw output from the external model

  • data (AtomicData | Batch) – Original input data (may be needed for context/metadata)

Returns:

OrderedDict with expected output keys and their values (or None if not present).

Return type:

ModelOutputs

add_output_head(prefix)[source]#

Add an output head to the model.

This method should create an multilayer perceptron block for mapping input embeddings to a desired output shape. The logic for this should differentiate based on invariant/equivariant models - specifically those that use e3nn layers.

The method should then save the output head to a output_heads ModuleDict attribute.

Parameters:

prefix (str) – Prefix for the output head

Return type:

None

abstractmethod compute_embeddings(data, **kwargs)[source]#

Compute embeddings at different levels of a batch of atomic graphs.

This method should extract meaningful representations from the model at node (atomic), edge (bond), and/or graph/system (structure) levels. The concrete implementation should check if the model supports computing embeddings, as well as perform validation on kwargs to make sure they are valid for the model.

The method should add graph, node, and/or edge embeddings to the Batch data structure in-place.

Parameters:
  • data (AtomicData | Batch) – Input atomic data containing positions, atomic numbers, etc.

  • kwargs (Any)

Returns:

Standardized AtomicData or Batch data structure mutated in place.

Return type:

AtomicData | Batch

Raises:

NotImplementedError – If the model does not support embeddings computation

abstract property embedding_shapes: dict[str, tuple[int, ...]]#

Retrieves the expected shapes of the node, edge, and graph embeddings.

export_model(path, as_state_dict=False)[source]#

Export the current model without the BaseModelMixin interface.

The idea behind this method is to allow users to use the trained model with the same interface as the corresponding ‘upstream’ version, so that they can re-use validation code that might have been written for the upstream case (e.g. ase.Calculator instances).

Essentially, this method should recreate the equivalent base class (by checking MRO), then run torch.save and serialize the model either directly or as its state_dict.

Parameters:
  • path (Path)

  • as_state_dict (bool)

Return type:

None

input_data()[source]#

Returns a set of keys that are expected to be in the input data.

This method provides the base logic that is generally common across all models, but can be overridden by subclasses to add more expected keys.

Returns:

Set of keys that are expected to be in the input data.

Return type:

set[str]

make_neighbor_hooks()[source]#

Return a list of NeighborListHook instances for this model’s neighbor configuration.

Returns an empty list if the model does not require a neighbor list. Defers the import to avoid circular imports.

Return type:

list

abstract property model_card: ModelCard#

Retrieves the model card for the model.

The model card is a Pydantic model that contains information about the model’s capabilities and requirements.

output_data()[source]#

Returns a set of keys that are expected to be computed by the model and written to the AtomicData or Batch data structure.

This method provides the base logic that is generally common across all models, but can be overridden by subclasses to add more expected keys.

Returns:

Set of keys that are expected to be computed by the model and written to the AtomicData or Batch data structure.

Return type:

set[str]