# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch utilities for neighbor list construction.
This module contains PyTorch-specific helper functions for neighbor list operations.
"""
from __future__ import annotations
import torch
import warp as wp
from nvalchemiops.neighbors.neighbor_utils import (
NeighborOverflowError,
estimate_max_neighbors,
)
from nvalchemiops.neighbors.neighbor_utils import (
compute_naive_num_shifts as wp_compute_naive_num_shifts,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype
__all__ = [
"compute_naive_num_shifts",
"get_neighbor_list_from_neighbor_matrix",
"prepare_batch_idx_ptr",
"allocate_cell_list",
"estimate_max_neighbors",
"NeighborOverflowError",
]
def compute_naive_num_shifts(
cell: torch.Tensor,
cutoff: float,
pbc: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Compute periodic image shifts needed for neighbor searching.
Parameters
----------
cell : torch.Tensor, shape (num_systems, 3, 3)
Cell matrices defining lattice vectors in Cartesian coordinates.
Each 3x3 matrix represents one system's periodic cell.
cutoff : float
Cutoff distance for neighbor searching in Cartesian units.
Must be positive and typically less than half the minimum cell dimension.
pbc : torch.Tensor, shape (num_systems, 3), dtype=bool
Periodic boundary condition flags for each dimension.
True enables periodicity in that direction.
Returns
-------
shift_range : torch.Tensor, shape (num_systems, 3), dtype=int32
Maximum shift indices in each dimension for each system.
num_shifts : torch.Tensor, shape (num_systems,), dtype=int32
Number of periodic shifts for each system.
max_shifts : int
Maximum per-system shift count across all systems.
Raises
------
ValueError
If any per-system shift count exceeds int32 range.
See Also
--------
nvalchemiops.neighbors.neighbor_utils.compute_naive_num_shifts : Core warp launcher
"""
num_systems = cell.shape[0]
device = cell.device
num_shifts_i32 = torch.empty(num_systems, dtype=torch.int32, device=device)
shift_range = torch.empty((num_systems, 3), dtype=torch.int32, device=device)
wp_dtype = get_wp_dtype(cell.dtype)
wp_mat_dtype = get_wp_mat_dtype(cell.dtype)
wp_device = wp.device_from_torch(device)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype)
wp_pbc = wp.from_torch(pbc, dtype=wp.bool)
wp_num_shifts = wp.from_torch(num_shifts_i32, dtype=wp.int32)
wp_shift_range = wp.from_torch(shift_range, dtype=wp.vec3i)
wp_compute_naive_num_shifts(
cell=wp_cell,
cutoff=cutoff,
pbc=wp_pbc,
num_shifts=wp_num_shifts,
shift_range=wp_shift_range,
wp_dtype=wp_dtype,
device=str(wp_device),
)
s = shift_range.to(torch.int64)
k1 = 2 * s[:, 1] + 1
k2 = 2 * s[:, 2] + 1
num_shifts_i64 = s[:, 0] * k1 * k2 + s[:, 1] * k2 + s[:, 2] + 1
max_shifts_i64 = num_shifts_i64.max().item() if num_systems > 0 else 0
if max_shifts_i64 > 2**31 - 1:
raise ValueError(
f"Per-system shift count ({max_shifts_i64}) exceeds int32 max "
f"(2^31 - 1). Reduce the cutoff, increase cell size, or use a "
f"cell-list method for very small cells."
)
num_shifts = num_shifts_i64.to(torch.int32)
return shift_range, num_shifts, int(max_shifts_i64)
def get_neighbor_list_from_neighbor_matrix(
neighbor_matrix: torch.Tensor,
num_neighbors: torch.Tensor,
neighbor_shift_matrix: torch.Tensor | None = None,
fill_value: int = -1,
) -> (
tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
"""Convert neighbor matrix format to neighbor list format.
Parameters
----------
neighbor_matrix : torch.Tensor, shape (total_atoms, max_neighbors), dtype=int32
The neighbor matrix with neighbor atom indices.
num_neighbors : torch.Tensor, shape (total_atoms,), dtype=int32
The number of neighbors for each atom.
neighbor_shift_matrix : torch.Tensor | None, shape (total_atoms, max_neighbors, 3), dtype=int32
Optional neighbor shift matrix with periodic shift vectors.
fill_value : int, default=-1
The fill value used in the neighbor matrix to indicate empty slots.
This is used to create a mask from the neighbor matrix.
Returns
-------
neighbor_list : torch.Tensor, shape (2, num_pairs), dtype=int32
The neighbor list in COO format [source_atoms, target_atoms].
neighbor_ptr : torch.Tensor, shape (total_atoms + 1,), dtype=int32
CSR-style pointer array where neighbor_ptr[i]:neighbor_ptr[i+1] gives the range of
neighbors for atom i in the flattened neighbor list.
neighbor_list_shifts : torch.Tensor, shape (num_pairs, 3), dtype=int32
The neighbor shift vectors (only returned if neighbor_shift_matrix is not None).
Raises
------
ValueError
If the max number of neighbors is larger than the neighbor matrix width.
Notes
-----
This is a pure PyTorch utility function with no warp dependencies. It converts
from the fixed-width matrix format to the variable-width list format by masking
out fill values and flattening the result.
See Also
--------
nvalchemiops.torch.neighbors.naive_neighbor_list : Uses this for format conversion
nvalchemiops.torch.neighbors.cell_list : Uses this for format conversion
"""
# Handle empty case
if num_neighbors.shape[0] == 0:
neighbor_list = torch.zeros(
2, 0, dtype=neighbor_matrix.dtype, device=neighbor_matrix.device
)
neighbor_ptr = torch.zeros(1, dtype=torch.int32, device=neighbor_matrix.device)
if neighbor_shift_matrix is not None:
neighbor_shift_list = torch.empty(
0,
3,
dtype=neighbor_shift_matrix.dtype,
device=neighbor_shift_matrix.device,
)
return neighbor_list, neighbor_ptr, neighbor_shift_list
else:
return neighbor_list, neighbor_ptr
# Validate that the neighbor matrix is large enough
max_found = num_neighbors.max()
if max_found > neighbor_matrix.shape[1]:
raise NeighborOverflowError(
neighbor_matrix.shape[1],
max_found.item() if hasattr(max_found, "item") else int(max_found),
)
# Create mask and extract neighbor pairs
mask = neighbor_matrix != fill_value
dtype = neighbor_matrix.dtype
i_idx = torch.where(mask)[0].to(dtype)
j_idx = neighbor_matrix[mask].to(dtype)
neighbor_list = torch.stack([i_idx, j_idx], dim=0)
# Create CSR-style pointer array
neighbor_ptr = torch.zeros(
num_neighbors.shape[0] + 1, dtype=torch.int32, device=neighbor_matrix.device
)
torch.cumsum(num_neighbors, dim=0, out=neighbor_ptr[1:])
if neighbor_shift_matrix is not None:
neighbor_list_shifts = neighbor_shift_matrix[mask]
return neighbor_list, neighbor_ptr, neighbor_list_shifts
else:
return neighbor_list, neighbor_ptr
[docs]
@torch.compile
def prepare_batch_idx_ptr(
batch_idx: torch.Tensor | None,
batch_ptr: torch.Tensor | None,
num_atoms: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare batch index and pointer tensors from either representation.
Utility function to ensure both batch_idx and batch_ptr are available,
computing one from the other if needed.
Parameters
----------
batch_idx : torch.Tensor | None, shape (total_atoms,), dtype=int32
Tensor indicating the batch index for each atom.
batch_ptr : torch.Tensor | None, shape (num_systems + 1,), dtype=int32
Tensor indicating the start index of each batch in the atom list.
num_atoms : int
Total number of atoms across all systems.
device : torch.device
Device on which to create tensors if needed.
Returns
-------
batch_idx : torch.Tensor, shape (total_atoms,), dtype=int32
Prepared batch index tensor.
batch_ptr : torch.Tensor, shape (num_systems + 1,), dtype=int32
Prepared batch pointer tensor.
Raises
------
ValueError
If both batch_idx and batch_ptr are None.
RuntimeError
If batch_idx length does not match num_atoms (only checked in eager mode).
Notes
-----
This is a pure PyTorch utility function with no warp dependencies. It provides
convenience for batch operations by converting between dense (batch_idx) and
sparse (batch_ptr) batch representations.
The batch_idx size validation is only performed in eager mode to avoid graph
breaks during torch.compile tracing. During compiled execution, mismatched
sizes will result in undefined behavior.
See Also
--------
nvalchemiops.torch.neighbors.batch_naive_neighbor_list : Uses this for batch setup
nvalchemiops.torch.neighbors.batch_cell_list : Uses this for batch setup
"""
if batch_idx is None and batch_ptr is None:
raise ValueError("Either batch_idx or batch_ptr must be provided.")
# Validate batch_idx size in eager mode only to avoid graph breaks
if not torch.compiler.is_compiling():
if batch_idx is not None and batch_idx.shape[0] != num_atoms:
raise RuntimeError(
f"batch_idx length ({batch_idx.shape[0]}) does not match "
f"num_atoms ({num_atoms}). batch_idx must have one entry per atom."
)
if batch_idx is None:
num_systems = batch_ptr.shape[0] - 1
num_atoms_per_system = batch_ptr[1:] - batch_ptr[:-1]
batch_idx = torch.repeat_interleave(
torch.arange(num_systems, dtype=torch.int32, device=device),
num_atoms_per_system,
)
elif batch_ptr is None:
num_systems = batch_idx.max() + 1
num_atoms_per_system = torch.bincount(batch_idx, minlength=num_systems)
batch_ptr = torch.zeros(num_systems + 1, dtype=torch.int32, device=device)
torch.cumsum(num_atoms_per_system, dim=0, out=batch_ptr[1:])
return batch_idx, batch_ptr
[docs]
def allocate_cell_list(
total_atoms: int,
max_total_cells: int,
neighbor_search_radius: torch.Tensor,
device: torch.device,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""Allocate memory tensors for cell list data structures.
Parameters
----------
total_atoms : int
Total number of atoms across all systems.
max_total_cells : int
Maximum number of cells to allocate.
neighbor_search_radius : torch.Tensor, shape (3,) or (num_systems, 3), dtype=int32
Radius of neighboring cells to search in each dimension.
device : torch.device
Device on which to create tensors.
Returns
-------
cells_per_dimension : torch.Tensor, shape (3,) or (num_systems, 3), dtype=int32
Number of cells in x, y, z directions (to be filled by build_cell_list).
neighbor_search_radius : torch.Tensor, shape (3,) or (num_systems, 3), dtype=int32
Radius of neighboring cells to search (passed through for convenience).
atom_periodic_shifts : torch.Tensor, shape (total_atoms, 3), dtype=int32
Periodic boundary crossings for each atom (to be filled by build_cell_list).
atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32
3D cell coordinates for each atom (to be filled by build_cell_list).
atoms_per_cell_count : torch.Tensor, shape (max_total_cells,), dtype=int32
Number of atoms in each cell (to be filled by build_cell_list).
cell_atom_start_indices : torch.Tensor, shape (max_total_cells,), dtype=int32
Starting index in cell_atom_list for each cell (to be filled by build_cell_list).
cell_atom_list : torch.Tensor, shape (total_atoms,), dtype=int32
Flattened list of atom indices organized by cell (to be filled by build_cell_list).
Notes
-----
This is a pure PyTorch utility function with no warp dependencies. It pre-allocates
all tensors needed for cell list construction, supporting both single-system and
batched operations based on the shape of neighbor_search_radius.
See Also
--------
nvalchemiops.neighbors.cell_list.build_cell_list : Warp launcher that uses these tensors
nvalchemiops.torch.neighbors.cell_list.build_cell_list : High-level PyTorch wrapper
nvalchemiops.torch.neighbors.batch_cell_list.batch_build_cell_list : Batched version
"""
# Detect number of systems from neighbor_search_radius shape
is_batched = neighbor_search_radius.ndim == 2
num_systems = neighbor_search_radius.shape[0] if is_batched else 1
cells_per_dimension = torch.zeros(
(3,) if not is_batched else (num_systems, 3),
dtype=torch.int32,
device=device,
)
atom_periodic_shifts = torch.zeros(
(total_atoms, 3), dtype=torch.int32, device=device
)
atom_to_cell_mapping = torch.zeros(
(total_atoms, 3), dtype=torch.int32, device=device
)
atoms_per_cell_count = torch.zeros(
(max_total_cells,), dtype=torch.int32, device=device
)
cell_atom_start_indices = torch.zeros(
(max_total_cells,), dtype=torch.int32, device=device
)
cell_atom_list = torch.zeros((total_atoms,), dtype=torch.int32, device=device)
return (
cells_per_dimension,
neighbor_search_radius,
atom_periodic_shifts,
atom_to_cell_mapping,
atoms_per_cell_count,
cell_atom_start_indices,
cell_atom_list,
)