# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Cell Utilities for NPT/NPH Simulations.
This module provides utilities for manipulating simulation cells (periodic boxes)
in molecular dynamics simulations with variable cell volume/shape (NPT, NPH ensembles).
The cell is represented as a (B, 3, 3) array of matrices where each matrix contains
lattice vectors as columns:
cell[b] = [a, b, c] (column vectors for system b)
Even single-system simulations use shape (1, 3, 3).
Fractional coordinates s relate to Cartesian coordinates r by:
r = cell @ s
s = cell_inv @ r
Key concepts:
- Cell volume: V = det(cell)
- Cell inverse: For coordinate transformations
- Strain tensor: Deformation from reference cell
- Position remapping: Maintain fractional coordinates when cell changes
All kernels are dtype-agnostic and support both float32 and float64 cell matrices.
Functions that require cell_inv accept it as a required parameter; callers
must pre-compute via ``compute_cell_inverse`` to avoid redundant inverse
computations in MD loops.
"""
from __future__ import annotations
from typing import Any
import warp as wp
__all__ = [
# Cell properties
"compute_cell_volume",
"compute_cell_inverse",
# Strain operations
"compute_strain_tensor",
"apply_strain_to_cell",
# Position operations
"scale_positions_with_cell",
"wrap_positions_to_cell",
"cartesian_to_fractional",
"fractional_to_cartesian",
# Non-mutating variants
"scale_positions_with_cell_out",
"wrap_positions_to_cell_out",
]
# ==============================================================================
# Cell Property Kernels
# ==============================================================================
@wp.kernel
def _compute_cell_volume_kernel(
cells: wp.array(dtype=Any),
volumes: wp.array(dtype=Any),
):
"""Compute cell volume V = det(cell) = a · (b × c).
Launch Grid
-----------
dim = [num_systems]
Parameters
----------
cells : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell matrices. Shape (B,).
volumes : wp.array(dtype=wp.float32 or wp.float64)
Output volumes. Shape (B,).
"""
sys_id = wp.tid()
cell = cells[sys_id]
# Cell columns are lattice vectors a, b, c
a0 = cell[0, 0]
a1 = cell[1, 0]
a2 = cell[2, 0]
b0 = cell[0, 1]
b1 = cell[1, 1]
b2 = cell[2, 1]
c0 = cell[0, 2]
c1 = cell[1, 2]
c2 = cell[2, 2]
# det = a · (b × c)
det = a0 * (b1 * c2 - b2 * c1) - a1 * (b0 * c2 - b2 * c0) + a2 * (b0 * c1 - b1 * c0)
volumes[sys_id] = wp.abs(det)
@wp.kernel
def _compute_cell_inverse_kernel(
cells: wp.array(dtype=Any),
cells_inv: wp.array(dtype=Any),
):
"""Compute cell inverse for coordinate transformations.
Launch Grid
-----------
dim = [num_systems]
Parameters
----------
cells : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell matrices. Shape (B,).
cells_inv : wp.array(dtype=wp.mat33f or wp.mat33d)
Output cell inverses. Shape (B,).
"""
sys_id = wp.tid()
cell = cells[sys_id]
# Cell elements
a00 = cell[0, 0]
a01 = cell[0, 1]
a02 = cell[0, 2]
a10 = cell[1, 0]
a11 = cell[1, 1]
a12 = cell[1, 2]
a20 = cell[2, 0]
a21 = cell[2, 1]
a22 = cell[2, 2]
# Determinant
det = (
a00 * (a11 * a22 - a12 * a21)
- a01 * (a10 * a22 - a12 * a20)
+ a02 * (a10 * a21 - a11 * a20)
)
inv_det = type(a00)(1.0) / wp.max(det, type(a00)(1e-10))
# Adjugate matrix / det
inv00 = (a11 * a22 - a12 * a21) * inv_det
inv01 = (a02 * a21 - a01 * a22) * inv_det
inv02 = (a01 * a12 - a02 * a11) * inv_det
inv10 = (a12 * a20 - a10 * a22) * inv_det
inv11 = (a00 * a22 - a02 * a20) * inv_det
inv12 = (a02 * a10 - a00 * a12) * inv_det
inv20 = (a10 * a21 - a11 * a20) * inv_det
inv21 = (a01 * a20 - a00 * a21) * inv_det
inv22 = (a00 * a11 - a01 * a10) * inv_det
cells_inv[sys_id] = type(cell)(
inv00, inv01, inv02, inv10, inv11, inv12, inv20, inv21, inv22
)
@wp.kernel
def _compute_strain_tensor_kernel(
cells: wp.array(dtype=Any),
cells_ref_inv: wp.array(dtype=Any),
strains: wp.array(dtype=Any),
):
"""Compute strain tensor: ε = cell @ cell_ref_inv - I.
Launch Grid
-----------
dim = [num_systems]
"""
sys_id = wp.tid()
cell = cells[sys_id]
cell_ref_inv = cells_ref_inv[sys_id]
# Compute cell @ cell_ref_inv
m = wp.mul(cell, cell_ref_inv)
m -= wp.identity(3, dtype=cell.dtype)
strains[sys_id] = m
@wp.kernel
def _apply_strain_to_cell_kernel(
cells: wp.array(dtype=Any),
strains: wp.array(dtype=Any),
cells_out: wp.array(dtype=Any),
):
"""Apply strain: cell_new = (I + strain) @ cell.
Launch Grid
-----------
dim = [num_systems]
"""
sys_id = wp.tid()
cell = cells[sys_id]
strain = strains[sys_id]
cells_out[sys_id] = wp.mul(wp.identity(3, dtype=cell.dtype) + strain, cell)
# ==============================================================================
# Position Transformation Kernels
# ==============================================================================
@wp.kernel
def _scale_positions_single_kernel(
positions: wp.array(dtype=Any),
cell_old_inv: wp.array(dtype=Any),
cell_new: wp.array(dtype=Any),
):
"""Scale positions for single system (no batch_idx).
r_new = cell_new @ cell_old_inv @ r_old
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
r = positions[atom_idx]
# Single system: always index 0
coi = cell_old_inv[0]
cn = cell_new[0]
positions[atom_idx] = wp.mul(wp.mul(cn, coi), r)
@wp.kernel
def _scale_positions_kernel(
positions: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
cells_old_inv: wp.array(dtype=Any),
cells_new: wp.array(dtype=Any),
):
"""Scale positions from old cell to new cell maintaining fractional coords.
r_new = cell_new @ cell_old_inv @ r_old
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
sys_id = batch_idx[atom_idx]
r = positions[atom_idx]
cell_old_inv = cells_old_inv[sys_id]
cell_new = cells_new[sys_id]
positions[atom_idx] = wp.mul(wp.mul(cell_new, cell_old_inv), r)
@wp.kernel
def _scale_positions_out_single_kernel(
positions: wp.array(dtype=Any),
cell_old_inv: wp.array(dtype=Any),
cell_new: wp.array(dtype=Any),
positions_out: wp.array(dtype=Any),
):
"""Scale positions to output array for single system.
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
r = positions[atom_idx]
coi = cell_old_inv[0]
cn = cell_new[0]
positions_out[atom_idx] = wp.mul(wp.mul(cn, coi), r)
@wp.kernel
def _scale_positions_out_kernel(
positions: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
cells_old_inv: wp.array(dtype=Any),
cells_new: wp.array(dtype=Any),
positions_out: wp.array(dtype=Any),
):
"""Scale positions to output array.
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
sys_id = batch_idx[atom_idx]
r = positions[atom_idx]
cell_old_inv = cells_old_inv[sys_id]
cell_new = cells_new[sys_id]
positions_out[atom_idx] = wp.mul(wp.mul(cell_new, cell_old_inv), r)
# ==============================================================================
# Wrapping Kernels
# ==============================================================================
@wp.kernel
def _wrap_positions_single_kernel(
positions: wp.array(dtype=Any),
cell_inv: wp.array(dtype=Any),
cell: wp.array(dtype=Any),
):
"""Wrap positions for single system (no batch_idx).
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
r = positions[atom_idx]
ci = cell_inv[0]
c = cell[0]
# Convert to fractional: s = cell_inv @ r
s = wp.mul(ci, r)
# Wrap to [0, 1) using floor
s_wrapped = type(s)(
s[0] - wp.floor(s[0]), s[1] - wp.floor(s[1]), s[2] - wp.floor(s[2])
)
# Convert back to Cartesian: r_new = cell @ s_wrapped
positions[atom_idx] = wp.mul(c, s_wrapped)
@wp.kernel
def _wrap_positions_kernel(
positions: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
cells_inv: wp.array(dtype=Any),
cells: wp.array(dtype=Any),
):
"""Wrap positions into the primary cell [0, 1) in fractional coordinates.
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
sys_id = batch_idx[atom_idx]
r = positions[atom_idx]
cell_inv = cells_inv[sys_id]
cell = cells[sys_id]
# Convert to fractional: s = cell_inv @ r
s = wp.mul(cell_inv, r)
# Wrap to [0, 1) using floor
s_wrapped = type(s)(
s[0] - wp.floor(s[0]), s[1] - wp.floor(s[1]), s[2] - wp.floor(s[2])
)
# Convert back to Cartesian: r_new = cell @ s_wrapped
positions[atom_idx] = wp.mul(cell, s_wrapped)
@wp.kernel
def _wrap_positions_out_single_kernel(
positions: wp.array(dtype=Any),
cell_inv: wp.array(dtype=Any),
cell: wp.array(dtype=Any),
positions_out: wp.array(dtype=Any),
):
"""Wrap positions to output array for single system.
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
r = positions[atom_idx]
ci = cell_inv[0]
c = cell[0]
s = wp.mul(ci, r)
s_wrapped = type(s)(
s[0] - wp.floor(s[0]), s[1] - wp.floor(s[1]), s[2] - wp.floor(s[2])
)
positions_out[atom_idx] = wp.mul(c, s_wrapped)
@wp.kernel
def _wrap_positions_out_kernel(
positions: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
cells_inv: wp.array(dtype=Any),
cells: wp.array(dtype=Any),
positions_out: wp.array(dtype=Any),
):
"""Wrap positions to output array.
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
sys_id = batch_idx[atom_idx]
r = positions[atom_idx]
cell_inv = cells_inv[sys_id]
cell = cells[sys_id]
# Convert to fractional: s = cell_inv @ r
s = wp.mul(cell_inv, r)
# Wrap to [0, 1) using floor
s_wrapped = type(s)(
s[0] - wp.floor(s[0]), s[1] - wp.floor(s[1]), s[2] - wp.floor(s[2])
)
# Convert back to Cartesian: r_new = cell @ s_wrapped
positions_out[atom_idx] = wp.mul(cell, s_wrapped)
# ==============================================================================
# Coordinate Transformation Kernels
# ==============================================================================
@wp.kernel
def _cartesian_to_fractional_single_kernel(
positions: wp.array(dtype=Any),
cell_inv: wp.array(dtype=Any),
fractional: wp.array(dtype=Any),
):
"""Convert Cartesian to fractional for single system.
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
r = positions[atom_idx]
ci = cell_inv[0]
fractional[atom_idx] = wp.mul(ci, r)
@wp.kernel
def _cartesian_to_fractional_kernel(
positions: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
cells_inv: wp.array(dtype=Any),
fractional: wp.array(dtype=Any),
):
"""Convert Cartesian coordinates to fractional coordinates.
s = cell_inv @ r
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
sys_id = batch_idx[atom_idx]
r = positions[atom_idx]
cell_inv = cells_inv[sys_id]
fractional[atom_idx] = wp.mul(cell_inv, r)
@wp.kernel
def _fractional_to_cartesian_single_kernel(
fractional: wp.array(dtype=Any),
cell: wp.array(dtype=Any),
positions: wp.array(dtype=Any),
):
"""Convert fractional to Cartesian for single system.
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
s = fractional[atom_idx]
c = cell[0]
positions[atom_idx] = wp.mul(c, s)
@wp.kernel
def _fractional_to_cartesian_kernel(
fractional: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
cells: wp.array(dtype=Any),
positions: wp.array(dtype=Any),
):
"""Convert fractional coordinates to Cartesian coordinates.
r = cell @ s
Launch Grid
-----------
dim = [num_atoms]
"""
atom_idx = wp.tid()
sys_id = batch_idx[atom_idx]
s = fractional[atom_idx]
cell = cells[sys_id]
positions[atom_idx] = wp.mul(cell, s)
# ==============================================================================
# Kernel Overloads for Explicit Typing
# ==============================================================================
_T = [wp.float32, wp.float64] # Scalar types
_V = [wp.vec3f, wp.vec3d] # Vector types
_M = [wp.mat33f, wp.mat33d] # Matrix types
# Cell property kernel overloads
_compute_cell_volume_kernel_overload = {}
_compute_cell_inverse_kernel_overload = {}
_compute_strain_tensor_kernel_overload = {}
_apply_strain_to_cell_kernel_overload = {}
# Position scaling kernel overloads
_scale_positions_single_kernel_overload = {}
_scale_positions_kernel_overload = {}
_scale_positions_out_single_kernel_overload = {}
_scale_positions_out_kernel_overload = {}
# Wrapping kernel overloads
_wrap_positions_single_kernel_overload = {}
_wrap_positions_kernel_overload = {}
_wrap_positions_out_single_kernel_overload = {}
_wrap_positions_out_kernel_overload = {}
# Coordinate conversion kernel overloads
_cartesian_to_fractional_single_kernel_overload = {}
_cartesian_to_fractional_kernel_overload = {}
_fractional_to_cartesian_single_kernel_overload = {}
_fractional_to_cartesian_kernel_overload = {}
for t, v, m in zip(_T, _V, _M):
# Cell property kernels
_compute_cell_volume_kernel_overload[m] = wp.overload(
_compute_cell_volume_kernel,
[wp.array(dtype=m), wp.array(dtype=t)],
)
_compute_cell_inverse_kernel_overload[m] = wp.overload(
_compute_cell_inverse_kernel,
[wp.array(dtype=m), wp.array(dtype=m)],
)
_compute_strain_tensor_kernel_overload[m] = wp.overload(
_compute_strain_tensor_kernel,
[wp.array(dtype=m), wp.array(dtype=m), wp.array(dtype=m)],
)
_apply_strain_to_cell_kernel_overload[m] = wp.overload(
_apply_strain_to_cell_kernel,
[wp.array(dtype=m), wp.array(dtype=m), wp.array(dtype=m)],
)
# Position scaling kernels
_scale_positions_single_kernel_overload[v] = wp.overload(
_scale_positions_single_kernel,
[wp.array(dtype=v), wp.array(dtype=m), wp.array(dtype=m)],
)
_scale_positions_kernel_overload[v] = wp.overload(
_scale_positions_kernel,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=m),
wp.array(dtype=m),
],
)
_scale_positions_out_single_kernel_overload[v] = wp.overload(
_scale_positions_out_single_kernel,
[wp.array(dtype=v), wp.array(dtype=m), wp.array(dtype=m), wp.array(dtype=v)],
)
_scale_positions_out_kernel_overload[v] = wp.overload(
_scale_positions_out_kernel,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=m),
wp.array(dtype=m),
wp.array(dtype=v),
],
)
# Wrapping kernels
_wrap_positions_single_kernel_overload[v] = wp.overload(
_wrap_positions_single_kernel,
[wp.array(dtype=v), wp.array(dtype=m), wp.array(dtype=m)],
)
_wrap_positions_kernel_overload[v] = wp.overload(
_wrap_positions_kernel,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=m),
wp.array(dtype=m),
],
)
_wrap_positions_out_single_kernel_overload[v] = wp.overload(
_wrap_positions_out_single_kernel,
[wp.array(dtype=v), wp.array(dtype=m), wp.array(dtype=m), wp.array(dtype=v)],
)
_wrap_positions_out_kernel_overload[v] = wp.overload(
_wrap_positions_out_kernel,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=m),
wp.array(dtype=m),
wp.array(dtype=v),
],
)
# Coordinate conversion kernels
_cartesian_to_fractional_single_kernel_overload[v] = wp.overload(
_cartesian_to_fractional_single_kernel,
[wp.array(dtype=v), wp.array(dtype=m), wp.array(dtype=v)],
)
_cartesian_to_fractional_kernel_overload[v] = wp.overload(
_cartesian_to_fractional_kernel,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=m),
wp.array(dtype=v),
],
)
_fractional_to_cartesian_single_kernel_overload[v] = wp.overload(
_fractional_to_cartesian_single_kernel,
[wp.array(dtype=v), wp.array(dtype=m), wp.array(dtype=v)],
)
_fractional_to_cartesian_kernel_overload[v] = wp.overload(
_fractional_to_cartesian_kernel,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=m),
wp.array(dtype=v),
],
)
# ==============================================================================
# Functional Interfaces
# ==============================================================================
[docs]
def compute_cell_volume(
cells: wp.array,
volumes: wp.array,
device: str = None,
) -> wp.array:
r"""
Compute cell volume :math:`V = |\det(cell)|`.
Parameters
----------
cells : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell matrices. Shape (B,) where B is number of systems.
Even single systems use shape (1,).
volumes : wp.array
Output array for volumes. Shape (B,). Caller must pre-allocate.
device : str, optional
Warp device. If None, inferred from cells.
Returns
-------
wp.array
Cell volumes. Shape (B,).
"""
if device is None:
device = cells.device
num_systems = cells.shape[0]
mat_dtype = cells.dtype
wp.launch(
_compute_cell_volume_kernel_overload[mat_dtype],
dim=num_systems,
inputs=[cells, volumes],
device=device,
)
return volumes
[docs]
def compute_cell_inverse(
cells: wp.array,
cells_inv: wp.array,
device: str = None,
) -> wp.array:
"""
Compute cell inverse matrices for coordinate transformations.
Parameters
----------
cells : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell matrices. Shape (B,).
cells_inv : wp.array
Output array for inverses. Shape (B,). Caller must pre-allocate.
device : str, optional
Warp device. If None, inferred from cells.
Returns
-------
wp.array
Cell inverse matrices. Shape (B,).
"""
if device is None:
device = cells.device
num_systems = cells.shape[0]
mat_dtype = cells.dtype
wp.launch(
_compute_cell_inverse_kernel_overload[mat_dtype],
dim=num_systems,
inputs=[cells, cells_inv],
device=device,
)
return cells_inv
[docs]
def compute_strain_tensor(
cells: wp.array,
cells_ref_inv: wp.array,
strains: wp.array,
device: str = None,
) -> wp.array:
"""
Compute strain tensor from current and reference cells.
The strain tensor ε is defined by: cell = (I + ε) @ cell_ref
So: ε = cell @ cell_ref_inv - I
Parameters
----------
cells : wp.array(dtype=wp.mat33f or wp.mat33d)
Current cell matrices. Shape (B,).
cells_ref_inv : wp.array
Pre-computed inverse of reference cells. Shape (B,).
Caller must pre-compute via ``compute_cell_inverse``.
strains : wp.array
Output strain tensors. Shape (B,). Caller must pre-allocate.
device : str, optional
Warp device. If None, inferred from cells.
Returns
-------
wp.array
Strain tensors. Shape (B,).
"""
if device is None:
device = cells.device
num_systems = cells.shape[0]
mat_dtype = cells.dtype
wp.launch(
_compute_strain_tensor_kernel_overload[mat_dtype],
dim=num_systems,
inputs=[cells, cells_ref_inv, strains],
device=device,
)
return strains
[docs]
def apply_strain_to_cell(
cells: wp.array,
strains: wp.array,
cells_out: wp.array,
device: str = None,
) -> wp.array:
"""
Apply strain tensor to cell: cell_new = (I + strain) @ cell.
Parameters
----------
cells : wp.array(dtype=wp.mat33f or wp.mat33d)
Current cell matrices. Shape (B,).
strains : wp.array
Strain tensors to apply. Shape (B,).
cells_out : wp.array
Output cell matrices. Shape (B,). Caller must pre-allocate.
device : str, optional
Warp device. If None, inferred from cells.
Returns
-------
wp.array
Updated cell matrices. Shape (B,).
"""
if device is None:
device = cells.device
num_systems = cells.shape[0]
mat_dtype = cells.dtype
wp.launch(
_apply_strain_to_cell_kernel_overload[mat_dtype],
dim=num_systems,
inputs=[cells, strains, cells_out],
device=device,
)
return cells_out
[docs]
def scale_positions_with_cell(
positions: wp.array,
cells_new: wp.array,
cells_old_inv: wp.array,
batch_idx: wp.array = None,
device: str = None,
) -> None:
"""
Scale positions when cell changes, maintaining fractional coordinates (in-place).
r_new = cell_new @ cell_old_inv @ r_old
Parameters
----------
positions : wp.array(dtype=wp.vec3f or wp.vec3d)
Atomic positions. Shape (N,). MODIFIED in-place.
cells_new : wp.array
New cell matrices. Shape (B,).
cells_old_inv : wp.array
Pre-computed inverse of old cell matrices. Shape (B,).
Caller must pre-compute via ``compute_cell_inverse``.
batch_idx : wp.array(dtype=wp.int32), optional
System index for each atom. Shape (N,). If None, assumes single system.
device : str, optional
Warp device. If None, inferred from positions.
"""
if device is None:
device = positions.device
num_atoms = positions.shape[0]
vec_dtype = positions.dtype
if batch_idx is None:
# Single-system kernel
wp.launch(
_scale_positions_single_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, cells_old_inv, cells_new],
device=device,
)
else:
# Batched kernel
wp.launch(
_scale_positions_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, batch_idx, cells_old_inv, cells_new],
device=device,
)
[docs]
def scale_positions_with_cell_out(
positions: wp.array,
cells_new: wp.array,
cells_old_inv: wp.array,
positions_out: wp.array,
batch_idx: wp.array = None,
device: str = None,
) -> wp.array:
"""
Scale positions when cell changes (non-mutating).
Parameters
----------
positions : wp.array(dtype=wp.vec3f or wp.vec3d)
Atomic positions. Shape (N,).
cells_new : wp.array
New cell matrices. Shape (B,).
cells_old_inv : wp.array
Pre-computed inverse of old cell matrices. Shape (B,).
Caller must pre-compute via ``compute_cell_inverse``.
positions_out : wp.array
Output positions. Shape (N,). Caller must pre-allocate.
batch_idx : wp.array(dtype=wp.int32), optional
System index for each atom. Shape (N,). If None, assumes single system.
device : str, optional
Warp device.
Returns
-------
wp.array
Scaled positions.
"""
if device is None:
device = positions.device
num_atoms = positions.shape[0]
vec_dtype = positions.dtype
if batch_idx is None:
wp.launch(
_scale_positions_out_single_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, cells_old_inv, cells_new, positions_out],
device=device,
)
else:
wp.launch(
_scale_positions_out_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, batch_idx, cells_old_inv, cells_new, positions_out],
device=device,
)
return positions_out
[docs]
def wrap_positions_to_cell(
positions: wp.array,
cells: wp.array,
cells_inv: wp.array,
batch_idx: wp.array = None,
device: str = None,
) -> None:
"""
Wrap positions into primary cell [0, 1) in fractional coordinates (in-place).
Parameters
----------
positions : wp.array(dtype=wp.vec3f or wp.vec3d)
Atomic positions. Shape (N,). MODIFIED in-place.
cells : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell matrices. Shape (B,).
cells_inv : wp.array
Pre-computed inverse of cells. Shape (B,).
Caller must pre-compute via ``compute_cell_inverse``.
batch_idx : wp.array(dtype=wp.int32), optional
System index for each atom. Shape (N,). If None, assumes single system.
device : str, optional
Warp device.
"""
if device is None:
device = positions.device
num_atoms = positions.shape[0]
vec_dtype = positions.dtype
if batch_idx is None:
wp.launch(
_wrap_positions_single_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, cells_inv, cells],
device=device,
)
else:
wp.launch(
_wrap_positions_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, batch_idx, cells_inv, cells],
device=device,
)
[docs]
def wrap_positions_to_cell_out(
positions: wp.array,
cells: wp.array,
cells_inv: wp.array,
positions_out: wp.array,
batch_idx: wp.array = None,
device: str = None,
) -> wp.array:
"""
Wrap positions into primary cell (non-mutating).
Parameters
----------
positions : wp.array(dtype=wp.vec3f or wp.vec3d)
Atomic positions. Shape (N,).
cells : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell matrices. Shape (B,).
cells_inv : wp.array
Pre-computed inverse of cells. Shape (B,).
Caller must pre-compute via ``compute_cell_inverse``.
positions_out : wp.array
Output positions. Shape (N,). Caller must pre-allocate.
batch_idx : wp.array(dtype=wp.int32), optional
System index for each atom. Shape (N,). If None, assumes single system.
device : str, optional
Warp device.
Returns
-------
wp.array
Wrapped positions.
"""
if device is None:
device = positions.device
num_atoms = positions.shape[0]
vec_dtype = positions.dtype
if batch_idx is None:
wp.launch(
_wrap_positions_out_single_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, cells_inv, cells, positions_out],
device=device,
)
else:
wp.launch(
_wrap_positions_out_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, batch_idx, cells_inv, cells, positions_out],
device=device,
)
return positions_out
[docs]
def cartesian_to_fractional(
positions: wp.array,
cells_inv: wp.array,
fractional: wp.array,
batch_idx: wp.array = None,
device: str = None,
) -> wp.array:
"""
Convert Cartesian coordinates to fractional coordinates.
s = cell_inv @ r
Parameters
----------
positions : wp.array(dtype=wp.vec3f or wp.vec3d)
Cartesian positions. Shape (N,).
cells_inv : wp.array
Pre-computed inverse of cells. Shape (B,).
Caller must pre-compute via ``compute_cell_inverse``.
fractional : wp.array
Output fractional coordinates. Shape (N,). Caller must pre-allocate.
batch_idx : wp.array(dtype=wp.int32), optional
System index for each atom. Shape (N,). If None, assumes single system.
device : str, optional
Warp device.
Returns
-------
wp.array
Fractional coordinates.
"""
if device is None:
device = positions.device
num_atoms = positions.shape[0]
vec_dtype = positions.dtype
if batch_idx is None:
wp.launch(
_cartesian_to_fractional_single_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, cells_inv, fractional],
device=device,
)
else:
wp.launch(
_cartesian_to_fractional_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, batch_idx, cells_inv, fractional],
device=device,
)
return fractional
[docs]
def fractional_to_cartesian(
fractional: wp.array,
cells: wp.array,
positions: wp.array,
batch_idx: wp.array = None,
device: str = None,
) -> wp.array:
"""
Convert fractional coordinates to Cartesian coordinates.
r = cell @ s
Parameters
----------
fractional : wp.array(dtype=wp.vec3f or wp.vec3d)
Fractional coordinates. Shape (N,).
cells : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell matrices. Shape (B,).
positions : wp.array
Output Cartesian positions. Shape (N,). Caller must pre-allocate.
batch_idx : wp.array(dtype=wp.int32), optional
System index for each atom. Shape (N,). If None, assumes single system.
device : str, optional
Warp device.
Returns
-------
wp.array
Cartesian positions.
"""
if device is None:
device = fractional.device
num_atoms = fractional.shape[0]
vec_dtype = fractional.dtype
if batch_idx is None:
wp.launch(
_fractional_to_cartesian_single_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[fractional, cells, positions],
device=device,
)
else:
wp.launch(
_fractional_to_cartesian_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[fractional, batch_idx, cells, positions],
device=device,
)
return positions