Source code for nvalchemi.models.mace

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MACE model wrapper.

Wraps any MACE model (``MACE``, ``ScaleShiftMACE``, etc.) as a
:class:`~nvalchemi.models.base.BaseModelMixin`-compatible wrapper, ready for
use in any :class:`~nvalchemi.dynamics.base.BaseDynamics` engine or standalone
inference / fine-tuning.

Usage
-----
Load a named foundation-model checkpoint::

    from nvalchemi.models.mace import MACEWrapper
    import torch

    model = MACEWrapper.from_checkpoint("medium-0b2", device=torch.device("cuda"))

Or wrap an already-instantiated model::

    mace_model = torch.load("my_mace.pt", weights_only=False)
    model = MACEWrapper(mace_model)

For dynamics, register :class:`~nvalchemi.dynamics.hooks.NeighborListHook`
with ``format=NeighborListFormat.COO`` so that ``edge_index`` and
``unit_shifts`` are populated before each model call::

    from nvalchemi.models.base import NeighborConfig, NeighborListFormat
    from nvalchemi.dynamics.hooks import NeighborListHook

    nl_hook = NeighborListHook(model.model_card.neighbor_config)
    dynamics.register_hook(nl_hook)
    dynamics.model = model

Notes
-----
* Forces are computed **conservatively** via MACE's internal autograd, so
  :attr:`~ModelCard.forces_via_autograd` is ``True``.
* ``node_attrs`` (one-hot atomic-number encodings) are computed via a
  pre-built GPU lookup table — no CPU round-trips per step.
* For PBC systems, both ``unit_shifts`` (integer image indices ``[E, 3]``)
  and pre-computed ``shifts`` (physical Å vectors ``[E, 3]``) are passed to
  MACE.  ``shifts`` is always required by ``prepare_graph``; ``unit_shifts``
  is additionally used when ``compute_displacement=True`` (stress path).
"""

from __future__ import annotations

import warnings
from importlib.metadata import version
from pathlib import Path
from typing import Any

import torch
from torch import nn

from nvalchemi._typing import ModelOutputs
from nvalchemi.data import AtomicData, Batch
from nvalchemi.models.base import (
    BaseModelMixin,
    ModelCard,
    ModelConfig,
    NeighborConfig,
    NeighborListFormat,
)

try:
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning)
        from mace.calculators.foundations_models import download_mace_mp_checkpoint
        from mace.cli.convert_e3nn_cueq import run as _convert_mace_weights

    _MACE_AVAILABLE = True
except ImportError:
    _MACE_AVAILABLE = False
    download_mace_mp_checkpoint = None
    _convert_mace_weights = None

_torch_version = version("torch")

__all__ = ["MACEWrapper"]


[docs] class MACEWrapper(nn.Module, BaseModelMixin): """Wrapper for any MACE model implementing the :class:`~nvalchemi.models.base.BaseModelMixin` interface. Accepts any MACE model variant (``MACE``, ``ScaleShiftMACE``, cuEq-converted models, ``torch.compile``-d models, etc.). The wrapper handles: * One-hot ``node_attrs`` encoding via a pre-built GPU lookup table (no CPU round-trip per step). * Gradient enabling on ``positions`` for conservative force / stress computation. * PBC via both ``unit_shifts`` (integer image indices) and pre-computed ``shifts`` (physical Å vectors from ``unit_shifts @ cell``) passed to MACE. ``shifts`` is always required; ``unit_shifts`` is additionally consumed when ``compute_displacement=True`` (stress path). Parameters ---------- model : nn.Module An instantiated MACE model. Any subclass of ``mace.modules.MACE`` is accepted. The wrapper mirrors the model's training/eval state. Attributes ---------- model : nn.Module The underlying MACE model. model_config : ModelConfig Mutable configuration controlling which outputs are computed. """ model: nn.Module def __init__(self, model: nn.Module) -> None: if not _MACE_AVAILABLE: raise ImportError( "mace-torch is required to use MACEWrapper. " "Install it with: pip install 'nvalchemi-toolkit[mace]'" ) super().__init__() self.model = model self.train(mode=model.training) self.model_config = ModelConfig() # Cache the model dtype — determined at construction, stable thereafter. self._cached_model_dtype: torch.dtype = next(model.parameters()).dtype # Pre-build a one-hot lookup table: shape [max_z + 1, num_elements]. # At runtime, node_attrs = _node_emb.index_select(0, atomic_numbers) # — a single GPU op, no CPU round-trips. z_table: list[int] = model.atomic_numbers.tolist() node_emb = torch.zeros(max(z_table) + 1, len(z_table)) for i, z in enumerate(z_table): node_emb[z, i] = 1.0 # Cast to model device+dtype so _node_attrs needs no per-step conversion. # Must use the model's device here: from_checkpoint moves the inner model # to the target device before calling cls(model), so the buffer must be # placed on that device from construction rather than relying on a # subsequent .to() call that never happens. model_device = next(model.parameters()).device node_emb = node_emb.to(device=model_device, dtype=self._cached_model_dtype) # persistent=False: derived from model.atomic_numbers, excluded from # state_dict but still tracked for device / dtype moves. self.register_buffer("_node_emb", node_emb, persistent=False) self._model_card: ModelCard = self._build_model_card() # ------------------------------------------------------------------ # BaseModelMixin required properties # ------------------------------------------------------------------ def _build_model_card(self) -> ModelCard: return ModelCard( forces_via_autograd=True, supports_energies=True, supports_forces=True, supports_stresses=True, supports_pbc=True, needs_pbc=False, supports_non_batch=True, supports_node_embeddings=True, supports_graph_embeddings=True, neighbor_config=NeighborConfig( cutoff=self.cutoff, format=NeighborListFormat.COO, half_list=False, ), ) @property def model_card(self) -> ModelCard: return self._model_card @property def embedding_shapes(self) -> dict[str, tuple[int, ...]]: hidden_dim: int = self.model.products[0].linear.irreps_out.dim return { "node_embeddings": (hidden_dim,), "graph_embeddings": (hidden_dim,), } # ------------------------------------------------------------------ # Convenience properties # ------------------------------------------------------------------ @property def cutoff(self) -> float: """Interaction cutoff in Angstroms, read from ``model.r_max``.""" r_max = self.model.r_max return r_max.item() if isinstance(r_max, torch.Tensor) else float(r_max) @property def _model_dtype(self) -> torch.dtype: """Return the current dtype of the model's parameters (live, not cached). Reading from parameters() directly ensures this stays correct after `.half()` or `.to(dtype=...)` calls post-construction. Note: calling `.to(dtype=...)` after construction with cuEquivariance or `torch.compile` enabled is unsupported and may produce incorrect results. Use `from_checkpoint` with the desired `dtype` parameter instead. """ try: return next(self.parameters()).dtype except StopIteration: return torch.float32 # ------------------------------------------------------------------ # Input / output adaptation # ------------------------------------------------------------------ def _node_attrs(self, data: Batch) -> torch.Tensor: """One-hot encode atomic numbers via the pre-built lookup table. Uses a single ``index_select`` on GPU — no CPU round-trips. ``_node_emb`` is already on the correct device and dtype (set at construction and kept in sync by ``nn.Module``'s ``.to()`` machinery), so no per-step device/dtype conversion is needed. """ return self._node_emb.index_select(0, data.atomic_numbers.long())
[docs] def adapt_input(self, data: AtomicData | Batch, **kwargs: Any) -> dict[str, Any]: """Build the input dict expected by ``MACE.forward``. Handles ``AtomicData → Batch`` promotion, ``node_attrs`` encoding, gradient enabling on ``positions``, transposing ``edge_index`` from nvalchemi's ``[E, 2]`` to MACE's ``[2, E]`` convention, zero-filling of ``unit_shifts`` / ``cell`` for non-PBC systems, and pre-computation of physical ``shifts`` vectors from ``unit_shifts @ cell``. .. note:: This method does **not** call ``super().adapt_input()`` because :class:`~nvalchemi.data.Batch` does not implement ``model_dump()``, which the base implementation requires. Gradient enabling on ``positions`` is handled manually here instead. """ if isinstance(data, AtomicData): data = Batch.from_data_list([data]) dtype = self._model_dtype device = data.positions.device # nvalchemi stores edge_index as [E, 2]; MACE expects [2, E]. edge_index = data.edge_index.long().T # [2, E] E = edge_index.shape[1] B = data.num_graphs # Cast positions to model dtype, then enable gradients on the converted # tensor. We always clone before enabling grad so that data.positions # is never mutated in-place (which would happen when dtype already # matches and .to() returns the same storage). positions = data.positions.to(dtype=dtype) if self.model_config.compute_forces or self.model_config.compute_stresses: positions = positions.clone() positions.requires_grad_(True) # unit_shifts: integer PBC image indices [E, 3], cast to float for # MACE's cell @ unit_shifts contraction. Zero for non-PBC systems. unit_shifts_raw = getattr(data, "unit_shifts", None) if unit_shifts_raw is None: unit_shifts = torch.zeros(E, 3, dtype=dtype, device=device) else: unit_shifts = unit_shifts_raw.to(dtype=dtype, device=device) # cell: [B, 3, 3]. Identity matrix for non-PBC systems. cell_raw = getattr(data, "cell", None) if cell_raw is None: cell = ( torch.eye(3, dtype=dtype, device=device) .unsqueeze(0) .expand(B, -1, -1) .contiguous() ) else: cell = cell_raw.to(dtype=dtype, device=device) # Pre-compute physical shift vectors [E, 3]. # MACE's prepare_graph always reads data["shifts"] (physical Å vectors) # directly; it only recomputes them internally when # compute_displacement=True (stress path). We must supply "shifts" for # the energy/force-only path. # Convention: shifts[e] = unit_shifts[e] @ cell[graph_of_sender_e] # matching get_symmetric_displacement in mace.modules.utils. sender = edge_index[0] # [E] — source node indices batch_per_edge = data.batch[sender] shifts = torch.einsum("eb,ebc->ec", unit_shifts, cell[batch_per_edge]) return { "positions": positions, "node_attrs": self._node_attrs(data), # MACE requires int64 for graph-topology tensors. "batch": data.batch.long(), "ptr": data.ptr.long(), "edge_index": edge_index, # [2, E] — MACE convention "unit_shifts": unit_shifts, "shifts": shifts, "cell": cell, }
[docs] def adapt_output( self, raw_output: dict[str, Any], data: AtomicData | Batch ) -> ModelOutputs: """Map MACE output keys to nvalchemi standard keys. MACE uses ``"energy"`` / ``"stress"`` / ``"hessian"``; nvalchemi expects ``"energies"`` / ``"stresses"`` / ``"hessians"``. Renaming happens *before* calling ``super()`` so the base auto-mapper sees the canonical key names. """ energy = raw_output["energy"] mapped: dict[str, Any] = { "energies": energy.unsqueeze(-1) if energy.ndim == 1 else energy, } if raw_output.get("forces") is not None: mapped["forces"] = raw_output["forces"] if raw_output.get("stress") is not None: mapped["stresses"] = raw_output["stress"] if raw_output.get("hessian") is not None: mapped["hessians"] = raw_output["hessian"] return super().adapt_output(mapped, data)
# ------------------------------------------------------------------ # Forward pass # ------------------------------------------------------------------
[docs] def forward(self, data: AtomicData | Batch, **kwargs: Any) -> ModelOutputs: """Run the MACE model and return the output.""" model_inputs = self.adapt_input(data, **kwargs) compute_forces = self._verify_request( self.model_config, self.model_card, "forces" ) compute_stresses = self._verify_request( self.model_config, self.model_card, "stresses" ) raw_output = self.model.forward( model_inputs, compute_force=compute_forces, compute_stress=compute_stresses, # compute_displacement enables the MACE displacement trick required # for stress computation via autograd through cell @ unit_shifts. compute_displacement=compute_stresses, training=self.training, ) return self.adapt_output(raw_output, data)
# ------------------------------------------------------------------ # Embeddings # ------------------------------------------------------------------
[docs] def compute_embeddings( self, data: AtomicData | Batch, **kwargs: Any ) -> AtomicData | Batch: """Compute node and graph embeddings without forces or stresses. Writes ``node_embeddings`` (shape ``[N, hidden_dim]``) and ``graph_embeddings`` (shape ``[B, hidden_dim]``, sum-pooled over atoms) into *data* in-place and returns it. Does **not** mutate ``model_config``. """ if isinstance(data, AtomicData): data = Batch.from_data_list([data]) model_inputs = self.adapt_input(data, **kwargs) # Pass flags as local kwargs — never mutate self.model_config. raw_output = self.model.forward( model_inputs, compute_force=False, compute_stress=False, compute_displacement=False, training=False, ) node_feats = raw_output.get("node_feats") if node_feats is None: raise RuntimeError( "MACE model did not return 'node_feats'. " "Ensure the model is a standard MACE variant." ) # Write node embeddings directly to the atoms group to avoid the # default "system" routing in MultiLevelStorage for unknown keys. # If we wrote via `data.node_embeddings = ...`, it would land in the # system group (batch_size = [N]) and then block the graph_embeddings # write (batch_size = [B]) from going to the same group. atoms_group = data._atoms_group if atoms_group is not None: atoms_group["node_embeddings"] = node_feats else: data.node_embeddings = node_feats hidden_dim = node_feats.shape[-1] graph_embeddings = torch.zeros( data.num_graphs, hidden_dim, device=node_feats.device, dtype=node_feats.dtype, ) graph_embeddings.scatter_add_( 0, data.batch.long().unsqueeze(-1).expand(-1, hidden_dim), node_feats, ) data.graph_embeddings = graph_embeddings return data
# ------------------------------------------------------------------ # Checkpoint loading # ------------------------------------------------------------------
[docs] @classmethod def from_checkpoint( cls, checkpoint_path: Path | str, device: torch.device = torch.device("cpu"), enable_cueq: bool = False, dtype: torch.dtype | None = None, compile_model: bool = False, **compile_kwargs: Any, ) -> "MACEWrapper": """Load a MACE model from a checkpoint and return a :class:`MACEWrapper`. Accepts local file paths or named MACE-MP foundation-model checkpoints (e.g. ``"medium-0b2"``), which are downloaded automatically to the MACE cache directory. Operations are applied in this order to avoid numerical issues: 1. **Load** — ``torch.load`` the checkpoint. 2. **cuEq** — convert to cuEquivariance format (must happen while the model is still in its original dtype, because ``extract_config_mace_model`` reads the dtype via ``torch.set_default_dtype``). 3. **dtype** — cast all weights (including atomic energies) uniformly to the requested dtype. 4. **compile** — ``torch.compile``; freezes parameters and sets eval mode. The model is **inference-only** after this step. Parameters ---------- checkpoint_path : Path | str Local path to a ``.pt`` file, or a named checkpoint string such as ``"medium-0b2"``. device : torch.device, optional Target device. Defaults to CPU. enable_cueq : bool, optional Convert to cuEquivariance format for GPU speedup. Requires the ``cuequivariance`` package. dtype : torch.dtype | None, optional If set, cast model weights to this dtype after cuEq conversion. compile_model : bool, optional Apply ``torch.compile``. Sets eval mode and freezes parameters; the model is **inference-only** after this step. **compile_kwargs Forwarded to ``torch.compile``. Returns ------- MACEWrapper Raises ------ ImportError If ``mace-torch`` is not installed, or if ``enable_cueq=True`` and ``cuequivariance`` is not installed. """ if not _MACE_AVAILABLE: raise ImportError( "mace-torch is required for MACEWrapper.from_checkpoint. " "Install it with: pip install 'nvalchemi-toolkit[mace]'" ) cached_path = download_mace_mp_checkpoint(checkpoint_path) model: nn.Module = torch.load( cached_path, weights_only=False, map_location=device ) # Step 1: cuEq conversion before dtype change. if enable_cueq: try: import cuequivariance # noqa: F401 except ImportError: raise ImportError( "cuequivariance is required for enable_cueq=True. " "Install it with: pip install 'nvalchemi-toolkit[mace]'" ) model = _convert_mace_weights(model, return_model=True, device=device) # Step 2: dtype conversion. if dtype is not None: model.to(dtype=dtype) model = model.to(device) # Step 3: torch.compile — inference-only after this point. if compile_model: if _torch_version.startswith("2.8"): warnings.warn( "torch.compile has known issues with e3nn in torch 2.8. " "You may need to patch e3nn before compiling:\n" " sed -i '238s/raise NotImplementedError/return 2/' " "<site-packages>/e3nn/o3/_irreps.py", stacklevel=2, ) model.eval() for param in model.parameters(): param.requires_grad = False model = torch.compile(model, **compile_kwargs) return cls(model)
# ------------------------------------------------------------------ # Export # ------------------------------------------------------------------
[docs] def export_model(self, path: Path, as_state_dict: bool = False) -> None: """Serialize the underlying MACE model without the wrapper. The exported file can be reloaded as a plain MACE ``nn.Module`` and used with the standard MACE / ASE interface. Parameters ---------- path : Path Output path. as_state_dict : bool, optional If ``True``, save only the ``state_dict``; otherwise pickle the full model object. Defaults to ``False``. """ if as_state_dict: torch.save(self.model.state_dict(), path) else: torch.save(self.model, path)