warp.jax_experimental#

Experimental JAX integration for calling Warp kernels from JAX.

This module enables using Warp kernels as JAX primitives, allowing them to be called inside jitted JAX functions. The jax_kernel function wraps individual Warp kernels, while jax_callable wraps Python functions that launch multiple kernels. Both support automatic differentiation, custom launch dimensions, and CUDA graph capture.

Caution

This module is experimental and less stable than the core Warp API. The interface may change as new functionality is added and to accommodate changes in upcoming JAX library versions.

Usage:

This module must be explicitly imported:

import warp.jax_experimental

See also

Using Warp kernels as JAX primitives in the user guide for detailed examples and usage patterns.

Additional Submodules#

These modules must be explicitly imported (e.g., import warp.jax_experimental.custom_call).

API#