Note
Go to the end to download the full example code.
Batch Neighbor List Example#
This example demonstrates how to use the batch neighbor list functions in nvalchemiops with multiple molecular and crystalline systems. We’ll cover:
batch_cell_list: Batch O(N) processing with spatial cell lists
batch_naive_neighbor_list: Batch O(N²) processing for small systems
Using batch_idx to identify which system each atom belongs to
Processing heterogeneous batches with different sizes and parameters
Comparing batch vs single-system processing
Batch processing allows efficient computation of neighbor lists for multiple systems simultaneously, which is essential for high-throughput molecular screening and ensemble simulations.
import numpy as np
import torch
from system_utils import create_bulk_structure, create_molecule_structure
from nvalchemiops.neighborlist import (
batch_cell_list,
batch_naive_neighbor_list,
cell_list,
)
Set up the computation device#
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")
Using device: cuda
Using dtype: torch.float32
Create multiple systems#
We’ll create a diverse set of molecular and crystalline systems
print("\n" + "=" * 70)
print("CREATING SYSTEMS")
print("=" * 70)
# Create molecular systems
water = create_molecule_structure("H2O", box_size=15.0)
co2 = create_molecule_structure("CO2", box_size=12.0)
methane = create_molecule_structure("CH4", box_size=10.0)
# Create a small crystalline system
fcc_al = create_bulk_structure("Al", "fcc", a=4.05, cubic=True)
# Create 2x2x2 supercell
fcc_al.make_supercell([2, 2, 2])
# Collect all systems
systems = [water, co2, methane, fcc_al]
system_names = ["H2O", "CO2", "CH4", "Al-fcc(2x2x2)"]
print(f"\nCreated {len(systems)} systems:")
for name, system in zip(system_names, systems):
lattice_abc = system.lattice.abc
print(
f" {name}: {len(system)} atoms, cell: [{lattice_abc[0]:.2f}, {lattice_abc[1]:.2f}, {lattice_abc[2]:.2f}]"
)
======================================================================
CREATING SYSTEMS
======================================================================
Created 4 systems:
H2O: 3 atoms, cell: [15.00, 15.00, 15.00]
CO2: 3 atoms, cell: [12.00, 12.00, 12.00]
CH4: 5 atoms, cell: [10.00, 10.00, 10.00]
Al-fcc(2x2x2): 32 atoms, cell: [8.10, 8.10, 8.10]
Convert systems to batch format#
Combine all systems into the batch format required by nvalchemiops
print("\n" + "=" * 70)
print("CONVERTING TO BATCH FORMAT")
print("=" * 70)
# Extract positions, cells, and PBC from all systems
all_positions = []
all_cells = []
all_pbc = []
batch_indices = []
for sys_idx, system in enumerate(systems):
all_positions.append(system.cart_coords)
all_cells.append(system.lattice.matrix)
all_pbc.append(
np.array([True, True, True])
) # pymatgen structures are always periodic
# Create batch_idx: which system does each atom belong to
batch_indices.extend([sys_idx] * len(system))
# Convert to torch tensors
positions = torch.tensor(np.vstack(all_positions), dtype=dtype, device=device)
cells = torch.tensor(np.array(all_cells), dtype=dtype, device=device).reshape(-1, 3, 3)
pbc = torch.tensor(np.array(all_pbc), device=device).reshape(-1, 3)
batch_idx = torch.tensor(batch_indices, dtype=torch.int32, device=device)
# Define single cutoff for all systems
cutoff = 5.0
print("\nBatch configuration:")
print(f" Total atoms: {positions.shape[0]}")
print(f" Number of systems: {len(systems)}")
print(f" batch_idx shape: {batch_idx.shape}")
print(f" Cutoff: {cutoff} Å")
# Show batch_idx distribution
atom_counts = [len(system) for system in systems]
print(f"\n Atoms per system: {atom_counts}")
for sys_idx, (name, count) in enumerate(zip(system_names, atom_counts)):
mask = batch_idx == sys_idx
print(f" System {sys_idx} ({name}): {mask.sum()} atoms (batch_idx={sys_idx})")
======================================================================
CONVERTING TO BATCH FORMAT
======================================================================
Batch configuration:
Total atoms: 43
Number of systems: 4
batch_idx shape: torch.Size([43])
Cutoff: 5.0 Å
Atoms per system: [3, 3, 5, 32]
System 0 (H2O): 3 atoms (batch_idx=0)
System 1 (CO2): 3 atoms (batch_idx=1)
System 2 (CH4): 5 atoms (batch_idx=2)
System 3 (Al-fcc(2x2x2)): 32 atoms (batch_idx=3)
Method 1: Batch Cell List Algorithm (O(N))#
Process all systems simultaneously with cell list algorithm
print("\n" + "=" * 70)
print("METHOD 1: BATCH CELL LIST (O(N))")
print("=" * 70)
# Return neighbor matrix format (default)
neighbor_matrix_batch, num_neighbors_batch, shifts_batch = batch_cell_list(
positions, cutoff, cells, pbc, batch_idx
)
print(f"\nReturned neighbor matrix: {neighbor_matrix_batch.shape}")
print(f" Total neighbor pairs: {num_neighbors_batch.sum()}")
print(f" Average neighbors per atom: {num_neighbors_batch.float().mean():.2f}")
# Or return neighbor list (COO) format
neighbor_list_batch, neighbor_ptr_batch, shifts_coo = batch_cell_list(
positions, cutoff, cells, pbc, batch_idx, return_neighbor_list=True
)
print(f"\nReturned neighbor list (COO): {neighbor_list_batch.shape}")
print(f" Total pairs: {neighbor_list_batch.shape[1]}")
print(f" Neighbor ptr shape: {neighbor_ptr_batch.shape}")
# Analyze results per system
print("\nPairs per system:")
start_idx = 0
for sys_idx, (name, count) in enumerate(zip(system_names, atom_counts)):
end_idx = start_idx + count
system_num_neighbors = num_neighbors_batch[start_idx:end_idx].sum().item()
avg_neighbors = system_num_neighbors / count if count > 0 else 0
print(f" {name}: {system_num_neighbors} pairs, {avg_neighbors:.1f} neighbors/atom")
start_idx = end_idx
======================================================================
METHOD 1: BATCH CELL LIST (O(N))
======================================================================
Returned neighbor matrix: torch.Size([43, 928])
Total neighbor pairs: 1376
Average neighbors per atom: 32.00
Returned neighbor list (COO): torch.Size([2, 1376])
Total pairs: 1376
Neighbor ptr shape: torch.Size([44])
Pairs per system:
H2O: 6 pairs, 2.0 neighbors/atom
CO2: 6 pairs, 2.0 neighbors/atom
CH4: 20 pairs, 4.0 neighbors/atom
Al-fcc(2x2x2): 1344 pairs, 42.0 neighbors/atom
Method 2: Batch Naive Algorithm (O(N²))#
For comparison, use naive algorithm on batch of small systems
print("\n" + "=" * 70)
print("METHOD 2: BATCH NAIVE ALGORITHM (O(N²))")
print("=" * 70)
# Create batch of small systems for naive algorithm demo
small_systems = [water, co2, methane] # Exclude larger Al crystal
small_system_names = ["H2O", "CO2", "CH4"]
# Convert to batch format
small_positions_list = [
torch.tensor(s.cart_coords, dtype=dtype, device=device) for s in small_systems
]
small_positions = torch.cat(small_positions_list)
small_cells = torch.stack(
[torch.tensor(s.lattice.matrix, dtype=dtype, device=device) for s in small_systems]
)
small_pbc = torch.stack(
[torch.tensor([True, True, True], device=device) for s in small_systems]
)
# Create batch_idx
small_batch_idx = torch.cat(
[
torch.full((len(s),), i, dtype=torch.int32, device=device)
for i, s in enumerate(small_systems)
]
)
print(f"Small systems batch: {small_positions.shape[0]} total atoms")
# Batch naive neighbor list
neighbor_matrix_naive, num_neighbors_naive, shifts_naive = batch_naive_neighbor_list(
small_positions,
cutoff,
batch_idx=small_batch_idx,
cell=small_cells,
pbc=small_pbc,
)
print(f"Returned neighbor matrix: {neighbor_matrix_naive.shape}")
print(f"Total neighbor pairs: {num_neighbors_naive.sum()}")
# Compare with batch cell list on same systems
neighbor_matrix_cell, num_neighbors_cell, _ = batch_cell_list(
small_positions, cutoff, small_cells, small_pbc, small_batch_idx
)
print("\nVerification (naive vs cell list):")
print(f" Naive total pairs: {num_neighbors_naive.sum()}")
print(f" Cell list total pairs: {num_neighbors_cell.sum()}")
print(f" Results match: {torch.equal(num_neighbors_naive, num_neighbors_cell)}")
======================================================================
METHOD 2: BATCH NAIVE ALGORITHM (O(N²))
======================================================================
Small systems batch: 11 total atoms
Returned neighbor matrix: torch.Size([11, 928])
Total neighbor pairs: 32
Verification (naive vs cell list):
Naive total pairs: 32
Cell list total pairs: 32
Results match: True
Extract individual system results from batch#
print("\n" + "=" * 70)
print("EXTRACTING INDIVIDUAL SYSTEM RESULTS")
print("=" * 70)
def extract_system_neighbors(system_idx, neighbor_list, batch_idx):
"""Extract neighbor list for a specific system from batch results (COO format)."""
source_atoms = neighbor_list[0]
target_atoms = neighbor_list[1]
# Get atom range for this system
system_mask = batch_idx == system_idx
system_atom_indices = torch.where(system_mask)[0]
first_atom = system_atom_indices[0].item()
last_atom = system_atom_indices[-1].item()
# Find pairs where source atom belongs to this system
pair_mask = (source_atoms >= first_atom) & (source_atoms <= last_atom)
# Extract and adjust indices to be local to the system
system_source = source_atoms[pair_mask] - first_atom
system_target = target_atoms[pair_mask] - first_atom
return system_source, system_target, pair_mask
# Analyze each system individually
print("\nPer-system analysis:")
for sys_idx, (system, name) in enumerate(zip(systems, system_names)):
sys_source, sys_target, pair_mask = extract_system_neighbors(
sys_idx, neighbor_list_batch, batch_idx
)
n_atoms = len(system)
n_pairs = len(sys_source)
avg_neighbors = n_pairs / n_atoms if n_atoms > 0 else 0
print(f"\n{name}:")
print(f" Atoms: {n_atoms}")
print(f" Neighbor pairs: {n_pairs}")
print(f" Avg neighbors per atom: {avg_neighbors:.2f}")
if n_pairs > 0:
# Show first few pairs
print(" Sample pairs: ", end="")
for i in range(min(3, n_pairs)):
print(f"({sys_source[i]}->{sys_target[i]})", end=" ")
print()
======================================================================
EXTRACTING INDIVIDUAL SYSTEM RESULTS
======================================================================
Per-system analysis:
H2O:
Atoms: 3
Neighbor pairs: 6
Avg neighbors per atom: 2.00
Sample pairs: (0->1) (0->2) (1->0)
CO2:
Atoms: 3
Neighbor pairs: 6
Avg neighbors per atom: 2.00
Sample pairs: (0->1) (0->2) (1->0)
CH4:
Atoms: 5
Neighbor pairs: 20
Avg neighbors per atom: 4.00
Sample pairs: (0->1) (0->2) (0->3)
Al-fcc(2x2x2):
Atoms: 32
Neighbor pairs: 1344
Avg neighbors per atom: 42.00
Sample pairs: (0->22) (0->28) (0->26)
Compare batch vs single-system processing#
print("\n" + "=" * 70)
print("BATCH VS SINGLE-SYSTEM COMPARISON")
print("=" * 70)
# Process each system individually and compare with batch results
print("\nVerifying batch results against single-system calculations:\n")
for sys_idx, (system, name) in enumerate(zip(systems, system_names)):
# Convert system to tensors
sys_positions = torch.tensor(system.cart_coords, dtype=dtype, device=device)
sys_cell = torch.tensor(
system.lattice.matrix, dtype=dtype, device=device
).unsqueeze(0)
sys_pbc = torch.tensor([True, True, True], device=device)
# Calculate single system neighbor list
_, num_neighbors_single, _ = cell_list(sys_positions, cutoff, sys_cell, sys_pbc)
single_total = num_neighbors_single.sum().item()
# Extract from batch results
system_mask = batch_idx == sys_idx
batch_total = num_neighbors_batch[system_mask].sum().item()
# Compare
match_status = "✓" if single_total == batch_total else "✗"
print(
f"{match_status} {name:15s}: single={single_total:4d}, batch={batch_total:4d}"
)
======================================================================
BATCH VS SINGLE-SYSTEM COMPARISON
======================================================================
Verifying batch results against single-system calculations:
✓ H2O : single= 6, batch= 6
✓ CO2 : single= 6, batch= 6
✓ CH4 : single= 20, batch= 20
✓ Al-fcc(2x2x2) : single=1344, batch=1344
Demonstrate heterogeneous batch parameters#
Show that each system can have different properties
print("\n" + "=" * 70)
print("HETEROGENEOUS BATCH PARAMETERS")
print("=" * 70)
print("\nBatch supports different parameters per system:")
print(f" System sizes: {atom_counts}")
print(" Unit cells (box sizes):")
for idx, (name, system) in enumerate(zip(system_names, systems)):
cell_size = system.lattice.abc[0]
print(f" {name}: {cell_size:.2f} Å")
print(" PBC settings:")
for idx, (name, system) in enumerate(zip(system_names, systems)):
pbc_str = "TTT" # pymatgen structures are always periodic
print(f" {name}: [{pbc_str}]")
print(f"\n Single cutoff used for all: {cutoff} Å")
print(" (Note: Currently all systems share the same cutoff)")
======================================================================
HETEROGENEOUS BATCH PARAMETERS
======================================================================
Batch supports different parameters per system:
System sizes: [3, 3, 5, 32]
Unit cells (box sizes):
H2O: 15.00 Å
CO2: 12.00 Å
CH4: 10.00 Å
Al-fcc(2x2x2): 8.10 Å
PBC settings:
H2O: [TTT]
CO2: [TTT]
CH4: [TTT]
Al-fcc(2x2x2): [TTT]
Single cutoff used for all: 5.0 Å
(Note: Currently all systems share the same cutoff)
print("\nExample completed successfully!")
Example completed successfully!
Total running time of the script: (0 minutes 1.182 seconds)