warp.jax_experimental.ffi.jax_callable#
- warp.jax_experimental.ffi.jax_callable(
- func,
- num_outputs=1,
- graph_mode=GraphMode.JAX,
- vmap_method='broadcast_all',
- output_dims=None,
- in_out_argnames=None,
- stage_in_argnames=None,
- stage_out_argnames=None,
- graph_cache_max=None,
- module_preload_mode=ModulePreloadMode.CURRENT_DEVICE,
Create a JAX callback from an annotated Python function.
The Python function arguments must have type annotations like Warp kernels.
NOTE: This is an experimental feature under development.
- Parameters:
func (Callable) – The Python function to call.
num_outputs (int) – Specify the number of output arguments if greater than 1. This must include the number of
in_out_arguments.graph_mode (GraphMode) – CUDA graph capture mode.
GraphMode.JAX(default): Let JAX capture the graph, which may be used as a subgraph in an enclosing JAX capture.GraphMode.WARP: Let Warp capture the graph. Use this mode when the callable cannot be used as a subgraph, such as when the callable uses conditional graph nodes.GraphMode.NONE: Disable graph capture. Use when the callable performs operations that are not legal in a graph, such as host synchronization.vmap_method (str | None) – String specifying how the callback transforms under
vmap(). This argument can also be specified for individual calls.output_dims – Specify the default dimensions of output arrays. If
None, output dimensions are inferred from the launch dimensions. This argument can also be specified for individual calls.in_out_argnames – Names of arguments that are both inputs and outputs (aliased buffers). These must be array arguments that appear before any pure output arguments in the function signature. The number of in-out arguments is included in
num_outputs.stage_in_argnames – Names of input arguments that need to be copied with
GraphMode.WARP_STAGED*. IfNone, copy all input arguments.stage_out_argnames – Names of output arguments that need to be copied with
GraphMode.WARP_STAGED*. IfNone, copy all output arguments.graph_cache_max (int | None) – Maximum number of cached graphs captured using
GraphMode.WARP. IfNone, usewarp.jax_experimental.get_jax_callable_default_graph_cache_max().module_preload_mode (ModulePreloadMode) – Specify the devices where the module should be preloaded.
- Limitations:
All kernel arguments must be contiguous arrays or scalars.
Scalars must be static arguments in JAX.
Input and input-output arguments must precede the output arguments in the
funcdefinition.There must be at least one output or input-output argument.
Only the CUDA backend is supported.