Create a JAX callback from a Python function.
The Python function must have the form func(inputs, outputs, attrs, ctx).
NOTE: This is an experimental feature under development.
- Parameters:
name (str) – A unique FFI callback name.
func (Callable) – The Python function to call.
graph_compatible (bool) – Whether the function can be called during CUDA graph capture.
- Return type:
None