warp.jax_experimental.custom_call.jax_kernel#
- warp.jax_experimental.custom_call.jax_kernel(kernel, launch_dims=None, quiet=False)[source]#
Create a Jax primitive from a Warp kernel.
Deprecated since version 1.10.0: This version of
jax_kernel()is deprecated for JAX >= 0.5.0 and is not supported with JAX >= 0.8.0. Usewarp.jax_experimental.ffi.jax_kernel()instead, which is the default implementation as of Warp 1.10.This implementation requires JAX version 0.4.25 - 0.7.x. For JAX 0.8.0 and later, use the FFI-based implementation at
warp.jax_experimental.ffi.jax_kernel().- Parameters:
kernel – The Warp kernel to be wrapped.
launch_dims – Specify the kernel launch dimensions. If
None, dimensions are inferred from the shape of the first argument. This option when set will specify the output dimensions.quiet – If
True, suppress deprecation warnings with newer JAX versions.
- Raises:
RuntimeError – If JAX version is < 0.4.25 or >= 0.8.0.
- Limitations:
All kernel arguments must be contiguous arrays.
Input arguments are followed by output arguments in the Warp kernel definition.
There must be at least one input argument and at least one output argument.
Only the CUDA backend is supported.