nvalchemi.models.aimnet2.AIMNet2Wrapper#
- class nvalchemi.models.aimnet2.AIMNet2Wrapper(model)[source]#
Wrapper for AIMNet2 interatomic potentials.
Energy is always computed as the primitive differentiable output via the raw AIMNet2 model. Forces and stresses are derived from energy via autograd. Partial charges and node embeddings (AIM features) are taken directly from the model outputs.
The wrapper declares an external MATRIX-format neighbor list requirement at the model’s AEV cutoff. The
NeighborListHook(or the pipeline’s synthesized hook) populatesneighbor_matrixon the batch before each forward pass. The wrapper converts this to AIMNet2’s internalnbmatformat (with a padding row for the padding atom).Coulomb and D3 dispersion are disabled. Use
PipelineModelWrapperto compose AIMNet2 with electrostatics or dispersion models.- Parameters:
model (nn.Module) – An AIMNet2 model (loaded from checkpoint or instantiated directly). Use
from_checkpoint()for the common construction path.
- model_config#
Configuration with capability and runtime fields.
- Type:
- model#
The underlying AIMNet2 model. If you want your model to be compiled, wrap with
torch.compile(model, **kwargs)before passing here.- Type:
nn.Module
- adapt_input(data, **kwargs)[source]#
Build the flat-padded input dict expected by the AIMNet2 model.
Handles:
AtomicData→Batchpromotion.Gradient enabling on positions when autograd outputs are active.
Collecting positions, numbers, charges, cell from the batch.
Converting the batch’s
neighbor_matrix(fromNeighborListHook) to AIMNet2’s internalnbmatformat by appending a padding row.Running
mol_flattenandpad_inputto produce the flat-padded layout the model architecture expects.
Note
This method does not call
super().adapt_input()because AIMNet2 uses its own input key conventions (coord,numbers,nbmat) rather than the framework’s standard keys.- Parameters:
data (AtomicData | Batch) – Input batch with positions, atomic_numbers, charge, and neighbor_matrix / num_neighbors (from NeighborListHook).
kwargs (Any)
- Returns:
Flat-padded dict ready for
self._calculator.model().- Return type:
dict[str, Any]
- adapt_output(model_output, data)[source]#
Map AIMNet2 outputs to nvalchemi standard keys.
- Parameters:
model_output (dict[str, Any])
data (AtomicData | Batch)
- Return type:
OrderedDict[str, Float[Tensor, ‘B 1’] | Float[Tensor, ‘V 3’] | Float[Tensor, ‘V 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3’] | None]
- compute_embeddings(data, **kwargs)[source]#
Compute AIMNet2 AIM feature embeddings and attach to data.
- Parameters:
data (AtomicData | Batch)
kwargs (Any)
- Return type:
- property embedding_shapes: dict[str, tuple[int, ...]]#
Return AIMNet2 AIM feature embedding shapes.
- export_model(path, as_state_dict=False)[source]#
Export the raw AIMNet2 model.
- Parameters:
path (Path)
as_state_dict (bool)
- Return type:
None
- extra_repr()#
Format the model config for
nn.Module.__repr__.- Parameters:
self (Any)
- Return type:
str
- forward(data, **kwargs)[source]#
Run the AIMNet2 model and return outputs.
Energy is always computed as the primitive differentiable output via the raw model. Forces and stresses are derived from energy via autograd when requested.
For stresses, the affine strain trick is applied before the forward pass using
prepare_strain(). This scales positions and cell through a displacement tensor so thatdE/d(displacement)gives the strain derivative.In a pipeline with
use_autograd=True, the pipeline handles derivative computation externally — it strips forces/stresses fromactive_outputsso this method only computes energy.- Parameters:
data (AtomicData | Batch) – Input batch with positions, atomic_numbers, charge, and neighbor_matrix (from NeighborListHook).
kwargs (Any)
- Returns:
OrderedDict with requested output keys.
- Return type:
ModelOutputs
- classmethod from_checkpoint(checkpoint_path, device='cpu', compile_model=False, **compile_kwargs)[source]#
Load an AIMNet2 model and return a wrapped instance.
Uses
AIMNet2Calculatorto resolve and load the checkpoint, then extracts the rawnn.Moduleand wraps it.- Parameters:
checkpoint_path (str | Path) – Path to an AIMNet2 checkpoint file, or a model alias recognized by
AIMNet2Calculator(e.g."aimnet2").device (torch.device | str, optional) – Target device. Defaults to
"cpu".compile_model (bool, optional) – Apply
torch.compile. Sets eval mode and freezes parameters; the model is inference-only after this step.**compile_kwargs – Forwarded to
torch.compile.
- Return type: