warp.grad#

warp.grad(func)[source]#

Return a callable that computes the gradient of the given function.

When called with the same arguments as the original function, returns the gradients for each input.

Parameters:

func (Callable) – A Warp function (decorated with @wp.func or a builtin).

Returns:

A callable that, when called with the function’s inputs, returns the gradient(s) for each input. If the function has a single input, returns a single gradient value. If the function has multiple inputs, returns a tuple of gradients.

Return type:

GradWrapper

Example:

import warp as wp


@wp.func
def square(x: float):
    return x * x


@wp.kernel
def my_kernel(x: wp.array(dtype=float), grad_x: wp.array(dtype=float)):
    tid = wp.tid()
    # Compute d(x*x)/d(x) = 2*x
    grad_x[tid] = wp.grad(square)(x[tid])


# For functions with multiple inputs:
@wp.kernel
def kernel2(
    a: wp.array(dtype=float),
    b: wp.array(dtype=float),
    grad_a: wp.array(dtype=float),
    grad_b: wp.array(dtype=float),
):
    tid = wp.tid()
    db, da = wp.grad(wp.atan2)(b[tid], a[tid])
    grad_a[tid] = da
    grad_b[tid] = db

Note

When used in a regular Warp function or kernel, grad() calls are forward-only and do NOT participate in Warp’s automatic differentiation. If you use grad() in a kernel with enable_backward=True, the gradient call will be treated as a constant in the backward pass (no gradients will flow through it).

However, grad() can be used inside custom gradient functions decorated with @warp.func_grad, which participate in the backward pass.