Interoperability#

Warp can interop with other Python-based frameworks such as NumPy through standard interface protocols.

NumPy#

Warp arrays may be converted to a NumPy array through the warp.array.numpy() method. When the Warp array lives on the cpu device this will return a zero-copy view onto the underlying Warp allocation. If the array lives on a cuda device then it will first be copied back to a temporary buffer and copied to NumPy.

Warp CPU arrays also implement the __array_interface__ protocol and so can be used to construct NumPy arrays directly:

w = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu")
a = np.array(w)
print(a)
> [1. 2. 3.]

Data type conversion utilities are also available for convenience:

warp_type = wp.float32
...
numpy_type = wp.dtype_to_numpy(warp_type)
...
a = wp.zeros(n, dtype=warp_type)
b = np.zeros(n, dtype=numpy_type)

To create Warp arrays from NumPy arrays, use warp.from_numpy() or pass the NumPy array as the data argument of the warp.array constructor directly.

warp.from_numpy(arr, dtype=None, shape=None, device=None, requires_grad=False)#

Returns a Warp array created from a NumPy array.

Parameters:
  • arr (ndarray) – The NumPy array providing the data to construct the Warp array.

  • dtype (type | None) – The data type of the new Warp array. If this is not provided, the data type will be inferred.

  • shape (Sequence[int] | None) – The shape of the Warp array.

  • device (Device | str | None) – The device on which the Warp array will be constructed.

  • requires_grad (bool) – Whether or not gradients will be tracked for this array.

Raises:

RuntimeError – The data type of the NumPy array is not supported.

Return type:

array

warp.dtype_from_numpy(numpy_dtype)#

Return the Warp dtype corresponding to a NumPy dtype.

warp.dtype_to_numpy(warp_dtype)#

Return the NumPy dtype corresponding to a Warp dtype.

PyTorch#

Warp provides helper functions to convert arrays to/from PyTorch:

w = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu")

# convert to Torch tensor
t = wp.to_torch(w)

# convert from Torch tensor
w = wp.from_torch(t)

These helper functions allow the conversion of Warp arrays to/from PyTorch tensors without copying the underlying data. At the same time, if available, gradient arrays and tensors are converted to/from PyTorch autograd tensors, allowing the use of Warp arrays in PyTorch autograd computations.

warp.from_torch(t, dtype=None, requires_grad=None, grad=None)#

Convert a Torch tensor to a Warp array without copying the data.

Parameters:
  • t (torch.Tensor) – The torch tensor to wrap.

  • dtype (warp.dtype, optional) – The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.

  • requires_grad (bool, optional) – Whether the resulting array should wrap the tensor’s gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor’s requires_grad value.

Returns:

The wrapped array.

Return type:

warp.array

warp.to_torch(a, requires_grad=None)#

Convert a Warp array to a Torch tensor without copying the data.

Parameters:
  • a (warp.array) – The Warp array to convert.

  • requires_grad (bool, optional) – Whether the resulting tensor should convert the array’s gradient, if it exists, to a grad tensor. Defaults to the array’s requires_grad value.

Returns:

The converted tensor.

Return type:

torch.Tensor

warp.device_from_torch(torch_device)#

Return the Warp device corresponding to a Torch device.

Return type:

Device

warp.device_to_torch(warp_device)#

Return the Torch device string corresponding to a Warp device.

Parameters:

warp_device (Device | str | None) – An identifier that can be resolved to a warp.context.Device.

Raises:

RuntimeError – The Warp device is not compatible with PyTorch.

Return type:

str

warp.dtype_from_torch(torch_dtype)#

Return the Warp dtype corresponding to a Torch dtype.

Parameters:

torch_dtype – A torch.dtype that has a corresponding Warp data type. Currently torch.bfloat16, torch.complex64, and torch.complex128 are not supported.

Raises:

TypeError – Unable to find a corresponding Warp data type.

warp.dtype_to_torch(warp_dtype)#

Return the Torch dtype corresponding to a Warp dtype.

Parameters:

warp_dtype – A Warp data type that has a corresponding torch.dtype. warp.uint16, warp.uint32, and warp.uint64 are mapped to the signed integer torch.dtype of the same width.

Raises:

TypeError – Unable to find a corresponding PyTorch data type.

To convert a PyTorch CUDA stream to a Warp CUDA stream and vice versa, Warp provides the following functions:

warp.stream_from_torch(stream_or_device=None)#

Convert from a Torch CUDA stream to a Warp CUDA stream.

warp.stream_to_torch(stream_or_device=None)#

Convert from a Warp CUDA stream to a Torch CUDA stream.

Example: Optimization using warp.from_torch()#

An example usage of minimizing a loss function over an array of 2D points written in Warp via PyTorch’s Adam optimizer using warp.from_torch() is as follows:

import warp as wp
import torch

wp.init()

@wp.kernel()
def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0)

# indicate requires_grad so that Warp can accumulate gradients in the grad buffers
xs = torch.randn(100, 2, requires_grad=True)
l = torch.zeros(1, requires_grad=True)
opt = torch.optim.Adam([xs], lr=0.1)

wp_xs = wp.from_torch(xs)
wp_l = wp.from_torch(l)

tape = wp.Tape()
with tape:
    # record the loss function kernel launch on the tape
    wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device)

for i in range(500):
    tape.zero()
    tape.backward(loss=wp_l)  # compute gradients
    # now xs.grad will be populated with the gradients computed by Warp
    opt.step()  # update xs (and thereby wp_xs)

    # these lines are only needed for evaluating the loss
    # (the optimization just needs the gradient, not the loss value)
    wp_l.zero_()
    wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device)
    print(f"{i}\tloss: {l.item()}")

Example: Optimization using warp.to_torch#

Less code is needed when we declare the optimization variables directly in Warp and use warp.to_torch() to convert them to PyTorch tensors. Here, we revisit the same example from above where now only a single conversion to a torch tensor is needed to supply Adam with the optimization variables:

import warp as wp
import numpy as np
import torch

wp.init()

@wp.kernel()
def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0)

# initialize the optimization variables in Warp
xs = wp.array(np.random.randn(100, 2), dtype=wp.float32, requires_grad=True)
l = wp.zeros(1, dtype=wp.float32, requires_grad=True)
# just a single wp.to_torch call is needed, Adam optimizes using the Warp array gradients
opt = torch.optim.Adam([wp.to_torch(xs)], lr=0.1)

tape = wp.Tape()
with tape:
    wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device)

for i in range(500):
    tape.zero()
    tape.backward(loss=l)
    opt.step()

    l.zero_()
    wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device)
    print(f"{i}\tloss: {l.numpy()[0]}")

Example: Optimization using torch.autograd.function#

One can insert Warp kernel launches in a PyTorch graph by defining a torch.autograd.Function class, which requires forward and backward functions to be defined. After mapping incoming torch arrays to Warp arrays, a Warp kernel may be launched in the usual way. In the backward pass, the same kernel’s adjoint may be launched by setting adjoint = True in wp.launch(). Alternatively, the user may choose to rely on Warp’s tape. In the following example, we demonstrate how Warp may be used to evaluate the Rosenbrock function in an optimization context:

import warp as wp
import numpy as np
import torch

wp.init()

pvec2 = wp.types.vector(length=2, dtype=wp.float32)

# Define the Rosenbrock function
@wp.func
def rosenbrock(x: float, y: float):
    return (1.0 - x) ** 2.0 + 100.0 * (y - x**2.0) ** 2.0

@wp.kernel
def eval_rosenbrock(
    xs: wp.array(dtype=pvec2),
    # outputs
    z: wp.array(dtype=float),
):
    i = wp.tid()
    x = xs[i]
    z[i] = rosenbrock(x[0], x[1])


class Rosenbrock(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xy, num_points):
        # ensure Torch operations complete before running Warp
        wp.synchronize_device()

        ctx.xy = wp.from_torch(xy, dtype=pvec2, requires_grad=True)
        ctx.num_points = num_points

        # allocate output
        ctx.z = wp.zeros(num_points, requires_grad=True)

        wp.launch(
            kernel=eval_rosenbrock,
            dim=ctx.num_points,
            inputs=[ctx.xy],
            outputs=[ctx.z]
        )

        # ensure Warp operations complete before returning data to Torch
        wp.synchronize_device()

        return wp.to_torch(ctx.z)

    @staticmethod
    def backward(ctx, adj_z):
        # ensure Torch operations complete before running Warp
        wp.synchronize_device()

        # map incoming Torch grads to our output variables
        ctx.z.grad = wp.from_torch(adj_z)

        wp.launch(
            kernel=eval_rosenbrock,
            dim=ctx.num_points,
            inputs=[ctx.xy],
            outputs=[ctx.z],
            adj_inputs=[ctx.xy.grad],
            adj_outputs=[ctx.z.grad],
            adjoint=True
        )

        # ensure Warp operations complete before returning data to Torch
        wp.synchronize_device()

        # return adjoint w.r.t. inputs
        return (wp.to_torch(ctx.xy.grad), None)


num_points = 1500
learning_rate = 5e-2

torch_device = wp.device_to_torch(wp.get_device())

rng = np.random.default_rng(42)
xy = torch.tensor(rng.normal(size=(num_points, 2)), dtype=torch.float32, requires_grad=True, device=torch_device)
opt = torch.optim.Adam([xy], lr=learning_rate)

for _ in range(10000):
    # step
    opt.zero_grad()
    z = Rosenbrock.apply(xy, num_points)
    z.backward(torch.ones_like(z))

    opt.step()

# minimum at (1, 1)
xy_np = xy.numpy(force=True)
print(np.mean(xy_np, axis=0))

Note that if Warp code is wrapped in a torch.autograd.function that gets called in torch.compile(), it will automatically exclude that function from compiler optimizations. If your script uses torch.compile(), we recommend using Pytorch version 2.3.0+, which has improvements that address this scenario.

CuPy/Numba#

Warp GPU arrays support the __cuda_array_interface__ protocol for sharing data with other Python GPU frameworks. Currently this is one-directional, so that Warp arrays can be used as input to any framework that also supports the __cuda_array_interface__ protocol, but not the other way around.

JAX#

Interoperability with JAX arrays is supported through the following methods. Internally these use the DLPack protocol to exchange data in a zero-copy way with JAX:

warp_array = wp.from_jax(jax_array)
jax_array = wp.to_jax(warp_array)

It may be preferable to use the DLPack protocol directly for better performance and control over stream synchronization behaviour.

warp.from_jax(jax_array, dtype=None)#

Convert a Jax array to a Warp array without copying the data.

Parameters:
  • jax_array (jax.Array) – The Jax array to convert.

  • dtype (optional) – The target data type of the resulting Warp array. Defaults to the Jax array’s data type mapped to a Warp data type.

Returns:

The converted Warp array.

Return type:

warp.array

warp.to_jax(warp_array)#

Convert a Warp array to a Jax array without copying the data.

Parameters:

warp_array (warp.array) – The Warp array to convert.

Returns:

The converted Jax array.

Return type:

jax.Array

warp.device_from_jax(jax_device)#

Return the Warp device corresponding to a Jax device.

Parameters:

jax_device (jax.Device) – A Jax device descriptor.

Raises:

RuntimeError – The Jax device is neither a CPU nor GPU device.

Return type:

Device

warp.device_to_jax(warp_device)#

Return the Jax device corresponding to a Warp device.

Returns:

jax.Device

Raises:

RuntimeError – Failed to find the corresponding Jax device.

Parameters:

warp_device (Device | str | None) –

warp.dtype_from_jax(jax_dtype)#

Return the Warp dtype corresponding to a Jax dtype.

Raises:

TypeError – Unable to find a corresponding Warp data type.

warp.dtype_to_jax(warp_dtype)#

Return the Jax dtype corresponding to a Warp dtype.

Parameters:

warp_dtype – A Warp data type that has a corresponding Jax data type.

Raises:

TypeError – Unable to find a corresponding Jax data type.

Using Warp kernels as JAX primitives#

Note

This is an experimental feature under development.

Warp kernels can be used as JAX primitives, which can be used to call Warp kernels inside of jitted JAX functions:

import warp as wp
import jax
import jax.numpy as jp

# import experimental feature
from warp.jax_experimental import jax_kernel

@wp.kernel
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = 3.0 * input[tid]

wp.init()

# create a Jax primitive from a Warp kernel
jax_triple = jax_kernel(triple_kernel)

# use the Warp kernel in a Jax jitted function
@jax.jit
def f():
    x = jp.arange(0, 64, dtype=jp.float32)
    return jax_triple(x)

print(f())

Since this is an experimental feature, there are some limitations:

  • All kernel arguments must be arrays.

  • Kernel launch dimensions are inferred from the shape of the first argument.

  • Input arguments are followed by output arguments in the Warp kernel definition.

  • There must be at least one input argument and at least one output argument.

  • Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).

  • All arrays must be contiguous.

  • Only the CUDA backend is supported.

Here is an example of an operation with three inputs and two outputs:

import warp as wp
import jax
import jax.numpy as jp

# import experimental feature
from warp.jax_experimental import jax_kernel

# kernel with multiple inputs and outputs
@wp.kernel
def multiarg_kernel(
    # inputs
    a: wp.array(dtype=float),
    b: wp.array(dtype=float),
    c: wp.array(dtype=float),
    # outputs
    ab: wp.array(dtype=float),
    bc: wp.array(dtype=float),
):
    tid = wp.tid()
    ab[tid] = a[tid] + b[tid]
    bc[tid] = b[tid] + c[tid]

wp.init()

# create a Jax primitive from a Warp kernel
jax_multiarg = jax_kernel(multiarg_kernel)

# use the Warp kernel in a Jax jitted function with three inputs and two outputs
@jax.jit
def f():
    a = jp.full(64, 1, dtype=jp.float32)
    b = jp.full(64, 2, dtype=jp.float32)
    c = jp.full(64, 3, dtype=jp.float32)
    return jax_multiarg(a, b, c)

x, y = f()

print(x)
print(y)

DLPack#

Warp supports the DLPack protocol included in the Python Array API standard v2022.12. See the Python Specification for DLPack for reference.

The canonical way to import an external array into Warp is using the warp.from_dlpack() function:

warp_array = wp.from_dlpack(external_array)

The external array can be a PyTorch tensor, Jax array, or any other array type compatible with this version of the DLPack protocol. For CUDA arrays, this approach requires the producer to perform stream synchronization which ensures that operations on the array are ordered correctly. The warp.from_dlpack() function asks the producer to synchronize the current Warp stream on the device where the array resides. Thus it should be safe to use the array in Warp kernels on that device without any additional synchronization.

The canonical way to export a Warp array to an external framework is to use the from_dlpack() function in that framework:

jax_array = jax.dlpack.from_dlpack(warp_array)
torch_tensor = torch.utils.dlpack.from_dlpack(warp_array)

For CUDA arrays, this will synchronize the current stream of the consumer framework with the current Warp stream on the array’s device. Thus it should be safe to use the wrapped array in the consumer framework, even if the array was previously used in a Warp kernel on the device.

Alternatively, arrays can be shared by explicitly creating PyCapsules using a to_dlpack() function provided by the producer framework. This approach may be used for older versions of frameworks that do not support the v2022.12 standard:

warp_array1 = wp.from_dlpack(jax.dlpack.to_dlpack(jax_array))
warp_array2 = wp.from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor))

jax_array = jax.dlpack.from_dlpack(wp.to_dlpack(warp_array))
torch_tensor = torch.utils.dlpack.from_dlpack(wp.to_dlpack(warp_array))

This approach is generally faster because it skips any stream synchronization, but another solution must be used to ensure correct ordering of operations. In situations where no synchronization is required, using this approach can yield better performance. This may be a good choice in situations like these:

  • The external framework is using the synchronous CUDA default stream.

  • Warp and the external framework are using the same CUDA stream.

  • Another synchronization mechanism is already in place.

warp.from_dlpack(source, dtype=None)#

Convert a source array or DLPack capsule into a Warp array without copying.

Parameters:
  • source – A DLPack-compatible array or PyCapsule

  • dtype – An optional Warp data type to interpret the source data.

Returns:

A new Warp array that uses the same underlying memory as the input pycapsule.

Return type:

array

warp.to_dlpack(wp_array)#

Convert a Warp array to another type of DLPack-compatible array.

Parameters:

wp_array (array) – The source Warp array that will be converted.

Returns:

A capsule containing a DLManagedTensor that can be converted to another array type without copying the underlying memory.