nvalchemi.models.mace.MACEWrapper#
- class nvalchemi.models.mace.MACEWrapper(model)[source]#
Wrapper for any MACE model implementing the
BaseModelMixininterface.Accepts any MACE model variant (
MACE,ScaleShiftMACE, cuEq-converted models,torch.compile-d models, etc.). The wrapper handles:One-hot
node_attrsencoding via a pre-built GPU lookup table (no CPU round-trip per step).Gradient enabling on
positionsfor conservative force / stress computation.PBC via both
neighbor_list_shifts(integer image indices) and pre-computedshifts(physical Å vectors fromneighbor_list_shifts @ cell) passed to MACE.shiftsis always required;neighbor_list_shiftsis additionally consumed whencompute_displacement=True(stress path).
- Parameters:
model (nn.Module) – An instantiated MACE model. Any subclass of
mace.modules.MACEis accepted.
- model#
The underlying MACE model.
- Type:
nn.Module
- model_config#
Mutable configuration controlling which outputs are computed.
- Type:
- adapt_input(data, **kwargs)[source]#
Build the input dict expected by
MACE.forward.Handles
AtomicData -> Batchpromotion,node_attrsencoding, gradient enabling onpositions, transposingedge_indexfrom nvalchemi’s[E, 2]to MACE’s[2, E]convention, zero-filling ofneighbor_list_shifts/cellfor non-PBC systems, and pre-computation of physicalshiftsvectors fromneighbor_list_shifts @ cell.Expects COO neighbor data (
neighbor_list, optionallyneighbor_list_shifts) to be present on the batch. When used in aPipelineModelWrapper, the pipeline handles format conversion and cutoff filtering before calling this model.Note
This method does not call
super().adapt_input()becauseBatchdoes not implementmodel_dump(), which the base implementation requires. Gradient enabling onpositionsis handled manually here instead.- Parameters:
data (AtomicData | Batch)
kwargs (Any)
- Return type:
dict[str, Any]
- adapt_output(raw_output, data)[source]#
Map MACE output keys to nvalchemi standard keys.
MACE uses
"energy"/"stress"/"hessian"; nvalchemi expects"energy"/"stress"/"hessian". Renaming happens before callingsuper()so the base auto-mapper sees the canonical key names.- Parameters:
raw_output (dict[str, Any])
data (AtomicData | Batch)
- Return type:
OrderedDict[str, Float[Tensor, ‘B 1’] | Float[Tensor, ‘V 3’] | Float[Tensor, ‘V 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3’] | None]
- compute_embeddings(data, **kwargs)[source]#
Compute node and graph embeddings without forces or stresses.
Writes
node_embeddings(shape[N, hidden_dim]) andgraph_embeddings(shape[B, hidden_dim], sum-pooled over atoms) into data in-place and returns it. Does not mutatemodel_config.- Parameters:
data (AtomicData | Batch)
kwargs (Any)
- Return type:
- property cutoff: float#
Interaction cutoff in Angstroms, read from
model.r_max.
- property embedding_shapes: dict[str, tuple[int, ...]]#
Retrieves the expected shapes of the node, edge, and graph embeddings.
- export_model(path, as_state_dict=False)[source]#
Serialize the underlying MACE model without the wrapper.
The exported file can be reloaded as a plain MACE
nn.Moduleand used with the standard MACE / ASE interface.- Parameters:
path (Path) – Output path.
as_state_dict (bool, optional) – If
True, save only thestate_dict; otherwise pickle the full model object. Defaults toFalse.
- Return type:
None
- extra_repr()#
Format the model config for
nn.Module.__repr__.- Parameters:
self (Any)
- Return type:
str
- forward(data, **kwargs)[source]#
Run the MACE model and return the output.
- Parameters:
data (AtomicData | Batch)
kwargs (Any)
- Return type:
OrderedDict[str, Float[Tensor, ‘B 1’] | Float[Tensor, ‘V 3’] | Float[Tensor, ‘V 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3’] | None]
- classmethod from_checkpoint(checkpoint_path, device=torch.device('cpu'), enable_cueq=False, dtype=None, compile_model=False, **compile_kwargs)[source]#
Load a MACE model from a checkpoint and return a
MACEWrapper.Accepts local file paths or named MACE-MP foundation-model checkpoints (e.g.
"medium-0b2"), which are downloaded automatically to the MACE cache directory.Operations are applied in this order:
Load —
torch.loadthe checkpoint to the specified device.dtype — cast model weights to the requested dtype.
cuEq — convert to cuEquivariance format for GPU speedup.
compile —
torch.compile; freezes parameters and sets eval mode. The model is inference-only after this step.
For best GPU throughput, use
device=torch.device("cuda"),enable_cueq=True,dtype=torch.float32, andcompile_model=True. Example:model = MACEWrapper.from_checkpoint( "medium-mpa-0", device=torch.device("cuda"), dtype=torch.float32, enable_cueq=True, compile_model=True, )
- Parameters:
checkpoint_path (Path | str) – Local path to a
.ptfile, or a named checkpoint string such as"medium-0b2".device (torch.device, optional) – Target device. Defaults to CPU.
enable_cueq (bool, optional) – Convert to cuEquivariance format for GPU speedup. Defaults to
False. Requires thecuequivariancepackage.dtype (torch.dtype | None, optional) – If set, cast model weights to this dtype before cuEq conversion.
compile_model (bool, optional) – Apply
torch.compile. Sets eval mode and freezes parameters; the model is inference-only after this step.**compile_kwargs – Forwarded to
torch.compile.
- Return type:
- Raises:
ImportError – If
mace-torchis not installed, or ifenable_cueq=Trueandcuequivarianceis not installed.