warp.device_from_jax#

warp.device_from_jax(jax_device)[source]#

Return the Warp device corresponding to a Jax device.

Parameters:

jax_device (jax.Device) – A Jax device descriptor.

Raises:

RuntimeError – The Jax device is neither a CPU nor GPU device.

Return type:

Device