Source code for nvalchemi.data.atomic_data

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import numbers
import warnings
from collections.abc import Sequence
from hashlib import blake2s
from typing import TYPE_CHECKING, Annotated, Any, ClassVar

import numpy as np
import periodictable as pt
import torch
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, model_validator

from nvalchemi import OptionalDependency
from nvalchemi import _typing as t
from nvalchemi.data.data import DataMixin  # type: ignore

if TYPE_CHECKING:
    from ase import Atoms
    from pymatgen.core import Molecule, Structure


def _tensor_serialization(tensor: torch.Tensor) -> list[float | int | list]:
    """
    Map a PyTorch tensor to JSON serializable values.

    Parameters
    ----------
    tensor: torch.Tensor
        The tensor to serialize.

    Returns
    -------
    list[float | int] | None
        The serialized tensor, or None if *tensor* is None.
    """
    if tensor is None:
        return None
    return tensor.detach().cpu().tolist()


class AtomicNumberTable:
    """
    Atomic number table
    """

    def __init__(self, zs: Sequence[int]):
        self.zs = zs

    def __len__(self) -> int:
        return len(self.zs)

    def __str__(self) -> str:
        return f"AtomicNumberTable: {tuple(s for s in self.zs)}"

    def index_to_z(self, index: int) -> int:
        """
        Convert index to atomic number
        """
        return self.zs[index]

    def z_to_index(self, atomic_number: str) -> int:
        """
        Convert atomic number to index
        """
        return self.zs.index(atomic_number)


[docs] class AtomicData(BaseModel, DataMixin): """Atomic data structure for molecular systems. Represents molecular systems as graphs with atomic properties and interactions. Uses Pydantic for validation and serialization, with DataMixin for graph functionality. Attributes ---------- atomic_numbers : torch.Tensor Atomic numbers of each atom [n_nodes] positions : torch.Tensor Cartesian coordinates [n_nodes, 3] atomic_masses : torch.Tensor Atomic masses [n_nodes] neighbor_list : torch.Tensor Neighbor list [n_edges, 2] node_attrs : torch.Tensor Node attributes [n_nodes, n_node_feats] shifts : torch.Tensor Cartesian displacement vectors for each edge [n_edges, 3], computed as ``neighbor_list_shifts @ cell``. neighbor_list_shifts : torch.Tensor Integer lattice image indices for periodic edges [n_edges, 3]. neighbor_matrix : torch.Tensor Dense neighbor matrix [n_nodes, max_neighbors] neighbor_matrix_shifts : torch.Tensor Periodic shifts for the dense neighbor matrix [n_nodes, max_neighbors, 3] num_neighbors : torch.Tensor Number of valid neighbors per atom [n_nodes] cell : torch.Tensor Unit cell vectors [3, 3] pbc : torch.Tensor Periodic boundary conditions [3] forces : torch.Tensor Atomic forces [n_nodes, 3] energy : torch.Tensor Total energy [1] stress : torch.Tensor Stress tensor [1, 3, 3] virial : torch.Tensor Virial tensor [1, 3, 3] dipole : torch.Tensor Dipole moment [1, 3] charges : torch.Tensor Partial atomic charges [n_nodes] charge : torch.Tensor Total system charge [1] info : dict Additional information about the system """ # Required fields atomic_numbers: Annotated[ t.AtomicNumbers, Field(description="Atomic numbers for each node [n_nodes]"), PlainSerializer(_tensor_serialization, when_used="json"), ] positions: Annotated[ t.NodePositions, Field(description="Cartesian coordinates for each atom [n_nodes, 3]"), PlainSerializer(_tensor_serialization, when_used="json"), ] # Optional fields with defaults atomic_masses: Annotated[ t.AtomicMasses | None, Field(description="Atomic masses [n_nodes]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None atom_categories: Annotated[ list[t.AtomCategory] | t.AtomCategories | None, Field( description="Atom categorical index, based on _typing.AtomCategory Enum [n_nodes]" ), PlainSerializer(_tensor_serialization, when_used="json"), ] = None neighbor_list: Annotated[ t.NeighborList | None, Field(description="Neighbor list [n_edges, 2]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None shifts: Annotated[ t.PeriodicShifts | None, Field( description="Cartesian displacement vectors for each edge (neighbor_list_shifts @ cell) [n_edges, 3]" ), PlainSerializer(_tensor_serialization, when_used="json"), ] = None neighbor_list_shifts: Annotated[ t.NeighborListShifts | None, Field( description="Integer lattice image indices for periodic edges [n_edges, 3]" ), PlainSerializer(_tensor_serialization, when_used="json"), ] = None neighbor_matrix: Annotated[ t.NeighborMatrix | None, Field(description="Dense neighbor matrix [n_nodes, max_neighbors]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None neighbor_matrix_shifts: Annotated[ t.NeighborMatrixShifts | None, Field( description="Periodic shifts for the dense neighbor matrix [n_nodes, max_neighbors, 3]" ), PlainSerializer(_tensor_serialization, when_used="json"), ] = None num_neighbors: Annotated[ t.NumNeighbors | None, Field(description="Number of valid neighbors per atom [n_nodes]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None cell: Annotated[ t.LatticeVectors | None, Field(description="Unit cell vectors [3, 3]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None pbc: Annotated[ t.Periodicity | None, Field( description="Boolean tensor indicating periodic boundary conditions along each dimension" ), PlainSerializer(_tensor_serialization, when_used="json"), ] = None forces: Annotated[ t.Forces | None, Field(description="Atomic forces [n_nodes, 3]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None energy: Annotated[ t.Energy | None, Field(description="Total energy [1]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None stress: Annotated[ t.Stress | None, Field(description="Cauchy stress W/V (eV/A^3) [1, 3, 3]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None virial: Annotated[ t.Virials | None, Field(description="Virial tensor [1, 3, 3]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None dipole: Annotated[ t.Dipole | None, Field(description="Dipole moment of the system."), PlainSerializer(_tensor_serialization, when_used="json"), ] = None charges: Annotated[ t.NodeCharges | None, Field(description="Partial atomic charges [n_nodes]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None charge: Annotated[ t.GraphCharges | None, Field(description="Total system charge [1]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None node_attrs: Annotated[ t.NodeAttributes | None, Field(description="Node attributes [n_nodes, n_node_attrs]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None node_alpha_spins: Annotated[ t.NodeSpins | None, Field( description="Alpha spins for each atom, [n_nodes, 1]. Use this field for closed-shell spins." ), PlainSerializer(_tensor_serialization, when_used="json"), ] = None node_beta_spins: Annotated[ t.NodeSpins | None, Field( description="Beta spins for each atom, [n_nodes, 1]. For restricted spin, use ``node_alpha_spins`` instead." ), PlainSerializer(_tensor_serialization, when_used="json"), ] = None spin: Annotated[ t.GraphSpins | None, Field(description="Spin or multiplicity value for the system, [1, 1]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None graph_alpha_spins: Annotated[ t.GraphSpins | None, Field(description="Alpha spins for the entire graph, [1, 1]"), PlainSerializer(_tensor_serialization, when_used="json"), ] = None node_embeddings: Annotated[ t.NodeEmbeddings | None, Field(description="Embeddings for each node within the batch/graph."), PlainSerializer(_tensor_serialization, when_used="json"), ] = None edge_embeddings: Annotated[ t.EdgeEmbeddings | None, Field(description="Embeddings for each edge within the batch/graph."), PlainSerializer(_tensor_serialization, when_used="json"), ] = None graph_embeddings: Annotated[ t.GraphEmbeddings | None, Field(description="Embeddings for the entire graph/graphs within a batch."), PlainSerializer(_tensor_serialization, when_used="json"), ] = None velocities: Annotated[ t.NodeVelocities | None, Field(description="Atomic velocities [n_nodes, 3], in units set by positions."), PlainSerializer(_tensor_serialization, when_used="json"), ] = None momenta: Annotated[ t.NodeMomentum | None, Field(description="Atomic momenta [n_nodes, 3], in units set by positions."), PlainSerializer(_tensor_serialization, when_used="json"), ] = None kinetic_energies: Annotated[ t.NodeKineticEnergies | None, Field( description="Per-atom kinetic energies [n_nodes, 1], with the same units as energy." ), PlainSerializer(_tensor_serialization, when_used="json"), ] = None info: dict[str, torch.Tensor] = Field(default_factory=dict) # "Node key" means dim(0) == num_nodes; tensors may have any rank. _default_node_keys: ClassVar[frozenset[str]] = frozenset( { "atomic_masses", "positions", "forces", "charges", "node_embeddings", "atomic_numbers", "node_attrs", "node_alpha_spins", "node_beta_spins", "atom_categories", "velocities", "momenta", "kinetic_energies", "neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors", } ) _default_edge_keys: ClassVar[frozenset[str]] = frozenset( {"shifts", "neighbor_list_shifts", "neighbor_list", "edge_embeddings"} ) _default_system_keys: ClassVar[frozenset[str]] = frozenset( { "energy", "stress", "virial", "dipole", "charge", "graph_embeddings", "cell", "pbc", "spin", } ) # Pydantic configuration model_config: ClassVar[ConfigDict] = ConfigDict( arbitrary_types_allowed=True, validate_assignment=True, extra="allow" )
[docs] def model_post_init(self, __context: Any) -> None: """Create per-instance mutable copies of the key sets. The class-level defaults are frozen to prevent accidental mutation. Each instance gets its own mutable set so that ``add_node_property`` and friends only affect the instance they are called on. Uses ``model_post_init`` rather than ``model_validator`` because ``validate_assignment=True`` causes model validators to re-run on every ``setattr`` call, which would reset the key sets and lose previously added custom keys. """ # Merge defaults with any dynamically-added keys passed during # construction (e.g. via model_validate or from_data_list round-trips). existing_node = set(getattr(self, "__node_keys__", ())) existing_edge = set(getattr(self, "__edge_keys__", ())) existing_system = set(getattr(self, "__system_keys__", ())) object.__setattr__( self, "__node_keys__", set(self._default_node_keys) | existing_node ) object.__setattr__( self, "__edge_keys__", set(self._default_edge_keys) | existing_edge ) object.__setattr__( self, "__system_keys__", set(self._default_system_keys) | existing_system )
[docs] @model_validator(mode="after") def check_node_consistency(self) -> AtomicData: """Validate that all node-level properties have consistent atom counts. This validator runs after all field validators and checks that any node-level property that is set has the same number of nodes as atomic_numbers. Returns ------- Self Returns self if validation passes. Raises ------ ValueError If any node-level property has an inconsistent number of nodes. """ num_atoms = len(self.atomic_numbers) node_keys = self.__dict__.get("__node_keys__", self._default_node_keys) for key in node_keys: tensor = getattr(self, key, None) if isinstance(tensor, torch.Tensor): if tensor.size(0) != num_atoms: raise ValueError( f"Inconsistent number of atoms in {key}: " f"expected {num_atoms}, got {tensor.shape[0]}" ) return self
[docs] @model_validator(mode="after") def check_edge_consistency(self) -> AtomicData: """Validate that all edge-level properties have consistent atom counts. This validator runs after all field validators and checks that any edge-level property that is set has the same number of edges as neighbor_list. Returns ------- Self Returns self if validation passes. Raises ------ ValueError If any edge-level property has an inconsistent number of edges. """ if not isinstance(self.neighbor_list, torch.Tensor): return self num_edges = self.neighbor_list.size(0) edge_keys = self.__dict__.get("__edge_keys__", self._default_edge_keys) for key in edge_keys: tensor = getattr(self, key, None) if isinstance(tensor, torch.Tensor): if tensor.size(0) != num_edges: raise ValueError( f"Inconsistent number of edges in {key}: " f"expected {num_edges}, got {tensor.shape[0]}" ) return self
[docs] @model_validator(mode="after") def check_fp_dtype_consistency(self) -> AtomicData: """ Ensures all floating point tensors are at the same precision as the positions tensor. """ dtype = self.positions.dtype casted: list[str] = [] for key in self.model_dump().keys(): value = getattr(self, key) if isinstance(value, torch.Tensor): tensor_dtype = value.dtype if tensor_dtype.is_floating_point and tensor_dtype != dtype: # using __dict__ to avoid re-validation self.__dict__[key] = value.to(dtype) casted.append(key) if casted: casted.sort() # Keep the warning attributed to the user's AtomicData(...) call # instead of Pydantic's internal validation frames. This may need # adjustment if Pydantic's construction stack changes. warnings.warn( f"AtomicData fields {casted} were cast from their original " f"dtypes to {dtype} to match positions. " f"Pass tensors with matching dtypes to silence this warning.", UserWarning, stacklevel=3, ) return self
[docs] @model_validator(mode="after") def use_default_masses(self) -> AtomicData: """ If no atomic masses are set, automatically fill in with default masses from ``periodictable``. Returns ------- Self Returns self if validation passes. """ if self.atomic_masses is None: masses_list = [pt.elements[int(n)].mass for n in self.atomic_numbers] # skip re-validation self.__dict__["atomic_masses"] = torch.as_tensor( masses_list, device=self.atomic_numbers.device, dtype=self.positions.dtype, ) return self
[docs] @model_validator(mode="after") def use_default_categories(self) -> AtomicData: """ Check to make sure categories for atoms are set. In the case that a list is passed, which should be validated by ``pydantic``, we will convert it to a tensor. """ if self.atom_categories is None: self.__dict__["atom_categories"] = torch.zeros_like( self.atomic_numbers, dtype=torch.long ) elif isinstance(self.atom_categories, list): if not isinstance(self.atom_categories[0], t.AtomCategory): raise ValueError( "Atom categories must be a list of `AtomCategory` enums" ) self.atom_categories = torch.as_tensor( [cat.value for cat in self.atom_categories], dtype=torch.long ) return self
[docs] @model_validator(mode="after") def use_default_velocities(self) -> AtomicData: """ If no velocities are set, initialize as zeros with proper shape and dtype. Returns ------- Self Returns self if validation passes. """ if self.velocities is None: # skip re-validation self.__dict__["velocities"] = torch.zeros_like(self.positions) return self
[docs] @model_validator(mode="after") def enforce_device_consistency(self) -> AtomicData: """ Enforces all tensors to be on the same device. In instances where the devices of atomic numbers and positions are different, we will try and promote them to offload over host CPU. """ # we will use atomic numbers and positions as the "ground truth" as # they are required fields base_devices = list( {self.atomic_numbers.device.type, self.positions.device.type} ) # sort the devices to be usable in a match statement base_devices = list(sorted(base_devices)) match base_devices: case ["cuda"]: target_device = torch.device("cuda") case ["mps"]: target_device = torch.device("mps") case ["cpu", "cuda"]: target_device = torch.device("cuda") case ["cpu", "mps"]: target_device = torch.device("mps") # fall back to CPU for all other cases case _: target_device = torch.device("cpu") tensor_devices = [ value.device.type for value in self.model_dump().values() if isinstance(value, torch.Tensor) ] if set(tensor_devices) != {target_device.type}: for key in ( self.__node_keys__ | self.__edge_keys__ | self.__system_keys__ | {"info"} ): value = getattr(self, key, None) if ( isinstance(value, torch.Tensor) and value.device.type != target_device.type ): # using __dict__ to avoid re-validation self.__dict__[key] = value.to(target_device, non_blocking=False) return self
def __getitem__(self, key: str) -> Any: return getattr(self, key) def __setitem__(self, key: str, value: Any) -> None: setattr(self, key, value) @property def device(self) -> torch.device: """Get the device of the positions tensor.""" return self.positions.device @property def dtype(self) -> torch.dtype: """Get the dtype of the positions tensor.""" return self.positions.dtype @property def node_properties(self) -> dict[str, Any]: """Get the node properties of the graph.""" return self.model_dump(include=self.__node_keys__, exclude_none=True) @property def edge_properties(self) -> dict[str, Any]: """Get the edge properties of the graph.""" return self.model_dump(include=self.__edge_keys__, exclude_none=True) @property def system_properties(self) -> dict[str, Any]: """Get the system properties of the graph.""" return self.model_dump(include=self.__system_keys__, exclude_none=True)
[docs] def add_node_property( self, key: str, value: torch.Tensor, node_dim: int = 0 ) -> None: """Add a node property to the graph.""" setattr(self, key, value) self.__node_keys__.add(key)
[docs] def add_edge_property(self, key: str, value: Any) -> None: """Add an edge property to the graph.""" setattr(self, key, value) self.__edge_keys__.add(key)
[docs] def add_system_property(self, key: str, value: Any) -> None: """Add a system property to the graph.""" setattr(self, key, value) self.__system_keys__.add(key)
@property def chemical_hash(self) -> str: """Generate a unique hash for the chemical system using the blake2s hashing algorithm. The hash is unique to a given atomic composition and structure, invariant to the ordering of atoms in the data. The hash also differentiates between periodic and non-periodic systems, and for the former, lattice vectors and directions of periodicity. Returns ------- str A ``blake2s`` hash string representing the chemical system. Notes ----- The hash is generated by: 1. Sorting atoms by atomic number to ensure invariance to atom ordering 2. Including atomic numbers and positions of sorted atoms 3. Including periodic boundary conditions and cell parameters if present 4. Computing a BLAKE2s hash of the formatted string representation """ atomic_numbers = self.atomic_numbers.cpu().numpy() sorted_idx = np.argsort(atomic_numbers) atomic_numbers = atomic_numbers[sorted_idx].tolist() positions = self.positions.cpu()[sorted_idx].tolist() # differentiate between periodic and non-periodic systems if self.pbc is not None and self.cell is not None: pbc = self.pbc.cpu().tolist() cell = self.cell.cpu().tolist() else: pbc = "" cell = "" formatted_str = f"{atomic_numbers}\n{positions}\n{pbc}\n{cell}" return blake2s(formatted_str.encode("utf-8"), digest_size=32).hexdigest() def __eq__(self, other: Any) -> bool: """ Checks if two objects are indeed ``AtomicData``, and if so, returns if their chemical hashes are equal. Parameters ---------- other : Any The object to compare with. Returns ------- bool True if the chemical hashes are equal, False otherwise. """ if not isinstance(other, AtomicData): return False return self.chemical_hash == other.chemical_hash
[docs] @classmethod @OptionalDependency.ASE.require def from_atoms( cls, atoms: Atoms, energy_key: str = "energy", forces_key: str = "forces", stress_key: str = "stress", virials_key: str = "virials", dipole_key: str = "dipole", charges_key: str = "charges", device: str | torch.device = "cpu", dtype: torch.dtype = torch.float32, z_table: AtomicNumberTable | None = None, ) -> AtomicData: """Create an AtomicData from an ASE-like Atoms object. Only fields that are actually present in the input object are populated; absent optional fields (energy, forces, stress, virials, dipole, charges) remain ``None``. The input ``atoms`` object is **not** mutated. The returned ``info`` dict contains only tensor-convertible entries from ``atoms.info`` (``np.ndarray``, ``list``, ``int``, ``float``, and their numpy equivalents). ``bool``, ``np.bool_``, strings, and other types are dropped. Parameters ---------- atoms : ase.Atoms An ASE Atoms object. energy_key : str Key in ``atoms.info`` for total energy. forces_key : str Key in ``atoms.arrays`` for atomic forces. stress_key : str Key in ``atoms.info`` for the stress tensor. virials_key : str Key in ``atoms.info`` for the virial tensor. dipole_key : str Key in ``atoms.info`` for the dipole moment. charges_key : str Key in ``atoms.arrays`` for per-atom partial charges. device : str | torch.device Target device for all output tensors. dtype : torch.dtype Target floating-point dtype for all output tensors. z_table : AtomicNumberTable | None Atomic number table used to build one-hot node attributes. Returns ------- AtomicData """ # convert device to torch.device if isinstance(device, str): device = torch.device(device) # Get base components from ase.Atoms object atomic_numbers = torch.as_tensor( atoms.arrays["numbers"], device=device, dtype=torch.int32 ) positions = torch.as_tensor( atoms.arrays["positions"], device=device, dtype=dtype ) pbc_array = atoms.get_pbc() if not pbc_array.any(): pbc = None cell = None else: cell = torch.as_tensor( atoms.get_cell().array.reshape(1, 3, 3), device=device, dtype=dtype, ) if torch.det(cell.squeeze(0)) <= 0.0: raise ValueError( "Cell has undefined (zero) lattice vectors. " "Please set the cell for all directions, " "e.g. using atoms.center(vacuum=10.0)." ) pbc = torch.as_tensor(pbc_array.reshape(1, 3), device=device) # Extract optional fields — absent fields remain None instead of # being fabricated as zero tensors. raw_energy = atoms.info.get(energy_key) energy = ( torch.as_tensor(raw_energy, device=device, dtype=dtype).reshape(1, 1) if raw_energy is not None else None ) raw_forces = atoms.arrays.get(forces_key) forces = ( torch.as_tensor(raw_forces, device=device, dtype=dtype) if raw_forces is not None else None ) raw_stress = atoms.info.get(stress_key) stress = ( voigt_to_matrix( torch.as_tensor(raw_stress, device=device, dtype=dtype) ).unsqueeze(0) if raw_stress is not None else None ) raw_virials = atoms.info.get(virials_key) virials = ( voigt_to_matrix( torch.as_tensor(raw_virials, device=device, dtype=dtype) ).unsqueeze(0) if raw_virials is not None else None ) raw_dipole = atoms.info.get(dipole_key) dipole = ( torch.as_tensor(raw_dipole, device=device, dtype=dtype).reshape(1, 3) if raw_dipole is not None else None ) raw_charges = atoms.arrays.get(charges_key) node_charges = ( torch.as_tensor(raw_charges, device=device, dtype=dtype) if raw_charges is not None else None ) # Read raw charge from original atoms.info before building local_info, # so it cannot be lost during normalization. raw_charge = atoms.info.get("charge") # Build local info dict with tensor-convertible entries only. # Do not mutate the caller's atoms.info. # Skip keys already consumed into dedicated AtomicData fields. _consumed_info_keys = { energy_key, stress_key, virials_key, dipole_key, "charge", } local_info: dict[str, torch.Tensor] = {} for key, value in atoms.info.items(): if key in _consumed_info_keys: continue if isinstance(value, (np.ndarray, list)): local_info[key] = torch.as_tensor(value, device=device, dtype=dtype) elif isinstance( value, (int, float, np.integer, np.floating) ) and not isinstance(value, (bool, np.bool_)): local_info[key] = torch.as_tensor([value], device=device, dtype=dtype) # Derive graph-level charge if raw_charge is not None: if not isinstance(raw_charge, numbers.Integral): raise ValueError( f"atoms.info['charge'] must be an integer, " f"got {type(raw_charge).__name__}: {raw_charge}" ) charge = torch.as_tensor([[int(raw_charge)]], device=device, dtype=dtype) elif node_charges is not None: _charge_f = torch.sum(node_charges) _charge = int(_charge_f.round().item()) if (_charge_f - _charge).abs() >= 1.0e-2: raise ValueError(f"Non-integer sum of atomic charges: {_charge_f}") charge = torch.as_tensor([[_charge]], device=device, dtype=dtype) else: charge = None node_attrs = None if z_table is not None: indices = torch.as_tensor( atomic_numbers_to_indices(atoms.arrays["numbers"], z_table=z_table), device=device, ) node_attrs = to_one_hot( indices.unsqueeze(-1), num_classes=len(z_table), ).to(dtype) masses_tensor = torch.from_numpy(atoms.get_masses()).to(device, dtype) return cls( atomic_masses=masses_tensor, atomic_numbers=atomic_numbers, positions=positions, cell=cell, pbc=pbc, node_attrs=node_attrs, # type: ignore forces=forces, energy=energy, stress=stress, virial=virials, dipole=dipole, charges=node_charges, charge=charge, info=local_info, )
[docs] @classmethod @OptionalDependency.PYMATGEN.require def from_structure( cls, structure: Structure | Molecule, energy_key: str = "energy", forces_key: str = "forces", stress_key: str = "stress", virials_key: str = "virials", dipole_key: str = "dipole", charges_key: str = "charges", device: str | torch.device = "cpu", dtype: torch.dtype = torch.float32, z_table: AtomicNumberTable | None = None, ) -> AtomicData: """Create an AtomicData from a pymatgen Structure or Molecule. Only fields that are actually present in the input are populated; absent optional fields (energy, forces, stress, virials, dipole, charges) remain ``None``. The input object is **not** mutated. The returned ``info`` dict contains tensor-convertible entries from ``structure.properties`` (``np.ndarray``, ``list``, ``int``, ``float``, and their numpy equivalents), excluding keys already consumed into dedicated fields. Unsupported types raise ``TypeError``. Stress and virials accept 3×3 matrices, 6-component Voigt vectors, or 9-component flat vectors (see :func:`voigt_to_matrix`). Parameters ---------- structure : pymatgen.core.Structure | pymatgen.core.Molecule A pymatgen Structure (periodic) or Molecule (non-periodic). For Molecule, ``cell`` and ``pbc`` are set to ``None``. energy_key : str Key in ``structure.properties`` for total energy. forces_key : str Key in ``structure.site_properties`` for atomic forces. stress_key : str Key in ``structure.properties`` for the stress tensor. virials_key : str Key in ``structure.properties`` for the virial tensor. dipole_key : str Key in ``structure.properties`` for the dipole moment. charges_key : str Key in ``structure.site_properties`` for per-atom partial charges. device : str | torch.device Target device for all output tensors. dtype : torch.dtype Target floating-point dtype for all output tensors. z_table : AtomicNumberTable | None Atomic number table used to build one-hot node attributes. Returns ------- AtomicData """ if isinstance(device, str): device = torch.device(device) atomic_numbers = torch.as_tensor( structure.atomic_numbers, device=device, dtype=torch.int32 ) positions = torch.as_tensor(structure.cart_coords, device=device, dtype=dtype) # Cell and pbc handling if hasattr(structure, "lattice"): pbc_tuple = structure.pbc if not any(pbc_tuple): pbc = None cell = None else: cell = torch.as_tensor( structure.lattice.matrix.copy().reshape(1, 3, 3), device=device, dtype=dtype, ) pbc = torch.as_tensor(pbc_tuple, device=device).reshape(1, 3) else: pbc = None cell = None # Extract optional fields from properties (system-level) # and site_properties (per-atom). raw_energy = structure.properties.get(energy_key) energy = ( torch.as_tensor([[raw_energy]], device=device, dtype=dtype) if raw_energy is not None else None ) raw_forces = structure.site_properties.get(forces_key) forces = ( torch.as_tensor(raw_forces, device=device, dtype=dtype) if raw_forces is not None else None ) raw_stress = structure.properties.get(stress_key) stress = ( voigt_to_matrix( torch.as_tensor(raw_stress, device=device, dtype=dtype) ).unsqueeze(0) if raw_stress is not None else None ) raw_virials = structure.properties.get(virials_key) virials = ( voigt_to_matrix( torch.as_tensor(raw_virials, device=device, dtype=dtype) ).unsqueeze(0) if raw_virials is not None else None ) raw_dipole = structure.properties.get(dipole_key) dipole = ( torch.as_tensor(raw_dipole, device=device, dtype=dtype).reshape(1, 3) if raw_dipole is not None else None ) raw_charges = structure.site_properties.get(charges_key) node_charges = ( torch.as_tensor(raw_charges, device=device, dtype=dtype) if raw_charges is not None else None ) # Build local info dict from remaining structure.properties. _consumed_props_keys = { energy_key, stress_key, virials_key, dipole_key, } local_info: dict[str, torch.Tensor] = {} for key, value in structure.properties.items(): if key in _consumed_props_keys: continue if isinstance(value, (np.ndarray, list)): local_info[key] = torch.as_tensor(value, device=device, dtype=dtype) elif isinstance( value, (int, float, np.integer, np.floating) ) and not isinstance(value, (bool, np.bool_)): local_info[key] = torch.as_tensor([value], device=device, dtype=dtype) else: raise TypeError( f"Cannot convert structure.properties['{key}'] of type " f"{type(value).__name__} to a tensor." ) # Derive graph-level charge. # pymatgen stores charge as float (e.g. 2 → 2.0); round before int cast. if structure._charge is not None: _charge = structure.charge if abs(_charge - round(_charge)) >= 1e-2: raise ValueError(f"Structure charge must be an integer, got {_charge}") charge = torch.as_tensor( [[int(round(_charge))]], device=device, dtype=dtype ) elif node_charges is not None: _charge_f = torch.sum(node_charges) _charge_i = int(_charge_f.round().item()) if (_charge_f - _charge_i).abs() >= 1.0e-2: raise ValueError(f"Non-integer sum of atomic charges: {_charge_f}") charge = torch.as_tensor([[_charge_i]], device=device, dtype=dtype) else: charge = None node_attrs = None if z_table is not None: indices = torch.as_tensor( atomic_numbers_to_indices( list(structure.atomic_numbers), z_table=z_table ), device=device, ) node_attrs = to_one_hot( indices.unsqueeze(-1), num_classes=len(z_table), ).to(dtype) masses = torch.tensor( [float(sp.atomic_mass) for sp in structure.species], device=device, dtype=dtype, ) return cls( atomic_masses=masses, atomic_numbers=atomic_numbers, positions=positions, cell=cell, pbc=pbc, node_attrs=node_attrs, # type: ignore forces=forces, energy=energy, stress=stress, virial=virials, dipole=dipole, charges=node_charges, charge=charge, info=local_info, )
@property def num_nodes(self) -> int: """Return the number of nodes in the graph.""" return len(self.atomic_numbers) @property def num_edges(self) -> int: """Return the number of edges in the graph.""" if self.neighbor_list is None: return 0 return self.neighbor_list.shape[0]
def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: """ Generates one-hot encoding """ shape = indices.shape[:-1] + (num_classes,) oh = torch.zeros(shape, device=indices.device).view(shape) # scatter_ is the in-place version of scatter oh.scatter_(dim=-1, index=indices, value=1) return oh.view(*shape) def voigt_to_matrix(t: torch.Tensor) -> torch.Tensor: """ Convert voigt notation to matrix notation """ if t.shape == (3, 3): return t if t.shape == (6,): return torch.tensor( [ [t[0], t[5], t[4]], [t[5], t[1], t[3]], [t[4], t[3], t[2]], ], dtype=t.dtype, device=t.device, ) if t.shape == (9,): return t.view(3, 3) raise ValueError( f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}" ) def atomic_numbers_to_indices( atomic_numbers: np.ndarray, z_table: AtomicNumberTable ) -> np.ndarray: """ Convert atomic numbers to indices """ to_index_fn = np.vectorize(z_table.z_to_index) return to_index_fn(atomic_numbers)