nvalchemi.models.mace.MACEWrapper#

class nvalchemi.models.mace.MACEWrapper(model)[source]#

Wrapper for any MACE model implementing the BaseModelMixin interface.

Accepts any MACE model variant (MACE, ScaleShiftMACE, cuEq-converted models, torch.compile-d models, etc.). The wrapper handles:

  • One-hot node_attrs encoding via a pre-built GPU lookup table (no CPU round-trip per step).

  • Gradient enabling on positions for conservative force / stress computation.

  • PBC via both unit_shifts (integer image indices) and pre-computed shifts (physical Å vectors from unit_shifts @ cell) passed to MACE. shifts is always required; unit_shifts is additionally consumed when compute_displacement=True (stress path).

Parameters:

model (nn.Module) – An instantiated MACE model. Any subclass of mace.modules.MACE is accepted. The wrapper mirrors the model’s training/eval state.

model#

The underlying MACE model.

Type:

nn.Module

model_config#

Mutable configuration controlling which outputs are computed.

Type:

ModelConfig

adapt_input(data, **kwargs)[source]#

Build the input dict expected by MACE.forward.

Handles AtomicData Batch promotion, node_attrs encoding, gradient enabling on positions, transposing edge_index from nvalchemi’s [E, 2] to MACE’s [2, E] convention, zero-filling of unit_shifts / cell for non-PBC systems, and pre-computation of physical shifts vectors from unit_shifts @ cell.

Note

This method does not call super().adapt_input() because Batch does not implement model_dump(), which the base implementation requires. Gradient enabling on positions is handled manually here instead.

Parameters:
Return type:

dict[str, Any]

adapt_output(raw_output, data)[source]#

Map MACE output keys to nvalchemi standard keys.

MACE uses "energy" / "stress" / "hessian"; nvalchemi expects "energies" / "stresses" / "hessians". Renaming happens before calling super() so the base auto-mapper sees the canonical key names.

Parameters:
Return type:

OrderedDict[str, Float[Tensor, ‘B 1’] | Float[Tensor, ‘V 3’] | Float[Tensor, ‘V 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3’] | None]

compute_embeddings(data, **kwargs)[source]#

Compute node and graph embeddings without forces or stresses.

Writes node_embeddings (shape [N, hidden_dim]) and graph_embeddings (shape [B, hidden_dim], sum-pooled over atoms) into data in-place and returns it. Does not mutate model_config.

Parameters:
Return type:

AtomicData | Batch

property cutoff: float#

Interaction cutoff in Angstroms, read from model.r_max.

property embedding_shapes: dict[str, tuple[int, ...]]#

Retrieves the expected shapes of the node, edge, and graph embeddings.

export_model(path, as_state_dict=False)[source]#

Serialize the underlying MACE model without the wrapper.

The exported file can be reloaded as a plain MACE nn.Module and used with the standard MACE / ASE interface.

Parameters:
  • path (Path) – Output path.

  • as_state_dict (bool, optional) – If True, save only the state_dict; otherwise pickle the full model object. Defaults to False.

Return type:

None

forward(data, **kwargs)[source]#

Run the MACE model and return the output.

Parameters:
Return type:

OrderedDict[str, Float[Tensor, ‘B 1’] | Float[Tensor, ‘V 3’] | Float[Tensor, ‘V 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3’] | None]

classmethod from_checkpoint(checkpoint_path, device=torch.device('cpu'), enable_cueq=False, dtype=None, compile_model=False, **compile_kwargs)[source]#

Load a MACE model from a checkpoint and return a MACEWrapper.

Accepts local file paths or named MACE-MP foundation-model checkpoints (e.g. "medium-0b2"), which are downloaded automatically to the MACE cache directory.

Operations are applied in this order to avoid numerical issues:

  1. Loadtorch.load the checkpoint.

  2. cuEq — convert to cuEquivariance format (must happen while the model is still in its original dtype, because extract_config_mace_model reads the dtype via torch.set_default_dtype).

  3. dtype — cast all weights (including atomic energies) uniformly to the requested dtype.

  4. compiletorch.compile; freezes parameters and sets eval mode. The model is inference-only after this step.

Parameters:
  • checkpoint_path (Path | str) – Local path to a .pt file, or a named checkpoint string such as "medium-0b2".

  • device (torch.device, optional) – Target device. Defaults to CPU.

  • enable_cueq (bool, optional) – Convert to cuEquivariance format for GPU speedup. Requires the cuequivariance package.

  • dtype (torch.dtype | None, optional) – If set, cast model weights to this dtype after cuEq conversion.

  • compile_model (bool, optional) – Apply torch.compile. Sets eval mode and freezes parameters; the model is inference-only after this step.

  • **compile_kwargs – Forwarded to torch.compile.

Return type:

MACEWrapper

Raises:

ImportError – If mace-torch is not installed, or if enable_cueq=True and cuequivariance is not installed.

property model_card: ModelCard#

Retrieves the model card for the model.

The model card is a Pydantic model that contains information about the model’s capabilities and requirements.