nvalchemiops.torch.autograd: Autograd Utilities#
Autograd Utilities for Warp-PyTorch Integration#
This module provides utilities for integrating Warp’s automatic differentiation with PyTorch custom operators. It abstracts common patterns for:
Checking if any tensor requires gradients
Conditionally creating Warp tapes
Storing tape and warp arrays on output tensors
Retrieving them in backward passes
Decorator-based custom op registration with automatic backward generation
import warp as wp import torch from contextlib import contextmanager, nullcontext from typing import Any, Optional, Sequence, Union
Custom Op Registration#
- nvalchemiops.torch.autograd.warp_custom_op(name, outputs, grad_arrays=None, mutates_args=())[source]#
Decorator to create a Warp-backed PyTorch op with compile-safe autograd.
This decorator eliminates boilerplate by automatically generating: - A
torch.library.custom_opforward registered with fake/meta support - A hidden token input for runtime state handoff while the public wrapperstill exposes only the user-visible signature
A traceable
register_autogradwrapper that replays Warp tapes through an opaque backward custom opStream binding so Warp launches execute on PyTorch’s current CUDA stream
- Parameters:
name (str) – Full custom op name (e.g., “alchemiops::_my_kernel”).
outputs (list[OutputSpec]) – Specifications for each output tensor.
grad_arrays (list[str], optional) – Names of warp arrays to track for gradients. Should include output names first, then differentiable input names. If None, auto-generated from outputs + all inputs that are likely differentiable (excludes common non-differentiable names like neighbor_list, batch_idx, etc.).
mutates_args (tuple, default=()) – Arguments that are mutated by the op (passed to custom_op).
- Returns:
Decorator function.
- Return type:
Callable
Examples
>>> @warp_custom_op( ... name="alchemiops::_ewald_real_space_energy", ... outputs=[ ... OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)), ... ], ... grad_arrays=["energies", "positions", "charges", "cell", "alpha"], ... ) ... def _ewald_real_space_energy( ... positions: torch.Tensor, ... charges: torch.Tensor, ... cell: torch.Tensor, ... alpha: torch.Tensor, ... neighbor_list: torch.Tensor, ... neighbor_shifts: torch.Tensor, ... ) -> torch.Tensor: ... # Implementation here - no boilerplate needed! ... ... ... return energies
Notes
The decorated function should still call
attach_for_backward()at the end of grad-enabled forward execution so the registered forward op can collect the runtime Warp tape and arrays from the real output tensor.retain_graph=Trueis supported: the Warp tape is preserved across backward passes and zeroed before each replay.create_graph=Trueis not supported – Warp backward ops do not register a second-order autograd formula, so higher-order differentiation through them will raise. Usehybrid_forces=Truein electrostatics APIs when you need to combine explicit Warp forces with autograd-based charge-gradient forces.
- class nvalchemiops.torch.autograd.OutputSpec(name, dtype, shape, torch_dtype=None)[source]#
Specification for a custom op output.
- Parameters:
name (str) – Name of the output (used for backward pass).
dtype (wp dtype) – Warp dtype (e.g., wp.float64, wp.vec3d).
shape (Callable or tuple) – Either a tuple of ints, or a callable that takes the input tensors and returns the shape. For callable, signature should match the custom op’s input signature.
torch_dtype (torch.dtype, optional) – PyTorch dtype override. If omitted (
None), the dtype is inferred from the resolved Warp dtype via_wp_dtype_to_torch.
Examples
>>> OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)) >>> OutputSpec("forces", wp.vec3d, lambda pos, *_: (pos.shape[0], 3)) >>> OutputSpec("virial", wp.mat33d, (3, 3)) # Static shape
Warp-PyTorch Interop#
- nvalchemiops.torch.autograd.warp_stream_from_torch(*values)[source]#
Bind Warp launches to PyTorch’s current CUDA stream when tensors are CUDA.
- Parameters:
values (Any)
- nvalchemiops.torch.autograd.warp_from_torch(tensor, warp_dtype, requires_grad=None)[source]#
Convert a PyTorch tensor to a Warp array with proper gradient tracking.
- Parameters:
tensor (torch.Tensor) – Input PyTorch tensor
warp_dtype (wp.dtype) – Warp data type for the array
requires_grad (bool | None, optional) – Override gradient tracking. If None, inherits from tensor.requires_grad
- Returns:
Warp array with gradient tracking if needed
- Return type:
wp.array
- nvalchemiops.torch.autograd.needs_grad(*tensors)[source]#
Check if any of the provided tensors requires gradients.
This is useful for conditionally enabling Warp gradient tracking and tape recording only when needed for backpropagation.
- Parameters:
*tensors (torch.Tensor) – Variable number of PyTorch tensors to check
- Returns:
True if any tensor requires gradients, False otherwise
- Return type:
Examples
>>> positions = torch.randn(100, 3, requires_grad=True) >>> charges = torch.randn(100, requires_grad=False) >>> needs_grad(positions, charges) True >>> needs_grad(charges) False
Autograd Context Manager#
- class nvalchemiops.torch.autograd.WarpAutogradContextManager(enable)[source]#
Conditionally create a Warp tape as a context manager.
Returns a Warp Tape if enable=True for gradient recording, otherwise returns a nullcontext (no-op) for zero overhead.
- Parameters:
enable (bool) – Whether to create a tape for gradient recording
- Yields:
wp.Tape or nullcontext – Active tape for recording if enabled, otherwise nullcontext
Examples
>>> needs_grad_flag = needs_grad(positions, charges) >>> with WarpAutogradContextManager(needs_grad_flag) as tape: ... wp.launch(kernel, ...) >>> if needs_grad_flag: ... # tape is a wp.Tape instance ... tape.backward()
- nvalchemiops.torch.autograd.attach_for_backward(output, tape=None, **warp_arrays)[source]#
Attach Warp tape and arrays to a PyTorch tensor for later retrieval in backward.
This stores the tape and warp arrays as attributes on the output tensor, allowing them to be retrieved in the backward pass of a custom operator.
- Parameters:
output (torch.Tensor) – PyTorch tensor to attach attributes to (usually the output of forward)
tape (wp.Tape, optional) – Warp tape containing recorded operations for backward pass
**warp_arrays (wp.array) – Named warp arrays to store (e.g., positions=wp_positions, charges=wp_charges)
- Return type:
None
Examples
>>> attach_for_backward( ... output, ... tape=tape, ... positions=wp_positions, ... charges=wp_charges, ... energies=wp_energies, ... ) >>> # Later in backward: >>> tape = output._warp_tape >>> wp_positions = output._wp_positions
- nvalchemiops.torch.autograd.retrieve_for_backward(output, *array_names)[source]#
Retrieve Warp tape and arrays from a PyTorch tensor in backward pass.
- Parameters:
output (torch.Tensor) – PyTorch tensor that has attached Warp objects (from attach_for_backward)
*array_names (str) – Names of warp arrays to retrieve (without ‘_wp_’ prefix)
- Returns:
tape (wp.Tape) – The stored Warp tape
arrays (dict[str, wp.array]) – Dictionary mapping names to warp arrays
- Return type:
Examples
>>> tape, arrays = retrieve_for_backward( ... ctx.output, ... 'positions', 'charges', 'energies' ... ) >>> wp_positions = arrays['positions'] >>> tape.backward()
- nvalchemiops.torch.autograd.extract_gradients(ctx, warp_arrays, input_names)[source]#
Extract gradients from warp arrays and return in correct order for PyTorch.
This helper extracts gradients from warp arrays and returns them in the same order as the forward pass inputs, with None for inputs that don’t require gradients.
- Parameters:
ctx (Any) – PyTorch autograd context with saved tensors (must have attributes matching input_names)
warp_arrays (dict[str, wp.array]) – Dictionary mapping input names to warp arrays with computed gradients
input_names (Sequence[str]) – Names of inputs in the order they appear in forward function signature
- Returns:
Gradients in order, with None for inputs without requires_grad
- Return type:
tuple[Optional[torch.Tensor], …]
Examples
>>> # In backward function: >>> tape, arrays = retrieve_for_backward(ctx.output, 'positions', 'charges') >>> tape.backward() >>> return extract_gradients( ... ctx, ... arrays, ... ['positions', 'charges', 'cell', 'alpha'] ... ) >>> # Returns: (grad_pos, grad_charges, None, None)
- nvalchemiops.torch.autograd.standard_backward(ctx, grad_outputs, output_names, array_names, input_names, output_dtypes=None)[source]#
Standard backward implementation for Warp-PyTorch custom operators.
This function handles both single-output and multiple-output operators. It encapsulates the common backward pattern: 1. Retrieve tape and warp arrays from context 2. Set gradient(s) on output(s) 3. Run tape backward 4. Extract and return gradients
- Parameters:
ctx (Any) – PyTorch autograd context with saved tensors
grad_outputs (torch.Tensor or tuple[Optional[torch.Tensor], ...]) – Gradient(s) from upstream operations. - Single output: pass the gradient tensor directly - Multiple outputs: pass tuple of gradient tensors (None if unused in loss)
output_names (str or Sequence[str]) – Name(s) of the output array(s) stored in ctx. - Single output: ‘output’ or ‘energies’ - Multiple outputs: [‘energies’, ‘forces’]
array_names (Sequence[str]) – Names of ALL warp arrays that were attached (outputs + inputs). MUST include all output array names first! Examples: - Single output: [‘output’, ‘positions’, ‘charges’] - Multiple outputs: [‘energies’, ‘forces’, ‘positions’]
input_names (Sequence[str]) – Names of all inputs in forward function signature order
output_dtypes (Any or Sequence[Any], optional) – Warp dtype(s) for each output. Required for multiple outputs or non-float32 outputs. - Single output: wp.float32 (default) or wp.vec3f - Multiple outputs: [wp.float32, wp.vec3f]
- Returns:
Gradients for all inputs (None for those without requires_grad)
- Return type:
tuple[Optional[torch.Tensor], …]
Examples
Single output operator:
>>> # In forward: >>> attach_for_backward(output, tape=tape, output=wp_output, ... positions=wp_positions, charges=wp_charges) >>> >>> # In backward: >>> def backward(ctx, grad_output): ... return standard_backward( ... ctx, ... grad_outputs=grad_output, # Single tensor (note: parameter name) ... output_names='output', # Single string ... array_names=['output', 'positions', 'charges'], ... input_names=['positions', 'charges', 'cell', 'alpha'], ... )
Multiple output operator:
>>> # In forward: >>> attach_for_backward(energies, tape=tape, energies=wp_energies, ... forces=wp_forces, positions=wp_positions) >>> return energies, forces >>> >>> # In backward: >>> def backward(ctx, grad_energies, grad_forces): ... return standard_backward( ... ctx, ... grad_outputs=(grad_energies, grad_forces), # Tuple ... output_names=['energies', 'forces'], # List ... output_dtypes=[wp.float32, wp.vec3f], # Required! ... array_names=['energies', 'forces', 'positions'], ... input_names=['positions'], ... )