# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MACE model wrapper.
Wraps any MACE model (``MACE``, ``ScaleShiftMACE``, etc.) as a
:class:`~nvalchemi.models.base.BaseModelMixin`-compatible wrapper, ready for
use in any :class:`~nvalchemi.dynamics.base.BaseDynamics` engine or standalone
inference / fine-tuning.
Usage
-----
Load a named foundation-model checkpoint::
from nvalchemi.models.mace import MACEWrapper
import torch
model = MACEWrapper.from_checkpoint("medium-0b2", device=torch.device("cuda"))
Or wrap an already-instantiated model::
mace_model = torch.load("my_mace.pt", weights_only=False)
model = MACEWrapper(mace_model)
For dynamics, register :class:`~nvalchemi.dynamics.hooks.NeighborListHook`
with ``format=NeighborListFormat.COO`` so that ``edge_index`` and
``unit_shifts`` are populated before each model call::
from nvalchemi.models.base import NeighborConfig, NeighborListFormat
from nvalchemi.dynamics.hooks import NeighborListHook
nl_hook = NeighborListHook(model.model_card.neighbor_config)
dynamics.register_hook(nl_hook)
dynamics.model = model
Notes
-----
* Forces are computed **conservatively** via MACE's internal autograd, so
:attr:`~ModelCard.forces_via_autograd` is ``True``.
* ``node_attrs`` (one-hot atomic-number encodings) are computed via a
pre-built GPU lookup table — no CPU round-trips per step.
* For PBC systems, both ``unit_shifts`` (integer image indices ``[E, 3]``)
and pre-computed ``shifts`` (physical Å vectors ``[E, 3]``) are passed to
MACE. ``shifts`` is always required by ``prepare_graph``; ``unit_shifts``
is additionally used when ``compute_displacement=True`` (stress path).
"""
from __future__ import annotations
import warnings
from importlib.metadata import version
from pathlib import Path
from typing import Any
import torch
from torch import nn
from nvalchemi._typing import ModelOutputs
from nvalchemi.data import AtomicData, Batch
from nvalchemi.models.base import (
BaseModelMixin,
ModelCard,
ModelConfig,
NeighborConfig,
NeighborListFormat,
)
try:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
from mace.calculators.foundations_models import download_mace_mp_checkpoint
from mace.cli.convert_e3nn_cueq import run as _convert_mace_weights
_MACE_AVAILABLE = True
except ImportError:
_MACE_AVAILABLE = False
download_mace_mp_checkpoint = None
_convert_mace_weights = None
_torch_version = version("torch")
__all__ = ["MACEWrapper"]
[docs]
class MACEWrapper(nn.Module, BaseModelMixin):
"""Wrapper for any MACE model implementing the :class:`~nvalchemi.models.base.BaseModelMixin` interface.
Accepts any MACE model variant (``MACE``, ``ScaleShiftMACE``, cuEq-converted
models, ``torch.compile``-d models, etc.). The wrapper handles:
* One-hot ``node_attrs`` encoding via a pre-built GPU lookup table
(no CPU round-trip per step).
* Gradient enabling on ``positions`` for conservative force / stress
computation.
* PBC via both ``unit_shifts`` (integer image indices) and pre-computed
``shifts`` (physical Å vectors from ``unit_shifts @ cell``) passed to
MACE. ``shifts`` is always required; ``unit_shifts`` is additionally
consumed when ``compute_displacement=True`` (stress path).
Parameters
----------
model : nn.Module
An instantiated MACE model. Any subclass of ``mace.modules.MACE``
is accepted. The wrapper mirrors the model's training/eval state.
Attributes
----------
model : nn.Module
The underlying MACE model.
model_config : ModelConfig
Mutable configuration controlling which outputs are computed.
"""
model: nn.Module
def __init__(self, model: nn.Module) -> None:
if not _MACE_AVAILABLE:
raise ImportError(
"mace-torch is required to use MACEWrapper. "
"Install it with: pip install 'nvalchemi-toolkit[mace]'"
)
super().__init__()
self.model = model
self.train(mode=model.training)
self.model_config = ModelConfig()
# Cache the model dtype — determined at construction, stable thereafter.
self._cached_model_dtype: torch.dtype = next(model.parameters()).dtype
# Pre-build a one-hot lookup table: shape [max_z + 1, num_elements].
# At runtime, node_attrs = _node_emb.index_select(0, atomic_numbers)
# — a single GPU op, no CPU round-trips.
z_table: list[int] = model.atomic_numbers.tolist()
node_emb = torch.zeros(max(z_table) + 1, len(z_table))
for i, z in enumerate(z_table):
node_emb[z, i] = 1.0
# Cast to model device+dtype so _node_attrs needs no per-step conversion.
# Must use the model's device here: from_checkpoint moves the inner model
# to the target device before calling cls(model), so the buffer must be
# placed on that device from construction rather than relying on a
# subsequent .to() call that never happens.
model_device = next(model.parameters()).device
node_emb = node_emb.to(device=model_device, dtype=self._cached_model_dtype)
# persistent=False: derived from model.atomic_numbers, excluded from
# state_dict but still tracked for device / dtype moves.
self.register_buffer("_node_emb", node_emb, persistent=False)
self._model_card: ModelCard = self._build_model_card()
# ------------------------------------------------------------------
# BaseModelMixin required properties
# ------------------------------------------------------------------
def _build_model_card(self) -> ModelCard:
return ModelCard(
forces_via_autograd=True,
supports_energies=True,
supports_forces=True,
supports_stresses=True,
supports_pbc=True,
needs_pbc=False,
supports_non_batch=True,
supports_node_embeddings=True,
supports_graph_embeddings=True,
neighbor_config=NeighborConfig(
cutoff=self.cutoff,
format=NeighborListFormat.COO,
half_list=False,
),
)
@property
def model_card(self) -> ModelCard:
return self._model_card
@property
def embedding_shapes(self) -> dict[str, tuple[int, ...]]:
hidden_dim: int = self.model.products[0].linear.irreps_out.dim
return {
"node_embeddings": (hidden_dim,),
"graph_embeddings": (hidden_dim,),
}
# ------------------------------------------------------------------
# Convenience properties
# ------------------------------------------------------------------
@property
def cutoff(self) -> float:
"""Interaction cutoff in Angstroms, read from ``model.r_max``."""
r_max = self.model.r_max
return r_max.item() if isinstance(r_max, torch.Tensor) else float(r_max)
@property
def _model_dtype(self) -> torch.dtype:
"""Return the current dtype of the model's parameters (live, not cached).
Reading from parameters() directly ensures this stays correct after
`.half()` or `.to(dtype=...)` calls post-construction.
Note: calling `.to(dtype=...)` after construction with cuEquivariance or
`torch.compile` enabled is unsupported and may produce incorrect results.
Use `from_checkpoint` with the desired `dtype` parameter instead.
"""
try:
return next(self.parameters()).dtype
except StopIteration:
return torch.float32
# ------------------------------------------------------------------
# Input / output adaptation
# ------------------------------------------------------------------
def _node_attrs(self, data: Batch) -> torch.Tensor:
"""One-hot encode atomic numbers via the pre-built lookup table.
Uses a single ``index_select`` on GPU — no CPU round-trips.
``_node_emb`` is already on the correct device and dtype (set at
construction and kept in sync by ``nn.Module``'s ``.to()``
machinery), so no per-step device/dtype conversion is needed.
"""
return self._node_emb.index_select(0, data.atomic_numbers.long())
[docs]
def adapt_output(
self, raw_output: dict[str, Any], data: AtomicData | Batch
) -> ModelOutputs:
"""Map MACE output keys to nvalchemi standard keys.
MACE uses ``"energy"`` / ``"stress"`` / ``"hessian"``; nvalchemi
expects ``"energies"`` / ``"stresses"`` / ``"hessians"``.
Renaming happens *before* calling ``super()`` so the base auto-mapper
sees the canonical key names.
"""
energy = raw_output["energy"]
mapped: dict[str, Any] = {
"energies": energy.unsqueeze(-1) if energy.ndim == 1 else energy,
}
if raw_output.get("forces") is not None:
mapped["forces"] = raw_output["forces"]
if raw_output.get("stress") is not None:
mapped["stresses"] = raw_output["stress"]
if raw_output.get("hessian") is not None:
mapped["hessians"] = raw_output["hessian"]
return super().adapt_output(mapped, data)
# ------------------------------------------------------------------
# Forward pass
# ------------------------------------------------------------------
[docs]
def forward(self, data: AtomicData | Batch, **kwargs: Any) -> ModelOutputs:
"""Run the MACE model and return the output."""
model_inputs = self.adapt_input(data, **kwargs)
compute_forces = self._verify_request(
self.model_config, self.model_card, "forces"
)
compute_stresses = self._verify_request(
self.model_config, self.model_card, "stresses"
)
raw_output = self.model.forward(
model_inputs,
compute_force=compute_forces,
compute_stress=compute_stresses,
# compute_displacement enables the MACE displacement trick required
# for stress computation via autograd through cell @ unit_shifts.
compute_displacement=compute_stresses,
training=self.training,
)
return self.adapt_output(raw_output, data)
# ------------------------------------------------------------------
# Embeddings
# ------------------------------------------------------------------
[docs]
def compute_embeddings(
self, data: AtomicData | Batch, **kwargs: Any
) -> AtomicData | Batch:
"""Compute node and graph embeddings without forces or stresses.
Writes ``node_embeddings`` (shape ``[N, hidden_dim]``) and
``graph_embeddings`` (shape ``[B, hidden_dim]``, sum-pooled over atoms)
into *data* in-place and returns it. Does **not** mutate
``model_config``.
"""
if isinstance(data, AtomicData):
data = Batch.from_data_list([data])
model_inputs = self.adapt_input(data, **kwargs)
# Pass flags as local kwargs — never mutate self.model_config.
raw_output = self.model.forward(
model_inputs,
compute_force=False,
compute_stress=False,
compute_displacement=False,
training=False,
)
node_feats = raw_output.get("node_feats")
if node_feats is None:
raise RuntimeError(
"MACE model did not return 'node_feats'. "
"Ensure the model is a standard MACE variant."
)
# Write node embeddings directly to the atoms group to avoid the
# default "system" routing in MultiLevelStorage for unknown keys.
# If we wrote via `data.node_embeddings = ...`, it would land in the
# system group (batch_size = [N]) and then block the graph_embeddings
# write (batch_size = [B]) from going to the same group.
atoms_group = data._atoms_group
if atoms_group is not None:
atoms_group["node_embeddings"] = node_feats
else:
data.node_embeddings = node_feats
hidden_dim = node_feats.shape[-1]
graph_embeddings = torch.zeros(
data.num_graphs,
hidden_dim,
device=node_feats.device,
dtype=node_feats.dtype,
)
graph_embeddings.scatter_add_(
0,
data.batch.long().unsqueeze(-1).expand(-1, hidden_dim),
node_feats,
)
data.graph_embeddings = graph_embeddings
return data
# ------------------------------------------------------------------
# Checkpoint loading
# ------------------------------------------------------------------
[docs]
@classmethod
def from_checkpoint(
cls,
checkpoint_path: Path | str,
device: torch.device = torch.device("cpu"),
enable_cueq: bool = False,
dtype: torch.dtype | None = None,
compile_model: bool = False,
**compile_kwargs: Any,
) -> "MACEWrapper":
"""Load a MACE model from a checkpoint and return a :class:`MACEWrapper`.
Accepts local file paths or named MACE-MP foundation-model checkpoints
(e.g. ``"medium-0b2"``), which are downloaded automatically to the
MACE cache directory.
Operations are applied in this order to avoid numerical issues:
1. **Load** — ``torch.load`` the checkpoint.
2. **cuEq** — convert to cuEquivariance format (must happen while the
model is still in its original dtype, because
``extract_config_mace_model`` reads the dtype via
``torch.set_default_dtype``).
3. **dtype** — cast all weights (including atomic energies) uniformly
to the requested dtype.
4. **compile** — ``torch.compile``; freezes parameters and sets eval
mode. The model is **inference-only** after this step.
Parameters
----------
checkpoint_path : Path | str
Local path to a ``.pt`` file, or a named checkpoint string such as
``"medium-0b2"``.
device : torch.device, optional
Target device. Defaults to CPU.
enable_cueq : bool, optional
Convert to cuEquivariance format for GPU speedup. Requires the
``cuequivariance`` package.
dtype : torch.dtype | None, optional
If set, cast model weights to this dtype after cuEq conversion.
compile_model : bool, optional
Apply ``torch.compile``. Sets eval mode and freezes parameters;
the model is **inference-only** after this step.
**compile_kwargs
Forwarded to ``torch.compile``.
Returns
-------
MACEWrapper
Raises
------
ImportError
If ``mace-torch`` is not installed, or if ``enable_cueq=True``
and ``cuequivariance`` is not installed.
"""
if not _MACE_AVAILABLE:
raise ImportError(
"mace-torch is required for MACEWrapper.from_checkpoint. "
"Install it with: pip install 'nvalchemi-toolkit[mace]'"
)
cached_path = download_mace_mp_checkpoint(checkpoint_path)
model: nn.Module = torch.load(
cached_path, weights_only=False, map_location=device
)
# Step 1: cuEq conversion before dtype change.
if enable_cueq:
try:
import cuequivariance # noqa: F401
except ImportError:
raise ImportError(
"cuequivariance is required for enable_cueq=True. "
"Install it with: pip install 'nvalchemi-toolkit[mace]'"
)
model = _convert_mace_weights(model, return_model=True, device=device)
# Step 2: dtype conversion.
if dtype is not None:
model.to(dtype=dtype)
model = model.to(device)
# Step 3: torch.compile — inference-only after this point.
if compile_model:
if _torch_version.startswith("2.8"):
warnings.warn(
"torch.compile has known issues with e3nn in torch 2.8. "
"You may need to patch e3nn before compiling:\n"
" sed -i '238s/raise NotImplementedError/return 2/' "
"<site-packages>/e3nn/o3/_irreps.py",
stacklevel=2,
)
model.eval()
for param in model.parameters():
param.requires_grad = False
model = torch.compile(model, **compile_kwargs)
return cls(model)
# ------------------------------------------------------------------
# Export
# ------------------------------------------------------------------
[docs]
def export_model(self, path: Path, as_state_dict: bool = False) -> None:
"""Serialize the underlying MACE model without the wrapper.
The exported file can be reloaded as a plain MACE ``nn.Module`` and
used with the standard MACE / ASE interface.
Parameters
----------
path : Path
Output path.
as_state_dict : bool, optional
If ``True``, save only the ``state_dict``; otherwise pickle the
full model object. Defaults to ``False``.
"""
if as_state_dict:
torch.save(self.model.state_dict(), path)
else:
torch.save(self.model, path)