warp.jax_experimental.ffi.jax_kernel#

warp.jax_experimental.ffi.jax_kernel(
kernel,
num_outputs=1,
vmap_method='broadcast_all',
launch_dims=None,
output_dims=None,
in_out_argnames=None,
module_preload_mode=ModulePreloadMode.CURRENT_DEVICE,
enable_backward=False,
)[source]#

Create a JAX callback from a Warp kernel.

NOTE: This is an experimental feature under development.

Parameters:
  • kernel – The Warp kernel to launch.

  • num_outputs – Specify the number of output arguments if greater than 1. This must include the number of in_out_arguments.

  • vmap_method – String specifying how the callback transforms under vmap(). This argument can also be specified for individual calls.

  • launch_dims – Specify the default kernel launch dimensions. If None, launch dimensions are inferred from the shape of the first array argument. This argument can also be specified for individual calls.

  • output_dims – Specify the default dimensions of output arrays. If None, output dimensions are inferred from the launch dimensions. This argument can also be specified for individual calls.

  • in_out_argnames – Names of arguments that are both inputs and outputs (aliased buffers). These must be array arguments that appear before any pure output arguments in the kernel signature. The number of in-out arguments is included in num_outputs. Not supported when enable_backward=True.

  • module_preload_mode – Specify the devices where the module should be preloaded.

  • enable_backward (bool) – Enable automatic differentiation for this kernel.

Limitations:
  • All kernel arguments must be contiguous arrays or scalars.

  • Scalars must be static arguments in JAX.

  • Input and input-output arguments must precede the output arguments in the kernel definition.

  • There must be at least one output or input-output argument.

  • Only the CUDA backend is supported.