warp.jax_experimental.ffi.register_ffi_callback#

warp.jax_experimental.ffi.register_ffi_callback(name, func, graph_compatible=True)[source]#

Create a JAX callback from a Python function.

The Python function must have the form func(inputs, outputs, attrs, ctx).

NOTE: This is an experimental feature under development.

Parameters:
  • name (str) – A unique FFI callback name.

  • func (Callable) – The Python function to call.

  • graph_compatible (bool) – Whether the function can be called during CUDA graph capture.

Return type:

None