warp.jax_experimental.ffi.jax_callable#

warp.jax_experimental.ffi.jax_callable(
func,
num_outputs=1,
graph_mode=GraphMode.JAX,
vmap_method='broadcast_all',
output_dims=None,
in_out_argnames=None,
stage_in_argnames=None,
stage_out_argnames=None,
graph_cache_max=None,
module_preload_mode=ModulePreloadMode.CURRENT_DEVICE,
)[source]#

Create a JAX callback from an annotated Python function.

The Python function arguments must have type annotations like Warp kernels.

NOTE: This is an experimental feature under development.

Parameters:
  • func (Callable) – The Python function to call.

  • num_outputs (int) – Specify the number of output arguments if greater than 1. This must include the number of in_out_arguments.

  • graph_mode (GraphMode) – CUDA graph capture mode. GraphMode.JAX (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing JAX capture. GraphMode.WARP: Let Warp capture the graph. Use this mode when the callable cannot be used as a subgraph, such as when the callable uses conditional graph nodes. GraphMode.NONE: Disable graph capture. Use when the callable performs operations that are not legal in a graph, such as host synchronization.

  • vmap_method (str | None) – String specifying how the callback transforms under vmap(). This argument can also be specified for individual calls.

  • output_dims – Specify the default dimensions of output arrays. If None, output dimensions are inferred from the launch dimensions. This argument can also be specified for individual calls.

  • in_out_argnames – Names of arguments that are both inputs and outputs (aliased buffers). These must be array arguments that appear before any pure output arguments in the function signature. The number of in-out arguments is included in num_outputs.

  • stage_in_argnames – Names of input arguments that need to be copied with GraphMode.WARP_STAGED*. If None, copy all input arguments.

  • stage_out_argnames – Names of output arguments that need to be copied with GraphMode.WARP_STAGED*. If None, copy all output arguments.

  • graph_cache_max (int | None) – Maximum number of cached graphs captured using GraphMode.WARP. If None, use warp.jax_experimental.get_jax_callable_default_graph_cache_max().

  • module_preload_mode (ModulePreloadMode) – Specify the devices where the module should be preloaded.

Limitations:
  • All kernel arguments must be contiguous arrays or scalars.

  • Scalars must be static arguments in JAX.

  • Input and input-output arguments must precede the output arguments in the func definition.

  • There must be at least one output or input-output argument.

  • Only the CUDA backend is supported.