# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import warnings
from collections import OrderedDict
from enum import Enum
from pathlib import Path
from typing import Annotated, Any
import torch
from pydantic import BaseModel, ConfigDict, Field
from nvalchemi._typing import AtomsLike, ModelOutputs
from nvalchemi.data import AtomicData, Batch
warnings.simplefilter("once", UserWarning)
class NeighborListFormat(str, Enum):
"""Storage format for neighbor data written to the batch.
Attributes
----------
COO : str
Coordinate (sparse) format. Neighbors are stored as an ``edge_index``
tensor of shape ``[2, E]`` (source / target global atom indices).
This is the conventional format used by most GNN-based MLIPs.
MATRIX : str
Dense neighbor-matrix format. Neighbors are stored as a
``neighbor_matrix`` tensor of shape ``[N, max_neighbors]`` (global
atom indices) together with a ``num_neighbors`` tensor of shape
``[N]``. Used by Warp interaction kernels (e.g. Lennard-Jones) that
benefit from fixed-width rows.
"""
COO = "coo"
MATRIX = "matrix"
class NeighborConfig(BaseModel):
"""Configuration for on-the-fly neighbor list construction.
An instance of this class attached to a :class:`ModelCard` signals that
the model requires a neighbor list and describes the format and parameters
it expects. At runtime a :class:`~nvalchemi.dynamics.hooks.NeighborListHook`
reads this config to compute and cache the appropriate neighbor data.
Attributes
----------
cutoff : float
Interaction cutoff radius in the same length units as positions.
format : NeighborListFormat
Whether to build a dense neighbor matrix (``MATRIX``) or a sparse
edge-index list (``COO``). Defaults to ``COO``.
half_list : bool
If ``True``, each pair ``(i, j)`` with ``i < j`` appears only once.
Newton's third law is applied inside the interaction kernel to recover
forces on both atoms. Defaults to ``False``.
skin : float
Verlet skin distance. The neighbor list is only rebuilt when any atom
has moved more than ``skin / 2`` since the last build. Set to ``0.0``
(default) to rebuild every step.
max_neighbors : int | None
Maximum number of neighbors per atom. Required when
``format=MATRIX``; ignored for ``COO``.
algorithm : str
Neighbor-finding algorithm. ``"auto"`` (default) selects naïve
O(N²) search for small systems and a cell-list algorithm for larger
ones. Explicit choices are ``"naive"`` and ``"cell_list"``.
"""
cutoff: float
format: NeighborListFormat = NeighborListFormat.COO
half_list: bool = False
skin: float = 0.0
max_neighbors: int | None = None
[docs]
class ModelConfig(BaseModel):
"""
Configuration structure for a given model.
All models that inherit from `BaseModelMixin` should have a `model_config`
attribute that is an instance of this class, which can be used to
change the behavior of the model.
Attributes
----------
compute_forces : bool, default True
Set to enable or disable force computation.
compute_stresses : bool, default False
Set to enable or disable stress computation.
compute_hessians : bool, default False
Set to enable or disable Hessian computation.
compute_dipoles : bool, default False
Set to enable or disable dipole computation.
gradient_keys : set[str], default set()
Set of keys to enable gradients for in the `Batch` of `AtomicData` structure.
"""
compute_forces: Annotated[
bool,
Field(description="Set to enable or disable force computation."),
] = True
compute_stresses: Annotated[
bool,
Field(description="Set to enable or disable stress computation."),
] = False
compute_hessians: Annotated[
bool,
Field(description="Set to enable or disable Hessian computation."),
] = False
compute_dipoles: Annotated[
bool,
Field(description="Set to enable or disable dipole computation."),
] = False
compute_charges: Annotated[
bool,
Field(description="Set to enable or disable charge computation."),
] = False
compute_embeddings: Annotated[
bool,
Field(description="Set to enable or disable embedding computation."),
] = False
compute_energies: Annotated[
bool,
Field(description="Set to enable or disable energies computation."),
] = True
gradient_keys: Annotated[
set[str],
Field(
description="Set of keys to compute gradients for in the `Batch` of `AtomicData` structure..",
default_factory=set,
),
]
[docs]
class ModelCard(BaseModel):
"""
Model card for a given model.
This model card is a Pydantic model that contains information about the model's
capabilities and requirements.
A new model wrapper should return this data structure as the `model_card` property.
"""
forces_via_autograd: Annotated[
bool, Field(description="Whether the model predicts forces via autograd.")
]
supports_node_embeddings: Annotated[
bool, Field(description="Whether the model supports computing embeddings.")
] = False
supports_edge_embeddings: Annotated[
bool, Field(description="Whether the model supports computing edge embeddings.")
] = False
supports_graph_embeddings: Annotated[
bool,
Field(description="Whether the model supports computing graph embeddings."),
] = False
supports_energies: Annotated[
bool, Field(description="Whether the model supports energies computation.")
] = True
supports_forces: Annotated[
bool, Field(description="Whether the model supports forces computation.")
] = False
supports_stresses: Annotated[
bool,
Field(description="Whether the model supports stresses/virials computation."),
] = False
supports_hessians: Annotated[
bool,
Field(
description="Whether the model supports computing the Hessians of the energy."
),
] = False
supports_pbc: Annotated[
bool,
Field(description="Whether the model supports periodic boundary conditions."),
] = False
needs_pbc: Annotated[
bool,
Field(
description="Whether the model needs periodic boundary conditions parameters as part of its input."
),
]
needs_node_charges: Annotated[
bool,
Field(
description="Whether the model needs partial atomic charges as part of its input."
),
] = False
needs_system_charges: Annotated[
bool,
Field(
description="Whether the model needs the total system charge as part of its input."
),
] = False
supports_dipoles: Annotated[
bool,
Field(
description="Whether the model explicitly supports computing the dipole moments."
),
] = False
supports_non_batch: Annotated[
bool, Field(description="Whether the model supports non-batch input.")
] = False
neighbor_config: Annotated[
NeighborConfig | None,
Field(
description=(
"Neighbor list requirements for this model. ``None`` means the "
"model does not use a neighbor list. When set, a "
"``NeighborListHook`` should be registered with the dynamics "
"engine to supply the required neighbor data before each "
"``compute()`` call."
)
),
] = None
includes_dispersion: Annotated[
bool,
Field(
description="Whether the model already incorporates dispersion (e.g. D3) in its energy."
),
] = False
includes_long_range_electrostatics: Annotated[
bool,
Field(
description="Whether the model already incorporates long-range electrostatics in its energy."
),
] = False
model_config = ConfigDict(frozen=True, extra="forbid")
@property
def needs_neighborlist(self) -> bool:
"""Convenience accessor: ``True`` when the model requires a neighbor list."""
return self.neighbor_config is not None
# Keys in ModelConfig that correspond to computable output properties.
# Used by output_data() to avoid per-call model_dump() serialization.
_COMPUTE_OUTPUT_KEYS: tuple[str, ...] = (
"forces",
"stresses",
"hessians",
"dipoles",
"charges",
"energies",
)
[docs]
class BaseModelMixin(abc.ABC):
"""
Abstract MixIn class providing a homogenized interface for wrapper models
from external machine learning interatomic potential projects.
This mixin defines the core interface that all external model wrappers
should implement to ensure consistency across different model types.
The mixin provides abstract methods for:
- Computing embeddings at different graph levels
- Predicting energies and forces
- Defining expected output shapes
- Adapting inputs and outputs between framework and external model formats
A concrete implementation of this mixin should utilize the following
functions to implement predictions:
- ``_adapt_input``, which adapts the input batch to the model's expected format
- ``_adapt_output``, which adapts the model's output to the framework's expected format
- ``validate_batch``, which ensures that the input batch is compatible with the model
- ``compute_embeddings``, which computes embeddings at different graph levels
The mixin also defines several properties that must be implemented to specify
model capabilities; when adding a new model, these properties must be implemented.
- ``model_card``: Pydantic model that contains information about the model's
capabilities and requirements
- ``embedding_shapes``: Expected shapes of node, edge, and graph embeddings
The workflow for using this mixin is:
1. Implement all required properties to specify model capabilities
2. Implement ``_adapt_input`` to convert framework data to model format
3. Implement ``parse_output`` to convert model output to framework format
4. Implement prediction methods based on supported capabilities
5. Use ``validate_batch`` to ensure input compatibility
6. Call ``parse_output`` to write model outputs to the ``Batch`` data structure
Raises
------
NotImplementedError
If any required abstract methods or properties are not implemented
ValueError
If input validation fails in `validate_batch`
"""
# model_config must be set as an instance attribute in each subclass __init__:
# self.model_config = ModelConfig()
# There is intentionally NO class-level default to prevent all instances from
# sharing a single ModelConfig object (which would cause mutations in one wrapper
# to silently affect all others).
@property
@abc.abstractmethod
def model_card(self) -> ModelCard:
"""Retrieves the model card for the model.
The model card is a Pydantic model that contains
information about the model's capabilities and requirements.
"""
...
@property
@abc.abstractmethod
def embedding_shapes(self) -> dict[str, tuple[int, ...]]:
"""Retrieves the expected shapes of the node, edge, and graph embeddings."""
...
[docs]
@abc.abstractmethod
def compute_embeddings(
self, data: AtomicData | Batch, **kwargs: Any
) -> AtomicData | Batch:
"""
Compute embeddings at different levels of a batch of atomic graphs.
This method should extract meaningful representations from the model
at node (atomic), edge (bond), and/or graph/system (structure) levels.
The concrete implementation should check if the model supports
computing embeddings, as well as perform validation on `kwargs`
to make sure they are valid for the model.
The method should add graph, node, and/or edge embeddings to the `Batch`
data structure in-place.
Parameters
----------
data : AtomicData | Batch
Input atomic data containing positions, atomic numbers, etc.
Returns
-------
AtomicData | Batch
Standardized `AtomicData` or `Batch` data structure mutated in place.
Raises
------
NotImplementedError
If the model does not support embeddings computation
"""
...
[docs]
def adapt_output(self, model_output: Any, data: AtomicData | Batch) -> ModelOutputs:
"""
Adapt external model output to the framework's standard output format (ModelOutputs).
This implementation returns a ModelOutputs (OrderedDict) with keys from output_data(),
initialized to None, and populates with values from model_output if present and if we
can match the key names generically. It is unlikely that this will perfectly match
key names for all models, so it is imperative to manually check and override this
implementation in a subclass.
Parameters
----------
model_output : Any
Raw output from the external model
data : AtomicData | Batch
Original input data (may be needed for context/metadata)
Returns
-------
ModelOutputs
OrderedDict with expected output keys and their values (or None if not present).
"""
output = OrderedDict((key, None) for key in self.output_data())
if isinstance(model_output, dict):
for key in output:
value = model_output.get(key, None)
if value is not None:
# insert key-specific logic here
match key:
case "energies":
if value.ndim == 1:
# energies need to be [N, 1] shape
value.unsqueeze_(-1)
case _:
pass
output[key] = value
return output
[docs]
def add_output_head(self, prefix: str) -> None:
"""
Add an output head to the model.
This method should create an multilayer perceptron block for
mapping input embeddings to a desired output shape. The logic
for this should differentiate based on invariant/equivariant
models - specifically those that use `e3nn` layers.
The method should then save the output head to a `output_heads`
`ModuleDict` attribute.
Parameters
----------
prefix : str
Prefix for the output head
"""
raise NotImplementedError
@staticmethod
def _verify_request(
model_config: ModelConfig,
model_card: ModelCard,
key: str,
) -> bool:
"""
Verify that a requested computation is supported by the model.
This method checks if a specific computation (forces, stresses, dipoles, hessians, or charges)
is both requested in the model configuration and supported by the model card.
If the computation is requested but not supported, it logs a warning.
Parameters
----------
model_config : ModelConfig
The model configuration containing computation settings.
model_card : ModelCard
The model card containing capability information.
key : str
The type of computation to verify.
Returns
-------
bool
True if the computation is both requested and supported by the model, False otherwise.
"""
is_requested = getattr(model_config, f"compute_{key}")
is_supported = getattr(model_card, f"supports_{key}")
if is_requested and not is_supported:
warnings.warn(
f"Model does not support {key}, but compute_{key} is set to True.",
UserWarning,
)
return is_requested and is_supported
[docs]
def output_data(self) -> set[str]:
"""
Returns a set of keys that are expected to be computed by the model
and written to the `AtomicData` or `Batch` data structure.
This method provides the base logic that is generally common across
all models, but can be overridden by subclasses to add more expected
keys.
Returns
-------
set[str]
Set of keys that are expected to be computed by the model
and written to the `AtomicData` or `Batch` data structure.
"""
expected_keys = set()
for key in _COMPUTE_OUTPUT_KEYS:
if getattr(self.model_config, f"compute_{key}", False):
if self._verify_request(self.model_config, self.model_card, key):
expected_keys.add(key)
return expected_keys
[docs]
def export_model(self, path: Path, as_state_dict: bool = False) -> None:
"""
Export the current model without the ``BaseModelMixin`` interface.
The idea behind this method is to allow users to use the trained
model with the same interface as the corresponding 'upstream' version,
so that they can re-use validation code that might have been written
for the upstream case (e.g. ``ase.Calculator`` instances).
Essentially, this method should recreate the equivalent base class
(by checking MRO), then run ``torch.save`` and serialize the
model either directly or as its ``state_dict``.
"""
raise NotImplementedError
def __add__(self, other: "BaseModelMixin") -> "ComposableModelWrapper":
"""Compose two models additively via the ``+`` operator.
Returns an :class:`ComposableModelWrapper` that sums energies, forces,
and stresses from both models.
Parameters
----------
other : BaseModelMixin
Another model to add.
"""
from nvalchemi.models.composable import ComposableModelWrapper # noqa: PLC0415
return ComposableModelWrapper(self, other)
[docs]
def make_neighbor_hooks(self) -> list:
"""Return a list of :class:`~nvalchemi.dynamics.hooks.NeighborListHook` instances
for this model's neighbor configuration.
Returns an empty list if the model does not require a neighbor list.
Defers the import to avoid circular imports.
"""
from nvalchemi.dynamics.hooks import NeighborListHook # noqa: PLC0415
nc = self.model_card.neighbor_config
if nc is None:
return []
return [NeighborListHook(nc)]