warp.grad#
- warp.grad(func)[source]#
Return a callable that computes the gradient of the given function.
When called with the same arguments as the original function, returns the gradients for each input.
- Parameters:
func (Callable) – A Warp function (decorated with
@wp.funcor a builtin).- Returns:
A callable that, when called with the function’s inputs, returns the gradient(s) for each input. If the function has a single input, returns a single gradient value. If the function has multiple inputs, returns a tuple of gradients.
- Return type:
GradWrapper
Example:
import warp as wp @wp.func def square(x: float): return x * x @wp.kernel def my_kernel(x: wp.array(dtype=float), grad_x: wp.array(dtype=float)): tid = wp.tid() # Compute d(x*x)/d(x) = 2*x grad_x[tid] = wp.grad(square)(x[tid]) # For functions with multiple inputs: @wp.kernel def kernel2( a: wp.array(dtype=float), b: wp.array(dtype=float), grad_a: wp.array(dtype=float), grad_b: wp.array(dtype=float), ): tid = wp.tid() db, da = wp.grad(wp.atan2)(b[tid], a[tid]) grad_a[tid] = da grad_b[tid] = db
Note
When used in a regular Warp function or kernel,
grad()calls are forward-only and do NOT participate in Warp’s automatic differentiation. If you usegrad()in a kernel withenable_backward=True, the gradient call will be treated as a constant in the backward pass (no gradients will flow through it).However,
grad()can be used inside custom gradient functions decorated with@warp.func_grad, which participate in the backward pass.