warp.jax_experimental.custom_call.jax_kernel#

warp.jax_experimental.custom_call.jax_kernel(kernel, launch_dims=None, quiet=False)[source]#

Create a Jax primitive from a Warp kernel.

Deprecated since version 1.10.0: This version of jax_kernel() is deprecated for JAX >= 0.5.0 and is not supported with JAX >= 0.8.0. Use warp.jax_experimental.ffi.jax_kernel() instead, which is the default implementation as of Warp 1.10.

This implementation requires JAX version 0.4.25 - 0.7.x. For JAX 0.8.0 and later, use the FFI-based implementation at warp.jax_experimental.ffi.jax_kernel().

Parameters:
  • kernel – The Warp kernel to be wrapped.

  • launch_dims – Specify the kernel launch dimensions. If None, dimensions are inferred from the shape of the first argument. This option when set will specify the output dimensions.

  • quiet – If True, suppress deprecation warnings with newer JAX versions.

Raises:

RuntimeError – If JAX version is < 0.4.25 or >= 0.8.0.

Limitations:
  • All kernel arguments must be contiguous arrays.

  • Input arguments are followed by output arguments in the Warp kernel definition.

  • There must be at least one input argument and at least one output argument.

  • Only the CUDA backend is supported.