nvalchemi.models.mace.MACEWrapper#
- class nvalchemi.models.mace.MACEWrapper(model)[source]#
Wrapper for any MACE model implementing the
BaseModelMixininterface.Accepts any MACE model variant (
MACE,ScaleShiftMACE, cuEq-converted models,torch.compile-d models, etc.). The wrapper handles:One-hot
node_attrsencoding via a pre-built GPU lookup table (no CPU round-trip per step).Gradient enabling on
positionsfor conservative force / stress computation.PBC via both
unit_shifts(integer image indices) and pre-computedshifts(physical Å vectors fromunit_shifts @ cell) passed to MACE.shiftsis always required;unit_shiftsis additionally consumed whencompute_displacement=True(stress path).
- Parameters:
model (nn.Module) – An instantiated MACE model. Any subclass of
mace.modules.MACEis accepted. The wrapper mirrors the model’s training/eval state.
- model#
The underlying MACE model.
- Type:
nn.Module
- model_config#
Mutable configuration controlling which outputs are computed.
- Type:
- adapt_input(data, **kwargs)[source]#
Build the input dict expected by
MACE.forward.Handles
AtomicData → Batchpromotion,node_attrsencoding, gradient enabling onpositions, transposingedge_indexfrom nvalchemi’s[E, 2]to MACE’s[2, E]convention, zero-filling ofunit_shifts/cellfor non-PBC systems, and pre-computation of physicalshiftsvectors fromunit_shifts @ cell.Note
This method does not call
super().adapt_input()becauseBatchdoes not implementmodel_dump(), which the base implementation requires. Gradient enabling onpositionsis handled manually here instead.- Parameters:
data (AtomicData | Batch)
kwargs (Any)
- Return type:
dict[str, Any]
- adapt_output(raw_output, data)[source]#
Map MACE output keys to nvalchemi standard keys.
MACE uses
"energy"/"stress"/"hessian"; nvalchemi expects"energies"/"stresses"/"hessians". Renaming happens before callingsuper()so the base auto-mapper sees the canonical key names.- Parameters:
raw_output (dict[str, Any])
data (AtomicData | Batch)
- Return type:
OrderedDict[str, Float[Tensor, ‘B 1’] | Float[Tensor, ‘V 3’] | Float[Tensor, ‘V 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3’] | None]
- compute_embeddings(data, **kwargs)[source]#
Compute node and graph embeddings without forces or stresses.
Writes
node_embeddings(shape[N, hidden_dim]) andgraph_embeddings(shape[B, hidden_dim], sum-pooled over atoms) into data in-place and returns it. Does not mutatemodel_config.- Parameters:
data (AtomicData | Batch)
kwargs (Any)
- Return type:
- property cutoff: float#
Interaction cutoff in Angstroms, read from
model.r_max.
- property embedding_shapes: dict[str, tuple[int, ...]]#
Retrieves the expected shapes of the node, edge, and graph embeddings.
- export_model(path, as_state_dict=False)[source]#
Serialize the underlying MACE model without the wrapper.
The exported file can be reloaded as a plain MACE
nn.Moduleand used with the standard MACE / ASE interface.- Parameters:
path (Path) – Output path.
as_state_dict (bool, optional) – If
True, save only thestate_dict; otherwise pickle the full model object. Defaults toFalse.
- Return type:
None
- forward(data, **kwargs)[source]#
Run the MACE model and return the output.
- Parameters:
data (AtomicData | Batch)
kwargs (Any)
- Return type:
OrderedDict[str, Float[Tensor, ‘B 1’] | Float[Tensor, ‘V 3’] | Float[Tensor, ‘V 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3’] | None]
- classmethod from_checkpoint(checkpoint_path, device=torch.device('cpu'), enable_cueq=False, dtype=None, compile_model=False, **compile_kwargs)[source]#
Load a MACE model from a checkpoint and return a
MACEWrapper.Accepts local file paths or named MACE-MP foundation-model checkpoints (e.g.
"medium-0b2"), which are downloaded automatically to the MACE cache directory.Operations are applied in this order to avoid numerical issues:
Load —
torch.loadthe checkpoint.cuEq — convert to cuEquivariance format (must happen while the model is still in its original dtype, because
extract_config_mace_modelreads the dtype viatorch.set_default_dtype).dtype — cast all weights (including atomic energies) uniformly to the requested dtype.
compile —
torch.compile; freezes parameters and sets eval mode. The model is inference-only after this step.
- Parameters:
checkpoint_path (Path | str) – Local path to a
.ptfile, or a named checkpoint string such as"medium-0b2".device (torch.device, optional) – Target device. Defaults to CPU.
enable_cueq (bool, optional) – Convert to cuEquivariance format for GPU speedup. Requires the
cuequivariancepackage.dtype (torch.dtype | None, optional) – If set, cast model weights to this dtype after cuEq conversion.
compile_model (bool, optional) – Apply
torch.compile. Sets eval mode and freezes parameters; the model is inference-only after this step.**compile_kwargs – Forwarded to
torch.compile.
- Return type:
- Raises:
ImportError – If
mace-torchis not installed, or ifenable_cueq=Trueandcuequivarianceis not installed.