warp.jax_experimental.ffi#

Current FFI-based implementation of JAX integration.

This module provides the Foreign Function Interface (FFI) implementation that supports JAX 0.4.25 and later, including JAX 0.8.0+. It is the default implementation as of Warp 1.10.

For low-level use cases, register_ffi_callback() provides direct FFI callback registration for functions that don’t use Warp-style type annotations.

API#

GraphMode

clear_jax_callable_graph_cache

Clear the graph cache of the given callable or all callables if None.

get_jax_callable_default_graph_cache_max

Get the maximum size of the graph cache for graphs captured using GraphMode.WARP, unlimited if None.

jax_callable

Create a JAX callback from an annotated Python function.

jax_kernel

Create a JAX callback from a Warp kernel.

register_ffi_callback

Create a JAX callback from a Python function.

set_jax_callable_default_graph_cache_max

Set the maximum size of the graph cache for graphs captured using GraphMode.WARP, unlimited if None.