nvalchemi.models.aimnet2.AIMNet2Wrapper#

class nvalchemi.models.aimnet2.AIMNet2Wrapper(model)[source]#

Wrapper for AIMNet2 interatomic potentials.

Energy is always computed as the primitive differentiable output via the raw AIMNet2 model. Forces and stresses are derived from energy via autograd. Partial charges and node embeddings (AIM features) are taken directly from the model outputs.

The wrapper declares an external MATRIX-format neighbor list requirement at the model’s AEV cutoff. The NeighborListHook (or the pipeline’s synthesized hook) populates neighbor_matrix on the batch before each forward pass. The wrapper converts this to AIMNet2’s internal nbmat format (with a padding row for the padding atom).

Coulomb and D3 dispersion are disabled. Use PipelineModelWrapper to compose AIMNet2 with electrostatics or dispersion models.

Parameters:

model (nn.Module) – An AIMNet2 model (loaded from checkpoint or instantiated directly). Use from_checkpoint() for the common construction path.

model_config#

Configuration with capability and runtime fields.

Type:

ModelConfig

model#

The underlying AIMNet2 model. If you want your model to be compiled, wrap with torch.compile(model, **kwargs) before passing here.

Type:

nn.Module

adapt_input(data, **kwargs)[source]#

Build the flat-padded input dict expected by the AIMNet2 model.

Handles:

  1. AtomicDataBatch promotion.

  2. Gradient enabling on positions when autograd outputs are active.

  3. Collecting positions, numbers, charges, cell from the batch.

  4. Converting the batch’s neighbor_matrix (from NeighborListHook) to AIMNet2’s internal nbmat format by appending a padding row.

  5. Running mol_flatten and pad_input to produce the flat-padded layout the model architecture expects.

Note

This method does not call super().adapt_input() because AIMNet2 uses its own input key conventions (coord, numbers, nbmat) rather than the framework’s standard keys.

Parameters:
  • data (AtomicData | Batch) – Input batch with positions, atomic_numbers, charge, and neighbor_matrix / num_neighbors (from NeighborListHook).

  • kwargs (Any)

Returns:

Flat-padded dict ready for self._calculator.model().

Return type:

dict[str, Any]

adapt_output(model_output, data)[source]#

Map AIMNet2 outputs to nvalchemi standard keys.

Parameters:
Return type:

OrderedDict[str, Float[Tensor, ‘B 1’] | Float[Tensor, ‘V 3’] | Float[Tensor, ‘V 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3 3’] | Float[Tensor, ‘B 3’] | None]

compute_embeddings(data, **kwargs)[source]#

Compute AIMNet2 AIM feature embeddings and attach to data.

Parameters:
Return type:

AtomicData | Batch

property embedding_shapes: dict[str, tuple[int, ...]]#

Return AIMNet2 AIM feature embedding shapes.

export_model(path, as_state_dict=False)[source]#

Export the raw AIMNet2 model.

Parameters:
  • path (Path)

  • as_state_dict (bool)

Return type:

None

extra_repr()#

Format the model config for nn.Module.__repr__.

Parameters:

self (Any)

Return type:

str

forward(data, **kwargs)[source]#

Run the AIMNet2 model and return outputs.

Energy is always computed as the primitive differentiable output via the raw model. Forces and stresses are derived from energy via autograd when requested.

For stresses, the affine strain trick is applied before the forward pass using prepare_strain(). This scales positions and cell through a displacement tensor so that dE/d(displacement) gives the strain derivative.

In a pipeline with use_autograd=True, the pipeline handles derivative computation externally — it strips forces/stresses from active_outputs so this method only computes energy.

Parameters:
  • data (AtomicData | Batch) – Input batch with positions, atomic_numbers, charge, and neighbor_matrix (from NeighborListHook).

  • kwargs (Any)

Returns:

OrderedDict with requested output keys.

Return type:

ModelOutputs

classmethod from_checkpoint(checkpoint_path, device='cpu', compile_model=False, **compile_kwargs)[source]#

Load an AIMNet2 model and return a wrapped instance.

Uses AIMNet2Calculator to resolve and load the checkpoint, then extracts the raw nn.Module and wraps it.

Parameters:
  • checkpoint_path (str | Path) – Path to an AIMNet2 checkpoint file, or a model alias recognized by AIMNet2Calculator (e.g. "aimnet2").

  • device (torch.device | str, optional) – Target device. Defaults to "cpu".

  • compile_model (bool, optional) – Apply torch.compile. Sets eval mode and freezes parameters; the model is inference-only after this step.

  • **compile_kwargs – Forwarded to torch.compile.

Return type:

AIMNet2Wrapper