warp.jax_experimental#
Experimental JAX integration for calling Warp kernels from JAX.
This module enables using Warp kernels as JAX primitives, allowing them to be
called inside jitted JAX functions. The
jax_kernel function wraps
individual Warp kernels, while
jax_callable wraps Python
functions that launch multiple kernels. Both support automatic differentiation,
custom launch dimensions, and CUDA graph capture.
Caution
This module is experimental and less stable than the core Warp API. The interface may change as new functionality is added and to accommodate changes in upcoming JAX library versions.
- Usage:
This module must be explicitly imported:
import warp.jax_experimental
See also
Using Warp kernels as JAX primitives in the user guide for detailed examples and usage patterns.
Additional Submodules#
These modules must be explicitly imported (e.g., import warp.jax_experimental.custom_call).