nvalchemiops.torch.autograd: Autograd Utilities#

Autograd Utilities for Warp-PyTorch Integration#

This module provides utilities for integrating Warp’s automatic differentiation with PyTorch custom operators. It abstracts common patterns for:

  1. Checking if any tensor requires gradients

  2. Conditionally creating Warp tapes

  3. Storing tape and warp arrays on output tensors

  4. Retrieving them in backward passes

  5. Decorator-based custom op registration with automatic backward generation

import warp as wp import torch from contextlib import contextmanager, nullcontext from typing import Any, Optional, Sequence, Union

Custom Op Registration#

nvalchemiops.torch.autograd.warp_custom_op(name, outputs, grad_arrays=None, mutates_args=())[source]#

Decorator to create a Warp-backed PyTorch op with compile-safe autograd.

This decorator eliminates boilerplate by automatically generating: - A torch.library.custom_op forward registered with fake/meta support - A hidden token input for runtime state handoff while the public wrapper

still exposes only the user-visible signature

  • A traceable register_autograd wrapper that replays Warp tapes through an opaque backward custom op

  • Stream binding so Warp launches execute on PyTorch’s current CUDA stream

Parameters:
  • name (str) – Full custom op name (e.g., “alchemiops::_my_kernel”).

  • outputs (list[OutputSpec]) – Specifications for each output tensor.

  • grad_arrays (list[str], optional) – Names of warp arrays to track for gradients. Should include output names first, then differentiable input names. If None, auto-generated from outputs + all inputs that are likely differentiable (excludes common non-differentiable names like neighbor_list, batch_idx, etc.).

  • mutates_args (tuple, default=()) – Arguments that are mutated by the op (passed to custom_op).

Returns:

Decorator function.

Return type:

Callable

Examples

>>> @warp_custom_op(
...     name="alchemiops::_ewald_real_space_energy",
...     outputs=[
...         OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
...     ],
...     grad_arrays=["energies", "positions", "charges", "cell", "alpha"],
... )
... def _ewald_real_space_energy(
...     positions: torch.Tensor,
...     charges: torch.Tensor,
...     cell: torch.Tensor,
...     alpha: torch.Tensor,
...     neighbor_list: torch.Tensor,
...     neighbor_shifts: torch.Tensor,
... ) -> torch.Tensor:
...     # Implementation here - no boilerplate needed!
...     ...
...     return energies

Notes

The decorated function should still call attach_for_backward() at the end of grad-enabled forward execution so the registered forward op can collect the runtime Warp tape and arrays from the real output tensor.

retain_graph=True is supported: the Warp tape is preserved across backward passes and zeroed before each replay. create_graph=True is not supported – Warp backward ops do not register a second-order autograd formula, so higher-order differentiation through them will raise. Use hybrid_forces=True in electrostatics APIs when you need to combine explicit Warp forces with autograd-based charge-gradient forces.

class nvalchemiops.torch.autograd.OutputSpec(name, dtype, shape, torch_dtype=None)[source]#

Specification for a custom op output.

Parameters:
  • name (str) – Name of the output (used for backward pass).

  • dtype (wp dtype) – Warp dtype (e.g., wp.float64, wp.vec3d).

  • shape (Callable or tuple) – Either a tuple of ints, or a callable that takes the input tensors and returns the shape. For callable, signature should match the custom op’s input signature.

  • torch_dtype (torch.dtype, optional) – PyTorch dtype override. If omitted (None), the dtype is inferred from the resolved Warp dtype via _wp_dtype_to_torch.

Examples

>>> OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],))
>>> OutputSpec("forces", wp.vec3d, lambda pos, *_: (pos.shape[0], 3))
>>> OutputSpec("virial", wp.mat33d, (3, 3))  # Static shape

Warp-PyTorch Interop#

nvalchemiops.torch.autograd.warp_stream_from_torch(*values)[source]#

Bind Warp launches to PyTorch’s current CUDA stream when tensors are CUDA.

Parameters:

values (Any)

nvalchemiops.torch.autograd.warp_from_torch(tensor, warp_dtype, requires_grad=None)[source]#

Convert a PyTorch tensor to a Warp array with proper gradient tracking.

Parameters:
  • tensor (torch.Tensor) – Input PyTorch tensor

  • warp_dtype (wp.dtype) – Warp data type for the array

  • requires_grad (bool | None, optional) – Override gradient tracking. If None, inherits from tensor.requires_grad

Returns:

Warp array with gradient tracking if needed

Return type:

wp.array

nvalchemiops.torch.autograd.needs_grad(*tensors)[source]#

Check if any of the provided tensors requires gradients.

This is useful for conditionally enabling Warp gradient tracking and tape recording only when needed for backpropagation.

Parameters:

*tensors (torch.Tensor) – Variable number of PyTorch tensors to check

Returns:

True if any tensor requires gradients, False otherwise

Return type:

bool

Examples

>>> positions = torch.randn(100, 3, requires_grad=True)
>>> charges = torch.randn(100, requires_grad=False)
>>> needs_grad(positions, charges)
True
>>> needs_grad(charges)
False

Autograd Context Manager#

class nvalchemiops.torch.autograd.WarpAutogradContextManager(enable)[source]#

Conditionally create a Warp tape as a context manager.

Returns a Warp Tape if enable=True for gradient recording, otherwise returns a nullcontext (no-op) for zero overhead.

Parameters:

enable (bool) – Whether to create a tape for gradient recording

Yields:

wp.Tape or nullcontext – Active tape for recording if enabled, otherwise nullcontext

Examples

>>> needs_grad_flag = needs_grad(positions, charges)
>>> with WarpAutogradContextManager(needs_grad_flag) as tape:
...     wp.launch(kernel, ...)
>>> if needs_grad_flag:
...     # tape is a wp.Tape instance
...     tape.backward()
nvalchemiops.torch.autograd.attach_for_backward(output, tape=None, **warp_arrays)[source]#

Attach Warp tape and arrays to a PyTorch tensor for later retrieval in backward.

This stores the tape and warp arrays as attributes on the output tensor, allowing them to be retrieved in the backward pass of a custom operator.

Parameters:
  • output (torch.Tensor) – PyTorch tensor to attach attributes to (usually the output of forward)

  • tape (wp.Tape, optional) – Warp tape containing recorded operations for backward pass

  • **warp_arrays (wp.array) – Named warp arrays to store (e.g., positions=wp_positions, charges=wp_charges)

Return type:

None

Examples

>>> attach_for_backward(
...     output,
...     tape=tape,
...     positions=wp_positions,
...     charges=wp_charges,
...     energies=wp_energies,
... )
>>> # Later in backward:
>>> tape = output._warp_tape
>>> wp_positions = output._wp_positions
nvalchemiops.torch.autograd.retrieve_for_backward(output, *array_names)[source]#

Retrieve Warp tape and arrays from a PyTorch tensor in backward pass.

Parameters:
  • output (torch.Tensor) – PyTorch tensor that has attached Warp objects (from attach_for_backward)

  • *array_names (str) – Names of warp arrays to retrieve (without ‘_wp_’ prefix)

Returns:

  • tape (wp.Tape) – The stored Warp tape

  • arrays (dict[str, wp.array]) – Dictionary mapping names to warp arrays

Return type:

tuple[Tape, dict[str, array]]

Examples

>>> tape, arrays = retrieve_for_backward(
...     ctx.output,
...     'positions', 'charges', 'energies'
... )
>>> wp_positions = arrays['positions']
>>> tape.backward()
nvalchemiops.torch.autograd.extract_gradients(ctx, warp_arrays, input_names)[source]#

Extract gradients from warp arrays and return in correct order for PyTorch.

This helper extracts gradients from warp arrays and returns them in the same order as the forward pass inputs, with None for inputs that don’t require gradients.

Parameters:
  • ctx (Any) – PyTorch autograd context with saved tensors (must have attributes matching input_names)

  • warp_arrays (dict[str, wp.array]) – Dictionary mapping input names to warp arrays with computed gradients

  • input_names (Sequence[str]) – Names of inputs in the order they appear in forward function signature

Returns:

Gradients in order, with None for inputs without requires_grad

Return type:

tuple[Optional[torch.Tensor], …]

Examples

>>> # In backward function:
>>> tape, arrays = retrieve_for_backward(ctx.output, 'positions', 'charges')
>>> tape.backward()
>>> return extract_gradients(
...     ctx,
...     arrays,
...     ['positions', 'charges', 'cell', 'alpha']
... )
>>> # Returns: (grad_pos, grad_charges, None, None)
nvalchemiops.torch.autograd.standard_backward(ctx, grad_outputs, output_names, array_names, input_names, output_dtypes=None)[source]#

Standard backward implementation for Warp-PyTorch custom operators.

This function handles both single-output and multiple-output operators. It encapsulates the common backward pattern: 1. Retrieve tape and warp arrays from context 2. Set gradient(s) on output(s) 3. Run tape backward 4. Extract and return gradients

Parameters:
  • ctx (Any) – PyTorch autograd context with saved tensors

  • grad_outputs (torch.Tensor or tuple[Optional[torch.Tensor], ...]) – Gradient(s) from upstream operations. - Single output: pass the gradient tensor directly - Multiple outputs: pass tuple of gradient tensors (None if unused in loss)

  • output_names (str or Sequence[str]) – Name(s) of the output array(s) stored in ctx. - Single output: ‘output’ or ‘energies’ - Multiple outputs: [‘energies’, ‘forces’]

  • array_names (Sequence[str]) – Names of ALL warp arrays that were attached (outputs + inputs). MUST include all output array names first! Examples: - Single output: [‘output’, ‘positions’, ‘charges’] - Multiple outputs: [‘energies’, ‘forces’, ‘positions’]

  • input_names (Sequence[str]) – Names of all inputs in forward function signature order

  • output_dtypes (Any or Sequence[Any], optional) – Warp dtype(s) for each output. Required for multiple outputs or non-float32 outputs. - Single output: wp.float32 (default) or wp.vec3f - Multiple outputs: [wp.float32, wp.vec3f]

Returns:

Gradients for all inputs (None for those without requires_grad)

Return type:

tuple[Optional[torch.Tensor], …]

Examples

Single output operator:

>>> # In forward:
>>> attach_for_backward(output, tape=tape, output=wp_output,
...                     positions=wp_positions, charges=wp_charges)
>>>
>>> # In backward:
>>> def backward(ctx, grad_output):
...     return standard_backward(
...         ctx,
...         grad_outputs=grad_output,  # Single tensor (note: parameter name)
...         output_names='output',  # Single string
...         array_names=['output', 'positions', 'charges'],
...         input_names=['positions', 'charges', 'cell', 'alpha'],
...     )

Multiple output operator:

>>> # In forward:
>>> attach_for_backward(energies, tape=tape, energies=wp_energies,
...                     forces=wp_forces, positions=wp_positions)
>>> return energies, forces
>>>
>>> # In backward:
>>> def backward(ctx, grad_energies, grad_forces):
...     return standard_backward(
...         ctx,
...         grad_outputs=(grad_energies, grad_forces),  # Tuple
...         output_names=['energies', 'forces'],  # List
...         output_dtypes=[wp.float32, wp.vec3f],  # Required!
...         array_names=['energies', 'forces', 'positions'],
...         input_names=['positions'],
...     )