warp.optim.Adam#
- class warp.optim.Adam(params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08)[source]#
Adaptive Moment Estimation (Adam) optimizer.
Adam is an adaptive learning rate optimization algorithm that computes individual learning rates for different parameters from estimates of first and second moments of the gradients. This implementation is designed for GPU-accelerated parameter updates using Warp kernels.
The algorithm maintains exponential moving averages of the gradient (first moment) and the squared gradient (second moment), using bias correction to account for their initialization at zero.
The interface is similar to PyTorch’s torch.optim.Adam.
- Parameters:
params – List of
warp.arrayobjects to optimize. Can beNoneand set later viaset_params(). Supported dtypes arewarp.float16,warp.float32, andwarp.vec3.lr – Learning rate (step size).
betas – Coefficients for computing running averages of gradient and its square. Tuple of two floats
(beta1, beta2)wherebeta1is the exponential decay rate for the first moment andbeta2is the decay rate for the second moment.eps – Small constant added to denominator for numerical stability.
Methods
__init__([params, lr, betas, eps])Reset moment buffers and timestep to zero.
set_params(params)Set parameters to optimize and allocate moment buffers.
step(grad)Apply one Adam step using the provided gradients.
step_detail(g, m, v, lr, beta1, beta2, t, ...)Apply an Adam update to a single parameter array.
- set_params(params)[source]#
Set parameters to optimize and allocate moment buffers.
- Parameters:
params – List of
warp.arrayobjects to optimize, orNone.
- step(grad)[source]#
Apply one Adam step using the provided gradients.
- Parameters:
grad – List of gradient arrays matching
params.
- static step_detail(g, m, v, lr, beta1, beta2, t, eps, params)[source]#
Apply an Adam update to a single parameter array.
- Parameters:
g – Gradient array.
m – First-moment buffer.
v – Second-moment buffer.
lr – Learning rate.
beta1 – Exponential decay for the first moment.
beta2 – Exponential decay for the second moment.
t – Current step index.
eps – Numerical stability term.
params – Parameter array to update in-place.