Models: Wrapping ML Interatomic Potentials#
The ALCHEMI Toolkit uses a standardized interface —
BaseModelMixin — that sits between your
PyTorch model and the rest of the framework (dynamics, data loading, active
learning). Any machine-learning interatomic potential (MLIP) can be used with
the toolkit as long as it is wrapped with this interface.
This guide covers:
What models are currently supported out of the box.
The three building blocks:
ModelCard,ModelConfig, andBaseModelMixin.How to wrap your own model, using
DemoModelWrapperas a worked example.
Supported models#
The nvalchemi.models package ships wrappers for the following
potentials:
Wrapper class |
Underlying model |
Notes |
|---|---|---|
Non-invariant demo; useful for testing and tutorials |
||
|
|
Requires the |
|
|
Requires the |
AIMNet2Wrapper and ScaleShiftMACEWrapper are lazily imported — they only
load when accessed, so missing dependencies will not break other imports.
Architecture overview#
A wrapped model uses multiple inheritance: your existing nn.Module
subclass provides the forward pass, while BaseModelMixin adds the
standardized interface.
![digraph model_inheritance {
rankdir=BT
compound=true
fontname="Helvetica"
node [fontname="Helvetica" fontsize=11 shape=box style="filled,rounded"]
edge [fontname="Helvetica" fontsize=10]
YourModel [
label="YourModel(nn.Module)\l- forward()\l- your layers\l"
fillcolor="#E8F4FD"
color="#4A90D9"
]
BaseModelMixin [
label="BaseModelMixin\l- model_card\l- adapt_input()\l- adapt_output()\l"
fillcolor="#E8F4FD"
color="#4A90D9"
]
YourModelWrapper [
label="YourModelWrapper\l(YourModel, BaseModelMixin)\l"
fillcolor="#D5E8D4"
color="#82B366"
]
YourModelWrapper -> YourModel
YourModelWrapper -> BaseModelMixin
}](../_images/graphviz-2fb1abd533c74e9e90644e584c4224aaaa925a6f.png)
Multiple-inheritance pattern for model wrapping.#
The wrapper’s forward method follows a three-step pipeline:
adapt_input — convert
AtomicData/Batchinto the keyword arguments your model expects.super().forward — call the underlying model unchanged.
adapt_output — map raw model outputs to the framework’s
ModelOutputsordered dictionary.
ModelCard: declaring capabilities#
ModelCard is an immutable Pydantic
model that describes what a model can compute and what inputs it requires.
Every wrapper must return a ModelCard from its
model_card property.
Capability fields#
Field |
Default |
Meaning |
|---|---|---|
|
(required) |
|
|
|
Model can predict energies |
|
|
Model can predict forces |
|
|
Model can predict stress tensors |
|
|
Model can predict Hessians |
|
|
Model can predict dipole moments |
|
|
Model handles periodic boundary conditions |
|
|
Model accepts single |
|
|
Model can expose per-atom embeddings |
|
|
Model can expose per-edge embeddings |
|
|
Model can expose per-graph embeddings |
Requirement fields#
Field |
Default |
Meaning |
|---|---|---|
|
(required) |
Model expects |
|
|
Model expects |
|
|
Model expects partial charges per atom |
|
|
Model expects total system charge |
ModelCard uses ConfigDict(extra="allow"), so you can attach additional
metadata (e.g. model_name) without modifying the schema.
from nvalchemi.models.base import ModelCard
card = ModelCard(
forces_via_autograd=True,
supports_energies=True,
supports_forces=True,
needs_pbc=False,
needs_neighborlist=False,
model_name="MyPotential", # extra metadata
)
ModelConfig: runtime computation control#
ModelConfig controls what to compute on
each forward pass. It lives as the model_config attribute on every
BaseModelMixin instance and can be changed at any time.
Field |
Default |
Meaning |
|---|---|---|
|
|
Compute energies |
|
|
Compute forces |
|
|
Compute stresses |
|
|
Compute Hessians |
|
|
Compute dipoles |
|
|
Compute partial charges |
|
|
Compute intermediate embeddings |
|
|
Tensor keys that need |
gradient_keys is populated automatically — when compute_forces is
True, "positions" is added so that autograd-based force computation works.
from nvalchemi.models.base import ModelConfig
model.model_config = ModelConfig(
compute_forces=True,
compute_stresses=True, # enable stress computation
)
The helper _verify_request()
checks whether a requested computation is both enabled in ModelConfig and
supported by ModelCard. If it is requested but not supported, a
UserWarning is issued.
Wrapping your own model: step by step#
This section walks through every method you need to implement, using
DemoModelWrapper as the running example.
Step 1 — Create the wrapper class#
Use multiple inheritance with your model first and BaseModelMixin second:
from nvalchemi.models.base import BaseModelMixin, ModelCard
class DemoModelWrapper(DemoModel, BaseModelMixin):
...
Step 2 — Implement model_card#
Return a ModelCard describing your model’s
capabilities. This is a @property:
@property
def model_card(self) -> ModelCard:
return ModelCard(
forces_via_autograd=True,
supports_energies=True,
supports_forces=True,
supports_non_batch=True,
needs_pbc=False,
needs_neighborlist=False,
model_name=self.__class__.__name__,
)
Step 3 — Implement embedding_shapes#
Return a dictionary mapping embedding names to their trailing shapes. This is used by downstream consumers (e.g. active learning) to know what representations the model can provide:
@property
def embedding_shapes(self) -> dict[str, tuple[int, ...]]:
return {
"node_embeddings": (self.hidden_dim,),
"graph_embedding": (self.hidden_dim,),
}
Step 4 — Implement adapt_input#
Convert framework data to the keyword arguments your underlying model’s
forward() expects. Always call super().adapt_input() first — the
base implementation enables gradients on the required tensors and validates
that all required input keys (from model_card) are present:
def adapt_input(self, data: AtomicData | Batch, **kwargs) -> dict[str, Any]:
model_inputs = super().adapt_input(data, **kwargs)
# Extract tensors in the format your model expects
model_inputs["atomic_numbers"] = data.atomic_numbers
model_inputs["positions"] = data.positions.to(self.dtype)
# Handle batched vs. single input
if isinstance(data, Batch):
model_inputs["batch_indices"] = data.batch
else:
model_inputs["batch_indices"] = None
# Pass config flags to control model behavior
model_inputs["compute_forces"] = self.model_config.compute_forces
return model_inputs
Step 5 — Implement adapt_output#
Map the model’s raw output dictionary to ModelOutputs, an
OrderedDict[str, Tensor | None] with standardized keys. Always call
super().adapt_output() first — it creates the OrderedDict pre-filled
with expected keys (derived from model_config + model_card) and
auto-maps any keys whose names already match:
def adapt_output(self, model_output, data: AtomicData | Batch) -> ModelOutputs:
output = super().adapt_output(model_output, data)
energies = model_output["energies"]
if isinstance(data, AtomicData) and energies.ndim == 1:
energies = energies.unsqueeze(-1) # must be [B, 1]
output["energies"] = energies
if self.model_config.compute_forces:
output["forces"] = model_output["forces"]
# Validate: no expected key should be None
for key, value in output.items():
if value is None:
raise KeyError(
f"Key '{key}' not found in model output "
"but is supported and requested."
)
return output
The standard output shapes are:
Key |
Shape |
Description |
|---|---|---|
|
|
Per-graph total energy |
|
|
Per-atom forces |
|
|
Per-graph stress tensor |
|
|
Per-atom Hessian |
|
|
Per-graph dipole moment |
|
|
Per-atom partial charges |
Step 6 (optional) — Implement compute_embeddings#
Extract intermediate representations and write them to the data structure in-place. This is used by active learning and other downstream consumers:
def compute_embeddings(self, data: AtomicData | Batch, **kwargs) -> AtomicData | Batch:
model_inputs = self.adapt_input(data, **kwargs)
# Run the model's internal layers
atom_z = self.embedding(model_inputs["atomic_numbers"])
coord_z = self.coord_embedding(model_inputs["positions"])
embedding = self.joint_mlp(torch.cat([atom_z, coord_z], dim=-1))
embedding = embedding + atom_z + coord_z
# Aggregate to graph level via scatter
if isinstance(data, Batch):
batch_indices = data.batch
num_graphs = data.batch_size
else:
batch_indices = torch.zeros_like(model_inputs["atomic_numbers"])
num_graphs = 1
graph_shape = self.embedding_shapes["graph_embedding"]
graph_embedding = torch.zeros(
(num_graphs, *graph_shape),
device=embedding.device,
dtype=embedding.dtype,
)
graph_embedding.scatter_add_(0, batch_indices.unsqueeze(-1), embedding)
# Write in-place
data.node_embeddings = embedding
data.graph_embeddings = graph_embedding
return data
Step 7 — Implement forward#
Wire the three-step pipeline together:
def forward(self, data: AtomicData | Batch, **kwargs) -> ModelOutputs:
model_inputs = self.adapt_input(data, **kwargs)
model_outputs = super().forward(**model_inputs)
return self.adapt_output(model_outputs, data)
super().forward(**model_inputs) calls the underlying DemoModel.forward
with the unpacked keyword arguments — your original model is never modified.
For additional flair, the @beartype.beartype decorator can be applied to
the forward method, which will provide runtime type checking on the
inputs and outputs, as well as shape checking.
Step 8 (optional) — Implement export_model#
Export the model without the BaseModelMixin interface, for use with
external tools (e.g. ASE calculators):
def export_model(self, path: Path, as_state_dict: bool = False) -> None:
base_cls = self.__class__.__mro__[1] # the original nn.Module
base_model = base_cls()
for name, module in self.named_children():
setattr(base_model, name, module)
if as_state_dict:
torch.save(base_model.state_dict(), path)
else:
torch.save(base_model, path)
Putting it all together#
A complete minimal wrapper for a custom potential:
import torch
from torch import nn
from typing import Any
from pathlib import Path
from nvalchemi.data import AtomicData, Batch
from nvalchemi.models.base import BaseModelMixin, ModelCard, ModelConfig
from nvalchemi._typing import ModelOutputs
class MyPotential(nn.Module):
"""Your existing PyTorch MLIP."""
def __init__(self, hidden_dim: int = 128):
super().__init__()
self.hidden_dim = hidden_dim
self.encoder = nn.Linear(3, hidden_dim)
self.energy_head = nn.Linear(hidden_dim, 1)
def forward(self, positions, batch_indices=None, **kwargs):
h = self.encoder(positions)
node_energy = self.energy_head(h)
if batch_indices is not None:
num_graphs = batch_indices.max() + 1
energies = torch.zeros(num_graphs, 1, device=h.device, dtype=h.dtype)
energies.scatter_add_(0, batch_indices.unsqueeze(-1), node_energy)
else:
energies = node_energy.sum(dim=0, keepdim=True)
return {"energies": energies}
class MyPotentialWrapper(MyPotential, BaseModelMixin):
"""Wrapped version for use in nvalchemi."""
@property
def model_card(self) -> ModelCard:
return ModelCard(
forces_via_autograd=True,
supports_energies=True,
supports_forces=True,
supports_non_batch=True,
needs_neighborlist=False,
needs_pbc=False,
)
@property
def embedding_shapes(self) -> dict[str, tuple[int, ...]]:
return {"node_embeddings": (self.hidden_dim,)}
def adapt_input(self, data: AtomicData | Batch, **kwargs: Any) -> dict[str, Any]:
model_inputs = super().adapt_input(data, **kwargs)
model_inputs["positions"] = data.positions
model_inputs["batch_indices"] = data.batch if isinstance(data, Batch) else None
return model_inputs
def adapt_output(self, model_output: Any, data: AtomicData | Batch) -> ModelOutputs:
output = super().adapt_output(model_output, data)
output["energies"] = model_output["energies"]
if self.model_config.compute_forces:
output["forces"] = -torch.autograd.grad(
model_output["energies"],
data.positions,
grad_outputs=torch.ones_like(model_output["energies"]),
create_graph=self.training,
)[0]
return output
def compute_embeddings(self, data: AtomicData | Batch, **kwargs) -> AtomicData | Batch:
model_inputs = self.adapt_input(data, **kwargs)
data.node_embeddings = self.encoder(model_inputs["positions"])
return data
def forward(self, data: AtomicData | Batch, **kwargs: Any) -> ModelOutputs:
model_inputs = self.adapt_input(data, **kwargs)
model_outputs = super().forward(**model_inputs)
return self.adapt_output(model_outputs, data)
Usage:
model = MyPotentialWrapper(hidden_dim=128)
model.model_config = ModelConfig(compute_forces=True)
data = AtomicData(
positions=torch.randn(5, 3),
atomic_numbers=torch.tensor([6, 6, 8, 1, 1], dtype=torch.long),
)
batch = Batch.from_data_list([data])
outputs = model(batch)
# outputs["energies"] shape: [1, 1]
# outputs["forces"] shape: [5, 3]
How models integrate with dynamics#
Once wrapped, a model plugs directly into the dynamics framework. The
dynamics integrator calls the wrapper’s forward method internally via
BaseDynamics.compute(), and the resulting forces and energies are written
back to the batch:
from nvalchemi.dynamics import DemoDynamics
model = MyPotentialWrapper(hidden_dim=128)
dynamics = DemoDynamics(model=model, n_steps=1000, dt=0.5)
dynamics.run(batch)
The __needs_keys__ set on the dynamics class (e.g. {"forces"}) is
validated against the model’s output after every compute() call, so
mismatches between the model’s declared capabilities and the integrator’s
requirements are caught immediately at runtime.
See also#
Examples: The gallery includes dynamics examples that demonstrate model usage in context.
API:
nvalchemi.modelsfor the full reference ofBaseModelMixin,ModelCard, andModelConfig.Dynamics guide: dynamics for how models are used inside optimization and MD workflows.