warp.jax_experimental.ffi.GraphMode#

class warp.jax_experimental.ffi.GraphMode(*values)[source]#

CUDA graph capture modes for warp.jax_experimental.jax_callable().

These modes control whether JAX or Warp captures a CUDA graph, and whether staging buffers are used when capturing with Warp.

__init__(*args, **kwds)[source]#

Methods

conjugate

Returns self, the complex conjugate of any int.

bit_length()

Number of bits necessary to represent self in binary.

bit_count()

Number of ones in the binary representation of the absolute value of self.

to_bytes([length, byteorder, signed])

Return an array of bytes representing an integer.

from_bytes(bytes[, byteorder, signed])

as_integer_ratio()

Return a pair of integers, whose ratio is equal to the original int.

is_integer()

Returns True.

__init__(*args, **kwds)

Attributes

real

the real part of a complex number

imag

the imaginary part of a complex number

numerator

the numerator of a rational number in lowest terms

denominator

the denominator of a rational number in lowest terms

NONE

Disable graph capture.

JAX

Let JAX capture the graph so the callable can be used as a subgraph within a larger JAX capture.

WARP

Let Warp capture the graph and replay it for matching buffer addresses.

WARP_STAGED

Capture a Warp graph using staging buffers and insert memcpy nodes inside the graph.

WARP_STAGED_EX

Capture a Warp graph using staging buffers and perform memcpy outside the graph.

NONE = 0#

Disable graph capture. Use when operations are not CUDA-graph compatible (for example, host synchronization).

JAX = 1#

Let JAX capture the graph so the callable can be used as a subgraph within a larger JAX capture.

WARP = 2#

Let Warp capture the graph and replay it for matching buffer addresses.

WARP_STAGED = 3#

Capture a Warp graph using staging buffers and insert memcpy nodes inside the graph.

WARP_STAGED_EX = 4#

Capture a Warp graph using staging buffers and perform memcpy outside the graph.