Runtime Reference¶
This section describes the Warp Python runtime API, how to manage memory, launch kernels, and high-level functionality for dealing with objects such as meshes and volumes. The APIs described in this section are intended to be used at the Python Scope and run inside the CPython interpreter. For a comprehensive list of functions available at the Kernel Scope, please see the Kernel Reference section.
Kernels¶
Kernels are launched with the wp.launch()
function on a specific device (CPU/GPU):
wp.launch(simple_kernel, dim=1024, inputs=[a, b, c], device="cuda")
Note that all the kernel inputs must live on the target device or a runtime exception will be raised. Kernels may be launched with multi-dimensional grid bounds. In this case, threads are not assigned a single index, but a coordinate in an n-dimensional grid, e.g.:
wp.launch(complex_kernel, dim=(128, 128, 3), ...)
Launches a 3D grid of threads with dimension 128 x 128 x 3. To retrieve the 3D index for each thread, use the following syntax:
i,j,k = wp.tid()
Note
Currently, kernels launched on CPU devices will be executed in serial. Kernels launched on CUDA devices will be launched in parallel with a fixed block-size.
In the Warp Compilation Model, kernels are just-in-time compiled into dynamic libraries and PTX using
C++/CUDA as an intermediate representation.
To avoid excessive runtime recompilation of kernel code, these files are stored in a cache directory
named with a module-dependent hash to allow for the reuse of previously compiled modules.
The location of the kernel cache is printed when Warp is initialized.
wp.clear_kernel_cache()
can be used to clear the kernel cache of previously
generated compilation artifacts as Warp does not automatically try to keep the cache below a certain size.
- warp.launch(kernel, dim, inputs=[], outputs=[], adj_inputs=[], adj_outputs=[], device=None, stream=None, adjoint=False, record_tape=True, record_cmd=False, max_blocks=0)[source]¶
Launch a Warp kernel on the target device
Kernel launches are asynchronous with respect to the calling Python thread.
- Parameters:
kernel – The name of a Warp kernel function, decorated with the
@wp.kernel
decoratordim (Tuple[int]) – The number of threads to launch the kernel, can be an integer, or a Tuple of ints with max of 4 dimensions
inputs (Sequence) – The input parameters to the kernel (optional)
outputs (Sequence) – The output parameters (optional)
adj_inputs (Sequence) – The adjoint inputs (optional)
adj_outputs (Sequence) – The adjoint outputs (optional)
device (Device | str | None) – The device to launch on (optional)
stream (Stream | None) – The stream to launch on (optional)
adjoint – Whether to run forward or backward pass (typically use False)
record_tape – When true the launch will be recorded the global wp.Tape() object when present
record_cmd – When True the launch will be returned as a
Launch
command object, the launch will not occur until the user callscmd.launch()
max_blocks – The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches. If negative or zero, the maximum hardware value will be used.
- warp.clear_kernel_cache()[source]¶
Clear the kernel cache directory of previously generated source code and compiler artifacts.
Only directories beginning with
wp_
will be deleted. This function only clears the cache for the current Warp version.- Return type:
None
Runtime Kernel Creation¶
Warp allows generating kernels on-the-fly with various customizations, including closure support. Refer to the Code Generation section for the latest features.
Arrays¶
Arrays are the fundamental memory abstraction in Warp; they are created through the following global constructors:
wp.empty(shape=1024, dtype=wp.vec3, device="cpu")
wp.zeros(shape=1024, dtype=float, device="cuda")
wp.full(shape=1024, value=10, dtype=int, device="cuda")
Arrays can also be constructed directly from numpy
ndarrays as follows:
r = np.random.rand(1024)
# copy to Warp owned array
a = wp.array(r, dtype=float, device="cpu")
# return a Warp array wrapper around the NumPy data (zero-copy)
a = wp.array(r, dtype=float, copy=False, device="cpu")
# return a Warp copy of the array data on the GPU
a = wp.array(r, dtype=float, device="cuda")
Note that for multi-dimensional data the dtype
parameter must be specified explicitly, e.g.:
r = np.random.rand((1024, 3))
# initialize as an array of vec3 objects
a = wp.array(r, dtype=wp.vec3, device="cuda")
If the shapes are incompatible, an error will be raised.
Warp arrays can also be constructed from objects that define the __cuda_array_interface__
attribute. For example:
import cupy
import warp as wp
device = wp.get_cuda_device()
r = cupy.arange(10)
# return a Warp array wrapper around the cupy data (zero-copy)
a = wp.array(r, device=device)
Arrays can be moved between devices using the array.to()
method:
host_array = wp.array(a, dtype=float, device="cpu")
# allocate and copy to GPU
device_array = host_array.to("cuda")
Additionally, arrays can be copied directly between memory spaces:
src_array = wp.array(a, dtype=float, device="cpu")
dest_array = wp.empty_like(host_array)
# copy from source CPU buffer to GPU
wp.copy(dest_array, src_array)
- class warp.array(*args, **kwargs)[source]¶
- __init__(data=None, dtype=Any, shape=None, strides=None, length=None, ptr=None, capacity=None, device=None, pinned=False, copy=True, owner=False, deleter=None, ndim=None, grad=None, requires_grad=False)[source]¶
Constructs a new Warp array object
When the
data
argument is a valid list, tuple, or ndarray the array will be constructed from this object’s data. For objects that are not stored sequentially in memory (e.g.: a list), then the data will first be flattened before being transferred to the memory space given by device.The second construction path occurs when the
ptr
argument is a non-zero uint64 value representing the start address in memory where existing array data resides, e.g.: from an external or C-library. The memory allocation should reside on the same device given by the device argument, and the user should set the length and dtype parameter appropriately.If neither
data
norptr
are specified, theshape
orlength
arguments are checked next. This construction path can be used to create new uninitialized arrays, but users are encouraged to callwp.empty()
,wp.zeros()
, orwp.full()
instead to create new arrays.If none of the above arguments are specified, a simple type annotation is constructed. This is used when annotating kernel arguments or struct members (e.g.,``arr: wp.array(dtype=float)``). In this case, only
dtype
andndim
are taken into account and no memory is allocated for the array.- Parameters:
data (Union[list, tuple, ndarray]) – An object to construct the array from, can be a Tuple, List, or generally any type convertible to an np.array
dtype (Union) – One of the available data types, such as
warp.float32
,warp.mat33
, or a custom struct. If dtype isAny
and data is an ndarray, then it will be inferred from the array data typeshape (tuple) – Dimensions of the array
strides (tuple) – Number of bytes in each dimension between successive elements of the array
length (int) – Number of elements of the data type (deprecated, users should use shape argument)
ptr (uint64) – Address of an external memory address to alias (data should be None)
capacity (int) – Maximum size in bytes of the ptr allocation (data should be None)
device (Devicelike) – Device the array lives on
copy (bool) – Whether the incoming data will be copied or aliased, this is only possible when the incoming data already lives on the device specified and types match
owner (bool) – Should the array object try to deallocate memory when it is deleted (deprecated, pass deleter if you wish to transfer ownership to Warp)
deleter (Callable) – Function to be called when deallocating the array, taking two arguments, pointer and size
requires_grad (bool) – Whether or not gradients will be tracked for this array, see
warp.Tape
for detailsgrad (array) – The gradient array to use
pinned (bool) – Whether to allocate pinned host memory, which allows asynchronous host-device transfers (only applicable with device=”cpu”)
- mark_read()[source]¶
Marks this array as having been read from in a kernel or recorded function on the tape.
- fill_(value)[source]¶
Set all array entries to value
- Parameters:
value – The value to set every array entry to. Must be convertible to the array’s
dtype
.- Raises:
ValueError – If value cannot be converted to the array’s
dtype
.
Examples
fill_()
can take lists or other sequences when filling arrays of vectors or matrices.>>> arr = wp.zeros(2, dtype=wp.mat22) >>> arr.numpy() array([[[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]]], dtype=float32) >>> arr.fill_([[1, 2], [3, 4]]) >>> arr.numpy() array([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]], dtype=float32)
- assign(src)[source]¶
Wraps
src
in anwarp.array
if it is not already one and copies the contents toself
.
- numpy()[source]¶
Converts the array to a
numpy.ndarray
(aliasing memory through the array interface protocol) If the array is on the GPU, a synchronous device-to-host copy (on the CUDA default stream) will be automatically performed to ensure that any outstanding work is completed.
- cptr()[source]¶
Return a ctypes cast of the array address.
Notes:
Only CPU arrays support this method.
The array must be contiguous.
Accesses to this object are not bounds checked.
For
float16
types, a pointer to the internaluint16
representation is returned.
- to(device, requires_grad=None)[source]¶
Returns a Warp array with this array’s data moved to the specified device, no-op if already on device.
- flatten()[source]¶
Returns a zero-copy view of the array collapsed to 1-D. Only supported for contiguous arrays.
- reshape(shape)[source]¶
Returns a reshaped array. Only supported for contiguous arrays.
- Parameters:
shape – An int or tuple of ints specifying the shape of the returned array.
- view(dtype)[source]¶
Returns a zero-copy view of this array’s memory with a different data type.
dtype
must have the same byte size of the array’s nativedtype
.
- contiguous()[source]¶
Returns a contiguous array with this array’s data. No-op if array is already contiguous.
- transpose(axes=None)[source]¶
Returns an zero-copy view of the array with axes transposed.
Note: The transpose operation will return an array with a non-contiguous access pattern.
- Parameters:
axes (optional) – Specifies the how the axes are permuted. If not specified, the axes order will be reversed.
Multi-dimensional Arrays¶
Multi-dimensional arrays can be constructed by passing a tuple of sizes for each dimension, e.g.: the following constructs a 2d array of size 1024x16:
wp.zeros(shape=(1024, 16), dtype=float, device="cuda")
When passing multi-dimensional arrays to kernels users must specify the expected array dimension inside the kernel signature,
e.g. to pass a 2d array to a kernel the number of dims is specified using the ndim=2
parameter:
@wp.kernel
def test(input: wp.array(dtype=float, ndim=2)):
Type-hint helpers are provided for common array sizes, e.g.: array2d()
, array3d()
, which are equivalent to calling array(..., ndim=2)`
, etc. To index a multi-dimensional array use a the following kernel syntax:
# returns a float from the 2d array
value = input[i,j]
To create an array slice use the following syntax, where the number of indices is less than the array dimensions:
# returns an 1d array slice representing a row of the 2d array
row = input[i]
Slice operators can be concatenated, e.g.: s = array[i][j][k]
. Slices can be passed to wp.func
user functions provided
the function also declares the expected array dimension. Currently only single-index slicing is supported.
Note
Currently Warp limits arrays to 4 dimensions maximum. This is in addition to the contained datatype, which may be 1-2 dimensional for vector and matrix types such as vec3
, and mat33
.
The following construction methods are provided for allocating zero-initialized and empty (non-initialized) arrays:
- warp.zeros(shape=None, dtype=float, device=None, requires_grad=False, pinned=False, **kwargs)[source]¶
Return a zero-initialized array
- Parameters:
shape (Tuple | None) – Array dimensions
dtype – Type of each element, e.g.: warp.vec3, warp.mat33, etc
device (Device | str | None) – Device that array will live on
requires_grad (bool) – Whether the array will be tracked for back propagation
pinned (bool) – Whether the array uses pinned host memory (only applicable to CPU arrays)
- Returns:
A warp.array object representing the allocation
- Return type:
- warp.zeros_like(src, device=None, requires_grad=None, pinned=None)[source]¶
Return a zero-initialized array with the same type and dimension of another array
- Parameters:
src (array) – The template array to use for shape, data type, and device
device (Device | str | None) – The device where the new array will be created (defaults to src.device)
requires_grad (bool | None) – Whether the array will be tracked for back propagation
pinned (bool | None) – Whether the array uses pinned host memory (only applicable to CPU arrays)
- Returns:
A warp.array object representing the allocation
- Return type:
- warp.ones(shape=None, dtype=float, device=None, requires_grad=False, pinned=False, **kwargs)[source]¶
Return a one-initialized array
- Parameters:
shape (Tuple | None) – Array dimensions
dtype – Type of each element, e.g.: warp.vec3, warp.mat33, etc
device (Device | str | None) – Device that array will live on
requires_grad (bool) – Whether the array will be tracked for back propagation
pinned (bool) – Whether the array uses pinned host memory (only applicable to CPU arrays)
- Returns:
A warp.array object representing the allocation
- Return type:
- warp.ones_like(src, device=None, requires_grad=None, pinned=None)[source]¶
Return a one-initialized array with the same type and dimension of another array
- Parameters:
src (array) – The template array to use for shape, data type, and device
device (Device | str | None) – The device where the new array will be created (defaults to src.device)
requires_grad (bool | None) – Whether the array will be tracked for back propagation
pinned (bool | None) – Whether the array uses pinned host memory (only applicable to CPU arrays)
- Returns:
A warp.array object representing the allocation
- Return type:
- warp.full(shape=None, value=0, dtype=Any, device=None, requires_grad=False, pinned=False, **kwargs)[source]¶
Return an array with all elements initialized to the given value
- Parameters:
shape (Tuple | None) – Array dimensions
value – Element value
dtype – Type of each element, e.g.: float, warp.vec3, warp.mat33, etc
device (Device | str | None) – Device that array will live on
requires_grad (bool) – Whether the array will be tracked for back propagation
pinned (bool) – Whether the array uses pinned host memory (only applicable to CPU arrays)
- Returns:
A warp.array object representing the allocation
- Return type:
- warp.full_like(src, value, device=None, requires_grad=None, pinned=None)[source]¶
Return an array with all elements initialized to the given value with the same type and dimension of another array
- Parameters:
src (array) – The template array to use for shape, data type, and device
value (Any) – Element value
device (Device | str | None) – The device where the new array will be created (defaults to src.device)
requires_grad (bool | None) – Whether the array will be tracked for back propagation
pinned (bool | None) – Whether the array uses pinned host memory (only applicable to CPU arrays)
- Returns:
A warp.array object representing the allocation
- Return type:
- warp.empty(shape=None, dtype=float, device=None, requires_grad=False, pinned=False, **kwargs)[source]¶
Returns an uninitialized array
- Parameters:
shape (Tuple | None) – Array dimensions
dtype – Type of each element, e.g.: warp.vec3, warp.mat33, etc
device (Device | str | None) – Device that array will live on
requires_grad (bool) – Whether the array will be tracked for back propagation
pinned (bool) – Whether the array uses pinned host memory (only applicable to CPU arrays)
- Returns:
A warp.array object representing the allocation
- Return type:
- warp.empty_like(src, device=None, requires_grad=None, pinned=None)[source]¶
Return an uninitialized array with the same type and dimension of another array
- Parameters:
src (array) – The template array to use for shape, data type, and device
device (Device | str | None) – The device where the new array will be created (defaults to src.device)
requires_grad (bool | None) – Whether the array will be tracked for back propagation
pinned (bool | None) – Whether the array uses pinned host memory (only applicable to CPU arrays)
- Returns:
A warp.array object representing the allocation
- Return type:
- warp.copy(dest, src, dest_offset=0, src_offset=0, count=0, stream=None)[source]¶
Copy array contents from src to dest.
- Parameters:
dest (array) – Destination array, must be at least as big as source buffer
src (array) – Source array
dest_offset (int) – Element offset in the destination array
src_offset (int) – Element offset in the source array
count (int) – Number of array elements to copy (will copy all elements if set to 0)
stream (Stream | None) – The stream on which to perform the copy (optional)
The stream, if specified, can be from any device. If the stream is omitted, then Warp selects a stream based on the following rules: (1) If the destination array is on a CUDA device, use the current stream on the destination device. (2) Otherwise, if the source array is on a CUDA device, use the current stream on the source device.
If neither source nor destination are on a CUDA device, no stream is used for the copy.
- warp.clone(src, device=None, requires_grad=None, pinned=None)[source]¶
Clone an existing array, allocates a copy of the src memory
- Parameters:
src (array) – The source array to copy
device (Device | str | None) – The device where the new array will be created (defaults to src.device)
requires_grad (bool | None) – Whether the array will be tracked for back propagation
pinned (bool | None) – Whether the array uses pinned host memory (only applicable to CPU arrays)
- Returns:
A warp.array object representing the allocation
- Return type:
Matrix Multiplication¶
Warp 2D array multiplication is built on NVIDIA’s CUTLASS library, which enables fast matrix multiplication of large arrays on the GPU.
If no GPU is detected, matrix multiplication falls back to Numpy’s implementation on the CPU.
Matrix multiplication is fully differentiable, and can be recorded on the tape like so:
tape = wp.Tape()
with tape:
wp.matmul(A, B, C, D)
wp.launch(loss_kernel, dim=(m, n), inputs=[D, loss])
tape.backward(loss=loss)
A_grad = A.grad.numpy()
Using the @
operator (D = A @ B
) will default to the same CUTLASS algorithm used in wp.matmul
.
- warp.matmul(a, b, c, d, alpha=1.0, beta=0.0, allow_tf32x3_arith=False)[source]¶
Computes a generic matrix-matrix multiplication (GEMM) of the form: d = alpha * (a @ b) + beta * c.
- Parameters:
a (array2d) – two-dimensional array containing matrix A
b (array2d) – two-dimensional array containing matrix B
c (array2d) – two-dimensional array containing matrix C
d (array2d) – two-dimensional array to which output D is written
alpha (float) – parameter alpha of GEMM
beta (float) – parameter beta of GEMM
allow_tf32x3_arith (bool) – whether to use CUTLASS’s 3xTF32 GEMMs, which enable accuracy similar to FP32 while using Tensor Cores
- warp.batched_matmul(a, b, c, d, alpha=1.0, beta=0.0, allow_tf32x3_arith=False)[source]¶
Computes a batched generic matrix-matrix multiplication (GEMM) of the form: d = alpha * (a @ b) + beta * c.
- Parameters:
a (array3d) – three-dimensional array containing A matrices. Overall array dimension is {batch_count, M, K}
b (array3d) – three-dimensional array containing B matrices. Overall array dimension is {batch_count, K, N}
c (array3d) – three-dimensional array containing C matrices. Overall array dimension is {batch_count, M, N}
d (array3d) – three-dimensional array to which output D is written. Overall array dimension is {batch_count, M, N}
alpha (float) – parameter alpha of GEMM
beta (float) – parameter beta of GEMM
allow_tf32x3_arith (bool) – whether to use CUTLASS’s 3xTF32 GEMMs, which enable accuracy similar to FP32 while using Tensor Cores
Data Types¶
Scalar Types¶
The following scalar storage types are supported for array structures:
bool |
boolean |
int8 |
signed byte |
uint8 |
unsigned byte |
int16 |
signed short |
uint16 |
unsigned short |
int32 |
signed integer |
uint32 |
unsigned integer |
int64 |
signed long integer |
uint64 |
unsigned long integer |
float16 |
half-precision float |
float32 |
single-precision float |
float64 |
double-precision float |
Warp supports float
and int
as aliases for wp.float32
and wp.int32
respectively.
Vectors¶
Warp provides built-in math and geometry types for common simulation and graphics problems. A full reference for operators and functions for these types is available in the Kernel Reference.
Warp supports vectors of numbers with an arbitrary length/numeric type. The built-in concrete types are as follows:
vec2 vec3 vec4 |
2D, 3D, 4D vector of single-precision floats |
vec2b vec3b vec4b |
2D, 3D, 4D vector of signed bytes |
vec2ub vec3ub vec4ub |
2D, 3D, 4D vector of unsigned bytes |
vec2s vec3s vec4s |
2D, 3D, 4D vector of signed shorts |
vec2us vec3us vec4us |
2D, 3D, 4D vector of unsigned shorts |
vec2i vec3i vec4i |
2D, 3D, 4D vector of signed integers |
vec2ui vec3ui vec4ui |
2D, 3D, 4D vector of unsigned integers |
vec2l vec3l vec4l |
2D, 3D, 4D vector of signed long integers |
vec2ul vec3ul vec4ul |
2D, 3D, 4D vector of unsigned long integers |
vec2h vec3h vec4h |
2D, 3D, 4D vector of half-precision floats |
vec2f vec3f vec4f |
2D, 3D, 4D vector of single-precision floats |
vec2d vec3d vec4d |
2D, 3D, 4D vector of double-precision floats |
spatial_vector |
6D vector of single-precision floats |
spatial_vectorf |
6D vector of single-precision floats |
spatial_vectord |
6D vector of double-precision floats |
spatial_vectorh |
6D vector of half-precision floats |
Vectors support most standard linear algebra operations, e.g.:
@wp.kernel
def compute( ... ):
# basis vectors
a = wp.vec3(1.0, 0.0, 0.0)
b = wp.vec3(0.0, 1.0, 0.0)
# take the cross product
c = wp.cross(a, b)
# compute
r = wp.dot(c, c)
...
It’s possible to declare additional vector types with different lengths and data types. This is done in outside of kernels in Python scope using warp.types.vector()
, for example:
# declare a new vector type for holding 5 double precision floats:
vec5d = wp.types.vector(length=5, dtype=wp.float64)
Once declared, the new type can be used when allocating arrays or inside kernels:
# create an array of vec5d
arr = wp.zeros(10, dtype=vec5d)
# use inside a kernel
@wp.kernel
def compute( ... ):
# zero initialize a custom named vector type
v = vec5d()
...
# component-wise initialize a named vector type
v = vec5d(wp.float64(1.0),
wp.float64(2.0),
wp.float64(3.0),
wp.float64(4.0),
wp.float64(5.0))
...
In addition, it’s possible to directly create anonymously typed instances of these vectors without declaring their type in advance. In this case the type will be inferred by the constructor arguments. For example:
@wp.kernel
def compute( ... ):
# zero initialize vector of 5 doubles:
v = wp.vector(dtype=wp.float64, length=5)
# scalar initialize a vector of 5 doubles to the same value:
v = wp.vector(wp.float64(1.0), length=5)
# component-wise initialize a vector of 5 doubles
v = wp.vector(wp.float64(1.0),
wp.float64(2.0),
wp.float64(3.0),
wp.float64(4.0),
wp.float64(5.0))
These can be used with all the standard vector arithmetic operators, e.g.: +
, -
, scalar multiplication, and can also be transformed using matrices with compatible dimensions, potentially returning vectors with a different length.
Matrices¶
Matrices with arbitrary shapes/numeric types are also supported. The built-in concrete matrix types are as follows:
mat22 mat33 mat44 |
2x2, 3x3, 4x4 matrix of single-precision floats |
mat22f mat33f mat44f |
2x2, 3x3, 4x4 matrix of single-precision floats |
mat22d mat33d mat44d |
2x2, 3x3, 4x4 matrix of double-precision floats |
mat22h mat33h mat44h |
2x2, 3x3, 4x4 matrix of half-precision floats |
spatial_matrix |
6x6 matrix of single-precision floats |
spatial_matrixf |
6x6 matrix of single-precision floats |
spatial_matrixd |
6x6 matrix of double-precision floats |
spatial_matrixh |
6x6 matrix of half-precision floats |
Matrices are stored in row-major format and support most standard linear algebra operations:
@wp.kernel
def compute( ... ):
# initialize matrix
m = wp.mat22(1.0, 2.0,
3.0, 4.0)
# compute inverse
minv = wp.inverse(m)
# transform vector
v = minv * wp.vec2(0.5, 0.3)
...
In a similar manner to vectors, it’s possible to declare new matrix types with arbitrary shapes and data types using wp.types.matrix()
, for example:
# declare a new 3x2 half precision float matrix type:
mat32h = wp.types.matrix(shape=(3,2), dtype=wp.float64)
# create an array of this type
a = wp.zeros(10, dtype=mat32h)
These can be used inside a kernel:
@wp.kernel
def compute( ... ):
...
# initialize a mat32h matrix
m = mat32h(wp.float16(1.0), wp.float16(2.0),
wp.float16(3.0), wp.float16(4.0),
wp.float16(5.0), wp.float16(6.0))
# declare a 2 component half precision vector
v2 = wp.vec2h(wp.float16(1.0), wp.float16(1.0))
# multiply by the matrix, returning a 3 component vector:
v3 = m * v2
...
It’s also possible to directly create anonymously typed instances inside kernels where the type is inferred from constructor arguments as follows:
@wp.kernel
def compute( ... ):
...
# create a 3x2 half precision matrix from components (row major ordering):
m = wp.matrix(
wp.float16(1.0), wp.float16(2.0),
wp.float16(1.0), wp.float16(2.0),
wp.float16(1.0), wp.float16(2.0),
shape=(3,2))
# zero initialize a 3x2 half precision matrix:
m = wp.matrix(wp.float16(0.0),shape=(3,2))
# create a 5x5 double precision identity matrix:
m = wp.identity(n=5, dtype=wp.float64)
As with vectors, you can do standard matrix arithmetic with these variables, along with multiplying matrices with compatible shapes and potentially returning a matrix with a new shape.
Quaternions¶
Warp supports quaternions with the layout i, j, k, w
where w
is the real part. Here are the built-in concrete quaternion types:
quat |
Single-precision floating point quaternion |
quatf |
Single-precision floating point quaternion |
quatd |
Double-precision floating point quaternion |
quath |
Half-precision floating point quaternion |
Quaternions can be used to transform vectors as follows:
@wp.kernel
def compute( ... ):
...
# construct a 30 degree rotation around the x-axis
q = wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), wp.degrees(30.0))
# rotate an axis by this quaternion
v = wp.quat_rotate(q, wp.vec3(0.0, 1.0, 0.0))
As with vectors and matrices, you can declare quaternion types with an arbitrary numeric type like so:
quatd = wp.types.quaternion(dtype=wp.float64)
You can also create identity quaternion and anonymously typed instances inside a kernel like so:
@wp.kernel
def compute( ... ):
...
# create a double precision identity quaternion:
qd = wp.quat_identity(dtype=wp.float64)
# precision defaults to wp.float32 so this creates a single precision identity quaternion:
qf = wp.quat_identity()
# create a half precision quaternion from components, or a vector/scalar:
qh = wp.quaternion(wp.float16(0.0),
wp.float16(0.0),
wp.float16(0.0),
wp.float16(1.0))
qh = wp.quaternion(
wp.vector(wp.float16(0.0),wp.float16(0.0),wp.float16(0.0)),
wp.float16(1.0))
Transforms¶
Transforms are 7D vectors of floats representing a spatial rigid body transformation in format (p, q) where p is a 3D vector, and q is a quaternion.
transform |
Single-precision floating point transform |
transformf |
Single-precision floating point transform |
transformd |
Double-precision floating point transform |
transformh |
Half-precision floating point transform |
Transforms can be constructed inside kernels from translation and rotation parts:
@wp.kernel
def compute( ... ):
...
# create a transform from a vector/quaternion:
t = wp.transform(
wp.vec3(1.0, 2.0, 3.0),
wp.quat_from_axis_angle(wp.vec3(0.0, 1.0, 0.0), wp.degrees(30.0)))
# transform a point
p = wp.transform_point(t, wp.vec3(10.0, 0.5, 1.0))
# transform a vector (ignore translation)
p = wp.transform_vector(t, wp.vec3(10.0, 0.5, 1.0))
As with vectors and matrices, you can declare transform types with an arbitrary numeric type using wp.types.transformation()
, for example:
transformd = wp.types.transformation(dtype=wp.float64)
You can also create identity transforms and anonymously typed instances inside a kernel like so:
@wp.kernel
def compute( ... ):
# create double precision identity transform:
qd = wp.transform_identity(dtype=wp.float64)
Structs¶
Users can define custom structure types using the @wp.struct
decorator as follows:
@wp.struct
class MyStruct:
param1: int
param2: float
param3: wp.array(dtype=wp.vec3)
Struct attributes must be annotated with their respective type. They can be constructed in Python scope and then passed to kernels as arguments:
@wp.kernel
def compute(args: MyStruct):
tid = wp.tid()
print(args.param1)
print(args.param2)
print(args.param3[tid])
# construct an instance of the struct in Python
s = MyStruct()
s.param1 = 10
s.param2 = 2.5
s.param3 = wp.zeros(shape=10, dtype=wp.vec3)
# pass to our compute kernel
wp.launch(compute, dim=10, inputs=[s])
An array of structs can be zero-initialized as follows:
a = wp.zeros(shape=10, dtype=MyStruct)
An array of structs can also be initialized from a list of struct objects:
a = wp.array([MyStruct(), MyStruct(), MyStruct()], dtype=MyStruct)
Example: Using a struct in gradient computation¶
import numpy as np
import warp as wp
@wp.struct
class TestStruct:
x: wp.vec3
a: wp.array(dtype=wp.vec3)
b: wp.array(dtype=wp.vec3)
@wp.kernel
def test_kernel(s: TestStruct):
tid = wp.tid()
s.b[tid] = s.a[tid] + s.x
@wp.kernel
def loss_kernel(s: TestStruct, loss: wp.array(dtype=float)):
tid = wp.tid()
v = s.b[tid]
wp.atomic_add(loss, 0, float(tid + 1) * (v[0] + 2.0 * v[1] + 3.0 * v[2]))
# create struct
ts = TestStruct()
# set members
ts.x = wp.vec3(1.0, 2.0, 3.0)
ts.a = wp.array(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), dtype=wp.vec3, requires_grad=True)
ts.b = wp.zeros(2, dtype=wp.vec3, requires_grad=True)
loss = wp.zeros(1, dtype=float, requires_grad=True)
tape = wp.Tape()
with tape:
wp.launch(test_kernel, dim=2, inputs=[ts])
wp.launch(loss_kernel, dim=2, inputs=[ts, loss])
tape.backward(loss)
print(loss)
print(ts.a)
Type Conversions¶
Warp is particularly strict regarding type conversions and does not perform any implicit conversion between numeric types.
The user is responsible for ensuring types for most arithmetic operators match, e.g.: x = float(0.0) + int(4)
will result in an error.
This can be surprising for users that are accustomed to C-style conversions but avoids a class of common bugs that result from implicit conversions.
Note
Warp does not currently perform implicit type conversions between numeric types.
Users should explicitly cast variables to compatible types using constructors like
int()
, float()
, wp.float16()
, wp.uint8()
, etc.
Note
For performance reasons, Warp relies on native compilers to perform numeric conversions (e.g., LLVM for CPU and NVRTC for CUDA). This is generally not a problem, but in some cases the results may vary on different devices. For example, the conversion wp.uint8(-1.0)
results in undefined behavior, since the floating point value -1.0 is out of range for unsigned integer types. C++ compilers are free to handle such cases as they see fit. Numeric conversions are only guaranteed to produce correct results when the value being converted is in the range supported by the target data type.
Constants¶
A Warp kernel can access Python variables defined outside of the kernel, which are treated as compile-time constants inside of the kernel.
TYPE_SPHERE = wp.constant(0)
TYPE_CUBE = wp.constant(1)
TYPE_CAPSULE = wp.constant(2)
@wp.kernel
def collide(geometry: wp.array(dtype=int)):
t = geometry[wp.tid()]
if t == TYPE_SPHERE:
print("sphere")
elif t == TYPE_CUBE:
print("cube")
elif t == TYPE_CAPSULE:
print("capsule")
Note that using wp.constant()
is no longer required, but it performs some type checking and can serve as a reminder that the variables are meant to be used as Warp constants.
The behavior is simple and intuitive when the referenced Python variables never change. For details and more complex scenarios, refer to External References and Constants. The Code Generation section contains additional information and tips for advanced usage.
Predefined Constants¶
For convenience, Warp has a number of predefined mathematical constants that
may be used both inside and outside Warp kernels.
The constants in the following table also have lowercase versions defined,
e.g. wp.E
and wp.e
are equivalent.
Name |
Value |
---|---|
wp.E |
2.71828182845904523536 |
wp.LOG2E |
1.44269504088896340736 |
wp.LOG10E |
0.43429448190325182765 |
wp.LN2 |
0.69314718055994530942 |
wp.LN10 |
2.30258509299404568402 |
wp.PHI |
1.61803398874989484820 |
wp.PI |
3.14159265358979323846 |
wp.HALF_PI |
1.57079632679489661923 |
wp.TAU |
6.28318530717958647692 |
wp.INF |
math.inf |
wp.NAN |
float(‘nan’) |
The wp.NAN
constant may only be used with floating-point types.
Comparisons involving wp.NAN
follow the IEEE 754 standard,
e.g. wp.float32(wp.NAN) == wp.float32(wp.NAN)
returns False
.
The wp.isnan()
built-in function can be used to determine whether a
value is a NaN (or if a vector, matrix, or quaternion contains a NaN entry).
The following example shows how positive and negative infinity
can be used with floating-point types in Warp using the wp.inf
constant:
@wp.kernel
def test_infinity(outputs: wp.array(dtype=wp.float32)):
outputs[0] = wp.float32(wp.inf) # inf
outputs[1] = wp.float32(-wp.inf) # -inf
outputs[2] = wp.float32(2.0 * wp.inf) # inf
outputs[3] = wp.float32(-2.0 * wp.inf) # -inf
outputs[4] = wp.float32(2.0 / 0.0) # inf
outputs[5] = wp.float32(-2.0 / 0.0) # -inf
Operators¶
Boolean Operators¶
a and b |
True if a and b are True |
a or b |
True if a or b is True |
not a |
True if a is False, otherwise False |
Note
Expressions such as if (a and b):
currently do not perform short-circuit evaluation.
In this case b
will also be evaluated even when a
is False
.
Users should take care to ensure that secondary conditions are safe to evaluate (e.g.: do not index out of bounds) in all cases.
Comparison Operators¶
a > b |
True if a strictly greater than b |
a < b |
True if a strictly less than b |
a >= b |
True if a greater than or equal to b |
a <= b |
True if a less than or equal to b |
a == b |
True if a equals b |
a != b |
True if a not equal to b |
Arithmetic Operators¶
a + b |
Addition |
a - b |
Subtraction |
a * b |
Multiplication |
a / b |
Floating point division |
a // b |
Floored division |
a ** b |
Exponentiation |
a % b |
Modulus |
Note
Since implicit conversions are not performed arguments types to operators should match.
Users should use type constructors, e.g.: float()
, int()
, wp.int64()
, etc. to cast variables
to the correct type. Also note that the multiplication expression a * b
is used to represent scalar
multiplication and matrix multiplication. The @
operator is not currently supported.
Streams¶
A CUDA stream is a sequence of operations that execute in order on the GPU. Operations from different streams may run concurrently and may be interleaved by the device scheduler. See the Streams documentation for more information on using streams.
- class warp.Stream(*args, **kwargs)[source]¶
- __init__(device=None, priority=0, **kwargs)[source]¶
Initialize the stream on a device with an optional specified priority.
- Parameters:
device (Device | str | None) – The CUDA device on which this stream will be created.
priority (int) – An optional integer specifying the requested stream priority. Can be -1 (high priority) or 0 (low/default priority). Values outside this range will be clamped.
cuda_stream (int) – A optional external stream handle passed as an integer. The caller is responsible for ensuring that the external stream does not get destroyed while it is referenced by this object.
- Raises:
RuntimeError – If function is called before Warp has completed initialization with a
device
that is not an instance ofDevice`
.RuntimeError –
device
is not a CUDA Device.RuntimeError – The stream could not be created on the device.
TypeError – The requested stream priority is not an integer.
- record_event(event=None)[source]¶
Record an event onto the stream.
- Parameters:
event (Event | None) – A warp.Event instance to be recorded onto the stream. If not provided, an
Event
on the same device will be created.- Raises:
RuntimeError – The provided
Event
is from a different device than the recording stream.- Return type:
- wait_event(event)[source]¶
Makes all future work in this stream wait until event has completed.
This function does not block the host thread.
- Parameters:
event (Event)
- wait_stream(other_stream, event=None)[source]¶
Records an event on other_stream and makes this stream wait on it.
All work added to this stream after this function has been called will delay their execution until all preceding commands in other_stream have completed.
This function does not block the host thread.
- Parameters:
other_stream (Stream) – The stream on which the calling stream will wait for previously issued commands to complete before executing subsequent commands.
event (Event | None) – An optional
Event
instance that will be used to record an event ontoother_stream
. IfNone
, an internally managedEvent
instance will be used.
- warp.get_stream(device=None)[source]¶
Return the stream currently used by the given device.
- Parameters:
device (Device | str | None) – An optional
Device
instance or device alias (e.g. “cuda:0”) for which the current stream will be returned. IfNone
, the default device will be used.- Raises:
RuntimeError – The device is not a CUDA device.
- Return type:
- warp.set_stream(stream, device=None, sync=False)[source]¶
Convenience function for calling
Device.set_stream()
on the givendevice
.- Parameters:
device (Device | str | None) – An optional
Device
instance or device alias (e.g. “cuda:0”) for which the current stream is to be replaced withstream
. IfNone
, the default device will be used.stream (Stream) – The stream to set as this device’s current stream.
sync (bool) – If
True
, thenstream
will perform a device-side synchronization with the device’s previous current stream.
- Return type:
None
- warp.wait_stream(other_stream, event=None)[source]¶
Convenience function for calling
Stream.wait_stream()
on the current stream.- Parameters:
other_stream (Stream) – The stream on which the calling stream will wait for previously issued commands to complete before executing subsequent commands.
event (Event | None) – An optional
Event
instance that will be used to record an event ontoother_stream
. IfNone
, an internally managedEvent
instance will be used.
- warp.synchronize_stream(stream_or_device=None)[source]¶
Synchronize the calling CPU thread with any outstanding CUDA work on the specified stream.
This function allows the host application code to ensure that all kernel launches and memory copies have completed on the stream.
- class warp.ScopedStream(stream, sync_enter=True, sync_exit=False)[source]¶
A context manager to temporarily change the current stream on a device.
- stream¶
The stream that will temporarily become the device’s default stream within the context.
- Type:
Stream or None
- saved_stream¶
The device’s previous current stream. This is restored as the device’s current stream on exiting the context.
- Type:
- sync_enter¶
Whether to synchronize this context’s stream with the device’s previous current stream on entering the context.
- Type:
- sync_exit¶
Whether to synchronize the device’s previous current with this context’s stream on exiting the context.
- Type:
- __init__(stream, sync_enter=True, sync_exit=False)[source]¶
Initializes the context manager with a stream and synchronization options.
- Parameters:
stream (Stream | None) – The stream that will temporarily become the device’s default stream within the context.
sync_enter (bool) – Whether to synchronize this context’s stream with the device’s previous current stream on entering the context.
sync_exit (bool) – Whether to synchronize the device’s previous current with this context’s stream on exiting the context.
Events¶
Events can be inserted into streams and used to synchronize a stream with a different one. See the Events documentation for information on how to use events for cross-stream synchronization or the CUDA Events Timing documentation for information on how to use events for measuring GPU performance.
- class warp.Event(*args, **kwargs)[source]¶
A CUDA event that can be recorded onto a stream.
Events can be used for device-side synchronization, which do not block the host thread.
- __init__(device=None, cuda_event=None, enable_timing=False)[source]¶
Initializes the event on a CUDA device.
- Parameters:
device (Device | str | None) – The CUDA device whose streams this event may be recorded onto. If
None
, then the current default device will be used.cuda_event – A pointer to a previously allocated CUDA event. If None, then a new event will be allocated on the associated device.
enable_timing (bool) – If
True
this event will record timing data.get_event_elapsed_time()
can be used to measure the time between two events created withenable_timing=True
and recorded onto streams.
- warp.record_event(event=None)[source]¶
Convenience function for calling
Stream.record_event()
on the current stream.
- warp.wait_event(event)[source]¶
Convenience function for calling
Stream.wait_event()
on the current stream.
- warp.synchronize_event(event)[source]¶
Synchronize the calling CPU thread with an event recorded on a CUDA stream.
This function allows the host application code to ensure that a specific synchronization point was reached.
- Parameters:
event (Event) – Event to wait for.
- warp.get_event_elapsed_time(start_event, end_event, synchronize=True)[source]¶
Get the elapsed time between two recorded events.
Both events must have been previously recorded with
record_event()
orwarp.Stream.record_event()
.If
synchronize
is False, the caller must ensure that device execution has reachedend_event
prior to callingget_event_elapsed_time()
.
Graphs¶
Launching kernels from Python introduces significant additional overhead compared to C++ or native programs. To address this, Warp exposes the concept of CUDA graphs to allow recording large batches of kernels and replaying them with very little CPU overhead.
To record a series of kernel launches use the wp.capture_begin()
and
wp.capture_end()
API as follows:
# begin capture
wp.capture_begin(device="cuda")
try:
# record launches
for i in range(100):
wp.launch(kernel=compute1, inputs=[a, b], device="cuda")
finally:
# end capture and return a graph object
graph = wp.capture_end(device="cuda")
We strongly recommend the use of the try-finally pattern when capturing graphs because the finally
statement will ensure wp.capture_end
gets called, even if an exception occurs during
capture, which would otherwise trap the stream in a capturing state.
Once a graph has been constructed it can be executed:
wp.capture_launch(graph)
The wp.ScopedCapture
context manager can be used to simplify the code and
ensure that wp.capture_end
is called regardless of exceptions:
with wp.ScopedCapture(device="cuda") as capture:
# record launches
for i in range(100):
wp.launch(kernel=compute1, inputs=[a, b], device="cuda")
wp.capture_launch(capture.graph)
Note that only launch calls are recorded in the graph, any Python executed outside of the kernel code will not be recorded. Typically it is only beneficial to use CUDA graphs when the graph will be reused or launched multiple times.
- warp.capture_begin(device=None, stream=None, force_module_load=None, external=False)[source]¶
Begin capture of a CUDA graph
Captures all subsequent kernel launches and memory operations on CUDA devices. This can be used to record large numbers of kernels and replay them with low overhead.
If device is specified, the capture will begin on the CUDA stream currently associated with the device. If stream is specified, the capture will begin on the given stream. If both are omitted, the capture will begin on the current stream of the current device.
- Parameters:
device (Device | str | None) – The CUDA device to capture on
stream – The CUDA stream to capture on
force_module_load – Whether to force loading of all kernels before capture. In general it is better to use
load_module()
to selectively load kernels. When running with CUDA drivers that support CUDA 12.3 or newer, this option is not recommended to be set toTrue
because kernels can be loaded during graph capture on more recent drivers. If this argument isNone
, then the behavior inherits fromwp.config.enable_graph_capture_module_load_by_default
if the driver is older than CUDA 12.3.external – Whether the capture was already started externally
- warp.capture_end(device=None, stream=None)[source]¶
Ends the capture of a CUDA graph
- Parameters:
- Returns:
A Graph object that can be launched with
capture_launch()
- Return type:
Graph
- warp.capture_launch(graph, stream=None)[source]¶
Launch a previously captured CUDA graph
- Parameters:
graph (Graph) – A Graph as returned by
capture_end()
stream (Stream | None) – A Stream to launch the graph on (optional)
Meshes¶
Warp provides a wp.Mesh
class to manage triangle mesh data. To create a mesh users provide a points, indices and optionally a velocity array:
mesh = wp.Mesh(points, indices, velocities)
Note
Mesh objects maintain references to their input geometry buffers. All buffers should live on the same device.
Meshes can be passed to kernels using their id
attribute which uniquely identifies the mesh by a unique uint64
value.
Once inside a kernel you can perform geometric queries against the mesh such as ray-casts or closest point lookups:
@wp.kernel
def raycast(mesh: wp.uint64,
ray_origin: wp.array(dtype=wp.vec3),
ray_dir: wp.array(dtype=wp.vec3),
ray_hit: wp.array(dtype=wp.vec3)):
tid = wp.tid()
t = float(0.0) # hit distance along ray
u = float(0.0) # hit face barycentric u
v = float(0.0) # hit face barycentric v
sign = float(0.0) # hit face sign
n = wp.vec3() # hit face normal
f = int(0) # hit face index
color = wp.vec3()
# ray cast against the mesh
if wp.mesh_query_ray(mesh, ray_origin[tid], ray_dir[tid], 1.e+6, t, u, v, sign, n, f):
# if we got a hit then set color to the face normal
color = n*0.5 + wp.vec3(0.5, 0.5, 0.5)
ray_hit[tid] = color
Users may update mesh vertex positions at runtime simply by modifying the points buffer.
After modifying point locations users should call Mesh.refit()
to rebuild the bounding volume hierarchy (BVH) structure and ensure that queries work correctly.
Note
Updating Mesh topology (indices) at runtime is not currently supported. Users should instead recreate a new Mesh object.
- class warp.Mesh(*args, **kwargs)[source]¶
- __init__(points=None, indices=None, velocities=None, support_winding_number=False)[source]¶
Class representing a triangle mesh.
- id¶
Unique identifier for this mesh object, can be passed to kernels.
- device¶
Device this object lives on, all buffers must live on the same device.
- Parameters:
points (
warp.array
) – Array of vertex positions of typewarp.vec3
indices (
warp.array
) – Array of triangle indices of typewarp.int32
, should be a 1d array with shape (num_tris * 3)velocities (
warp.array
) – Array of vertex velocities of typewarp.vec3
(optional)support_winding_number (bool) – If true the mesh will build additional datastructures to support wp.mesh_query_point_sign_winding_number() queries
- property points[source]¶
The array of mesh’s vertex positions of type
warp.vec3
.The Mesh.points property has a custom setter method. Users can modify the vertex positions in-place, but the refit() method must be called manually after such modifications. Alternatively, assigning a new array to this property is also supported. The new array must have the same shape as the original, and once assigned, the Mesh class will automatically perform a refit operation based on the new vertex positions.
Hash Grids¶
Many particle-based simulation methods such as the Discrete Element Method (DEM), or Smoothed Particle Hydrodynamics (SPH), involve iterating over spatial neighbors to compute force interactions. Hash grids are a well-established data structure to accelerate these nearest neighbor queries, and particularly well-suited to the GPU.
To support spatial neighbor queries Warp provides a HashGrid
object that may be created as follows:
grid = wp.HashGrid(dim_x=128, dim_y=128, dim_z=128, device="cuda")
grid.build(points=p, radius=r)
p
is an array of wp.vec3
point positions, and r
is the radius to use when building the grid.
Neighbors can then be iterated over inside the kernel code using wp.hash_grid_query()
and wp.hash_grid_query_next()
as follows:
@wp.kernel
def sum(grid : wp.uint64,
points: wp.array(dtype=wp.vec3),
output: wp.array(dtype=wp.vec3),
radius: float):
tid = wp.tid()
# query point
p = points[tid]
# create grid query around point
query = wp.hash_grid_query(grid, p, radius)
index = int(0)
sum = wp.vec3()
while(wp.hash_grid_query_next(query, index)):
neighbor = points[index]
# compute distance to neighbor point
dist = wp.length(p-neighbor)
if (dist <= radius):
sum += neighbor
output[tid] = sum
Note
The HashGrid
query will give back all points in cells that fall inside the query radius.
When there are hash conflicts it means that some points outside of query radius will be returned, and users should
check the distance themselves inside their kernels. The reason the query doesn’t do the check itself for each
returned point is because it’s common for kernels to compute the distance themselves, so it would redundant to
check/compute the distance twice.
- class warp.HashGrid(*args, **kwargs)[source]¶
- __init__(dim_x, dim_y, dim_z, device=None)[source]¶
Class representing a hash grid object for accelerated point queries.
- id¶
Unique identifier for this mesh object, can be passed to kernels.
- device¶
Device this object lives on, all buffers must live on the same device.
- build(points, radius)[source]¶
Updates the hash grid data structure.
This method rebuilds the underlying datastructure and should be called any time the set of points changes.
- Parameters:
points (
warp.array
) – Array of points of typewarp.vec3
radius (float) – The cell size to use for bucketing points, cells are cubes with edges of this width. For best performance the radius used to construct the grid should match closely to the radius used when performing queries.
Volumes¶
Sparse volumes are incredibly useful for representing grid data over large domains, such as signed distance fields (SDFs) for complex objects, or velocities for large-scale fluid flow. Warp supports reading sparse volumetric grids stored using the NanoVDB standard. Users can access voxels directly or use built-in closest-point or trilinear interpolation to sample grid data from world or local space.
Volume objects can be created directly from Warp arrays containing a NanoVDB grid, from the contents of a
standard .nvdb
file using load_from_nvdb()
,
from an uncompressed in-memory buffer using load_from_address()
,
or from a dense 3D NumPy array using load_from_numpy()
.
Volumes can also be created using allocate()
,
allocate_by_tiles()
or allocate_by_voxels()
.
The values for a Volume object can be modified in a Warp kernel using wp.volume_store()
.
Note
Warp does not currently support modifying the topology of sparse volumes at runtime.
Below we give an example of creating a Volume object from an existing NanoVDB file:
# open NanoVDB file on disk
file = open("mygrid.nvdb", "rb")
# create Volume object
volume = wp.Volume.load_from_nvdb(file, device="cpu")
Note
Files written by the NanoVDB library, commonly marked by the .nvdb
extension, can contain multiple grids with
various compression methods, but a Volume
object represents a single NanoVDB grid.
The first grid is loaded by default, then Warp volumes corresponding to the other grids in the file can be created
using repeated calls to load_next_grid()
.
NanoVDB’s uncompressed and zip-compressed file formats are supported out-of-the-box, blosc compressed files require
the blosc Python package to be installed.
To sample the volume inside a kernel we pass a reference to it by ID, and use the built-in sampling modes:
@wp.kernel
def sample_grid(volume: wp.uint64,
points: wp.array(dtype=wp.vec3),
samples: wp.array(dtype=float)):
tid = wp.tid()
# load sample point in world-space
p = points[tid]
# transform position to the volume's local-space
q = wp.volume_world_to_index(volume, p)
# sample volume with trilinear interpolation
f = wp.volume_sample(volume, q, wp.Volume.LINEAR, dtype=float)
# write result
samples[tid] = f
Warp also supports NanoVDB index grids, which provide a memory-efficient linearization of voxel indices that can refer to values in arbitrarily shaped arrays:
@wp.kernel
def sample_index_grid(volume: wp.uint64,
points: wp.array(dtype=wp.vec3),
voxel_values: wp.array(dtype=Any)):
tid = wp.tid()
# load sample point in world-space
p = points[tid]
# transform position to the volume's local-space
q = wp.volume_world_to_index(volume, p)
# sample volume with trilinear interpolation
background_value = voxel_values.dtype(0.0)
f = wp.volume_sample_index(volume, q, wp.Volume.LINEAR, voxel_values, background_value)
The coordinates of all indexable voxels can be recovered using get_voxels()
.
NanoVDB grids may also contains embedded blind data arrays; those can be accessed with the
feature_array()
function.
- class warp.Volume(*args, **kwargs)[source]¶
- CLOSEST = 0¶
Enum value to specify nearest-neighbor interpolation during sampling
- LINEAR = 1¶
Enum value to specify trilinear interpolation during sampling
- __init__(data, copy=True)[source]¶
Class representing a sparse grid.
- Parameters:
data (
warp.array
) – Array of bytes representing the volume in NanoVDB formatcopy (bool) – Whether the incoming data will be copied or aliased
- get_tile_count()[source]¶
Returns the number of tiles (NanoVDB leaf nodes) of the volume
- Return type:
- get_tiles(out=None)[source]¶
Returns the integer coordinates of all allocated tiles for this volume.
- Parameters:
out (
warp.array
, optional) – If provided, use the out array to store the tile coordinates, otherwise a new array will be allocated. out must be a contiguous array oftile_count
vec3i
ortile_count x 3
int32
on the same device as this volume.- Return type:
- get_voxel_count()[source]¶
Returns the total number of allocated voxels for this volume
- Return type:
- get_voxels(out=None)[source]¶
Returns the integer coordinates of all allocated voxels for this volume.
- Parameters:
out (
warp.array
, optional) – If provided, use the out array to store the voxel coordinates, otherwise a new array will be allocated. out must be a contiguous array ofvoxel_count
vec3i
orvoxel_count x 3
int32
on the same device as this volume.- Return type:
- class GridInfo(name, size_in_bytes, grid_index, grid_count, type_str, translation, transform_matrix)[source]¶
Grid metadata
- Parameters:
- translation: vec3f¶
Index-to-world translation
- transform_matrix: mat33f¶
Linear part of the index-to-world transform
- property dtype: type[source]¶
Type of the Volume’s values as a Warp type.
If the grid does not contain values (e.g. index grids) or if the NanoVDB type is not representable as a Warp type, returns
None
.
- property is_index: bool[source]¶
Whether this Volume contains an index grid, that is, a type of grid that does not explicitly store values but associates each voxel to linearized index.
- get_feature_array_count()[source]¶
Returns the number of supplemental data arrays stored alongside the grid
- Return type:
- class FeatureArrayInfo(name, ptr, value_size, value_count, type_str)[source]¶
Metadata for a supplemental data array
- get_feature_array_info(feature_index)[source]¶
Returns the metadata associated to the feature array at feature_index
- Parameters:
feature_index (int)
- Return type:
- feature_array(feature_index, dtype=None)[source]¶
Returns one the grid’s feature data arrays as a Warp array
- classmethod load_from_nvdb(file_or_buffer, device=None)[source]¶
Creates a Volume object from a serialized NanoVDB file or in-memory buffer.
- Returns:
A
warp.Volume
object.- Return type:
- classmethod load_from_address(grid_ptr, buffer_size=0, device=None)[source]¶
Creates a new
Volume
aliasing an in-memory grid buffer.In contrast to
load_from_nvdb()
which should be used to load serialized NanoVDB grids, here the buffer must be uncompressed and must not contain file header information. If the passed address does not contain a NanoVDB grid, the behavior of this function is undefined.- Parameters:
grid_ptr (int) – Integer address of the start of the grid buffer
buffer_size (int) – Size of the buffer, in bytes. If not provided, the size will be assumed to be that of the single grid starting at grid_ptr.
device – Device of the buffer, and of the returned Volume. If not provided, the current Warp device is assumed.
- Return type:
Returns the newly created Volume.
- load_next_grid()[source]¶
Tries to create a new warp Volume for the next grid that is linked to by this Volume.
The existence of a next grid is deduced from the grid_index and grid_count metadata as well as the size of this Volume’s in-memory buffer.
Returns the newly created Volume, or None if there is no next grid.
- Return type:
- classmethod load_from_numpy(ndarray, min_world=(0.0, 0.0, 0.0), voxel_size=1.0, bg_value=0.0, device=None)[source]¶
Creates a Volume object from a dense 3D NumPy array.
This function is only supported for CUDA devices.
- Parameters:
min_world – The 3D coordinate of the lower corner of the volume.
voxel_size – The size of each voxel in spatial coordinates.
bg_value – Background value
device – The CUDA device to create the volume on, e.g.: “cuda” or “cuda:0”.
ndarray (array)
- Returns:
A
warp.Volume
object.- Return type:
- classmethod allocate(min, max, voxel_size, bg_value=0.0, translation=(0.0, 0.0, 0.0), points_in_world_space=False, device=None)[source]¶
Allocate a new Volume based on the bounding box defined by min and max.
This function is only supported for CUDA devices.
Allocate a volume that is large enough to contain voxels [min[0], min[1], min[2]] - [max[0], max[1], max[2]], inclusive. If points_in_world_space is true, then min and max are first converted to index space with the given voxel size and translation, and the volume is allocated with those.
The smallest unit of allocation is a dense tile of 8x8x8 voxels, the requested bounding box is rounded up to tiles, and the resulting tiles will be available in the new volume.
- Parameters:
min (array-like) – Lower 3D coordinates of the bounding box in index space or world space, inclusive.
max (array-like) – Upper 3D coordinates of the bounding box in index space or world space, inclusive.
voxel_size (float) – Voxel size of the new volume.
bg_value (float or array-like) – Value of unallocated voxels of the volume, also defines the volume’s type, a
warp.vec3
volume is created if this is array-like, otherwise a float volume is createdtranslation (array-like) – translation between the index and world spaces.
device (Devicelike) – The CUDA device to create the volume on, e.g.: “cuda” or “cuda:0”.
- Return type:
- classmethod allocate_by_tiles(tile_points, voxel_size=None, bg_value=0.0, translation=(0.0, 0.0, 0.0), device=None, transform=None)[source]¶
Allocate a new Volume with active tiles for each point tile_points.
This function is only supported for CUDA devices.
The smallest unit of allocation is a dense tile of 8x8x8 voxels. This is the primary method for allocating sparse volumes. It uses an array of points indicating the tiles that must be allocated.
- Example use cases:
tile_points can mark tiles directly in index space as in the case this method is called by allocate.
tile_points can be a list of points used in a simulation that needs to transfer data to a volume.
- Parameters:
tile_points (
warp.array
) – Array of positions that define the tiles to be allocated. The array may use an integer scalar type (2D N-by-3 array ofwarp.int32
or 1D array of warp.vec3i values), indicating index space positions, or a floating point scalar type (2D N-by-3 array ofwarp.float32
or 1D array of warp.vec3f values), indicating world space positions. Repeated points per tile are allowed and will be efficiently deduplicated.voxel_size (float or array-like) – Voxel size(s) of the new volume. Ignored if transform is given.
bg_value (array-like, float, int or None) – Value of unallocated voxels of the volume, also defines the volume’s type. A
warp.vec3
volume is created if this is array-like, an index volume will be created if bg_value isNone
.translation (array-like) – Translation between the index and world spaces.
transform (array-like) – Linear transform between the index and world spaces. If
None
, deduced from voxel_size.device (Devicelike) – The CUDA device to create the volume on, e.g.: “cuda” or “cuda:0”.
- Return type:
- classmethod allocate_by_voxels(voxel_points, voxel_size=None, translation=(0.0, 0.0, 0.0), device=None, transform=None)[source]¶
Allocate a new Volume with active voxel for each point voxel_points.
This function creates an index Volume, a special kind of volume that does not any store any explicit payload but encodes a linearized index for each active voxel, allowing to lookup and sample data from arbitrary external arrays.
This function is only supported for CUDA devices.
- Parameters:
voxel_points (
warp.array
) – Array of positions that define the voxels to be allocated. The array may use an integer scalar type (2D N-by-3 array ofwarp.int32
or 1D array of warp.vec3i values), indicating index space positions, or a floating point scalar type (2D N-by-3 array ofwarp.float32
or 1D array of warp.vec3f values), indicating world space positions. Repeated points per tile are allowed and will be efficiently deduplicated.voxel_size (float or array-like) – Voxel size(s) of the new volume. Ignored if transform is given.
translation (array-like) – Translation between the index and world spaces.
transform (array-like) – Linear transform between the index and world spaces. If
None
, deduced from voxel_size.device (Devicelike) – The CUDA device to create the volume on, e.g.: “cuda” or “cuda:0”.
- Return type:
See also
Reference for the volume functions available in kernels.
Bounding Value Hierarchies (BVH)¶
The wp.Bvh
class can be used to create a BVH for a group of bounding volumes. This object can then be traversed
to determine which parts are intersected by a ray using bvh_query_ray()
and which parts overlap
with a certain bounding volume using bvh_query_aabb()
.
The following snippet demonstrates how to create a wp.Bvh
object from 100 random bounding volumes:
rng = np.random.default_rng(123)
num_bounds = 100
lowers = rng.random(size=(num_bounds, 3)) * 5.0
uppers = lowers + rng.random(size=(num_bounds, 3)) * 5.0
device_lowers = wp.array(lowers, dtype=wp.vec3, device="cuda:0")
device_uppers = wp.array(uppers, dtype=wp.vec3, device="cuda:0")
bvh = wp.Bvh(device_lowers, device_uppers)
- class warp.Bvh(*args, **kwargs)[source]¶
- __init__(lowers, uppers)[source]¶
Class representing a bounding volume hierarchy.
- id¶
Unique identifier for this bvh object, can be passed to kernels.
- device¶
Device this object lives on, all buffers must live on the same device.
- Parameters:
lowers (
warp.array
) – Array of lower boundswarp.vec3
uppers (
warp.array
) – Array of upper boundswarp.vec3
Example: BVH Ray Traversal¶
An example of performing a ray traversal on the data structure is as follows:
@wp.kernel
def bvh_query_ray(
bvh_id: wp.uint64,
start: wp.vec3,
dir: wp.vec3,
bounds_intersected: wp.array(dtype=wp.bool),
):
query = wp.bvh_query_ray(bvh_id, start, dir)
bounds_nr = wp.int32(0)
while wp.bvh_query_next(query, bounds_nr):
# The ray intersects the volume with index bounds_nr
bounds_intersected[bounds_nr] = True
bounds_intersected = wp.zeros(shape=(num_bounds), dtype=wp.bool, device="cuda:0")
query_start = wp.vec3(0.0, 0.0, 0.0)
query_dir = wp.normalize(wp.vec3(1.0, 1.0, 1.0))
wp.launch(
kernel=bvh_query_ray,
dim=1,
inputs=[bvh.id, query_start, query_dir, bounds_intersected],
device="cuda:0",
)
The Warp kernel bvh_query_ray
is launched with a single thread, provided the unique uint64
identifier of the wp.Bvh
object, parameters describing the ray, and an array to store the results.
In bvh_query_ray
, wp.bvh_query_ray()
is called once to obtain an object that is stored in the
variable query
. An integer is also allocated as bounds_nr
to store the volume index of the traversal.
A while statement is used for the actual traversal using wp.bvh_query_next()
,
which returns True
as long as there are intersecting bounds.
Example: BVH Volume Traversal¶
Similar to the ray-traversal example, we can perform volume traversal to find the volumes that are fully contained within a specified bounding box.
@wp.kernel
def bvh_query_aabb(
bvh_id: wp.uint64,
lower: wp.vec3,
upper: wp.vec3,
bounds_intersected: wp.array(dtype=wp.bool),
):
query = wp.bvh_query_aabb(bvh_id, lower, upper)
bounds_nr = wp.int32(0)
while wp.bvh_query_next(query, bounds_nr):
# The volume with index bounds_nr is fully contained
# in the (lower,upper) bounding box
bounds_intersected[bounds_nr] = True
bounds_intersected = wp.zeros(shape=(num_bounds), dtype=wp.bool, device="cuda:0")
query_lower = wp.vec3(4.0, 4.0, 4.0)
query_upper = wp.vec3(6.0, 6.0, 6.0)
wp.launch(
kernel=bvh_query_aabb,
dim=1,
inputs=[bvh.id, query_lower, query_upper, bounds_intersected],
device="cuda:0",
)
The kernel is nearly identical to the ray-traversal example, except we obtain query
using
wp.bvh_query_aabb()
.
Marching Cubes¶
The wp.MarchingCubes
class can be used to extract a 2-D mesh approximating an
isosurface of a 3-D scalar field. The resulting triangle mesh can be saved to a USD
file using the warp.renderer.UsdRenderer
.
See warp/examples/core/example_marching_cubes.py for a usage example.
- class warp.MarchingCubes(*args, **kwargs)[source]¶
- __init__(nx, ny, nz, max_verts, max_tris, device=None)[source]¶
CUDA-based Marching Cubes algorithm to extract a 2D surface mesh from a 3D volume.
- id¶
Unique identifier for this object.
- verts¶
Array of vertex positions of type
warp.vec3f
for the output surface mesh. This is populated after runningsurface()
.- Type:
- indices¶
Array containing indices of type
warp.int32
defining triangles for the output surface mesh. This is populated after runningsurface()
.Each set of three consecutive integers in the array represents a single triangle, in which each integer is an index referring to a vertex in the
verts
array.- Type:
- Parameters:
nx (int) – Number of cubes in the x-direction.
ny (int) – Number of cubes in the y-direction.
nz (int) – Number of cubes in the z-direction.
max_verts (int) – Maximum expected number of vertices (used for array preallocation).
max_tris (int) – Maximum expected number of triangles (used for array preallocation).
device (Devicelike) – CUDA device on which to run marching cubes and allocate memory.
- Raises:
RuntimeError –
device
not a CUDA device.
Note
The shape of the marching cubes should match the shape of the scalar field being surfaced.
- resize(nx, ny, nz, max_verts, max_tris)[source]¶
Update the expected input and maximum output sizes for the marching cubes calculation.
This function has no immediate effect on the underlying buffers. The new values take effect on the next
surface()
call.- Parameters:
nx (int) – Number of cubes in the x-direction.
ny (int) – Number of cubes in the y-direction.
nz (int) – Number of cubes in the z-direction.
max_verts (int) – Maximum expected number of vertices (used for array preallocation).
max_tris (int) – Maximum expected number of triangles (used for array preallocation).
- Return type:
None
- surface(field, threshold)[source]¶
Compute a 2D surface mesh of a given isosurface from a 3D scalar field.
The triangles and vertices defining the output mesh are written to the
indices
andverts
arrays.- Parameters:
- Raises:
ValueError –
field
is not a 3D array.ValueError – Marching cubes shape does not match the shape of
field
.RuntimeError –
max_verts
and/ormax_tris
might be too small to hold the surface mesh.
- Return type:
None
Profiling¶
wp.ScopedTimer
objects can be used to gain some basic insight into the performance of Warp applications:
with wp.ScopedTimer("grid build"):
self.grid.build(self.x, self.point_radius)
This results in a printout at runtime to the standard output stream like:
grid build took 0.06 ms
See Profiling documentation for more information.
- class warp.ScopedTimer(name, active=True, print=True, detailed=False, dict=None, use_nvtx=False, color='rapids', synchronize=False, cuda_filter=0, report_func=None, skip_tape=False)[source]
- indent = -1
- enabled = True
- __init__(name, active=True, print=True, detailed=False, dict=None, use_nvtx=False, color='rapids', synchronize=False, cuda_filter=0, report_func=None, skip_tape=False)[source]
Context manager object for a timer
- Parameters:
name (str) – Name of timer
active (bool) – Enables this timer
print (bool) – At context manager exit, print elapsed time to sys.stdout
detailed (bool) – Collects additional profiling data using cProfile and calls
print_stats()
at context exitdict (dict) – A dictionary of lists to which the elapsed time will be appended using
name
as a keyuse_nvtx (bool) – If true, timing functionality is replaced by an NVTX range
color (int or str) – ARGB value (e.g. 0x00FFFF) or color name (e.g. ‘cyan’) associated with the NVTX range
synchronize (bool) – Synchronize the CPU thread with any outstanding CUDA work to return accurate GPU timings
cuda_filter (int) – Filter flags for CUDA activity timing, e.g.
warp.TIMING_KERNEL
orwarp.TIMING_ALL
report_func (Callable) – A callback function to print the activity report (
wp.timing_print()
is used by default)skip_tape (bool) – If true, the timer will not be recorded in the tape
- extra_msg
Can be set to a string that will be added to the printout at context exit.
- Type:
- elapsed
The duration of the
with
block used with this object- Type:
- timing_results
The list of activity timing results, if collection was requested using
cuda_filter
- Type: