Adaptive Moment Estimation (Adam)#

API#

class warp_nn.optimizers.Adam(
parameters: list[array],
*,
lr: float = 0.001,
device: str | Device | None = None,
max_norm: float | None = None,
disable_graph: bool = False,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
)[source]#

Bases: Optimizer

Adam optimizer.

Adapted from the Warp’s Adam implementation to support CUDA graphs, gradient clipping and state dict.

Parameters:
  • parameters – Model parameters.

  • lr – Learning rate.

  • device – Device to use for the optimizer.

  • disable_graph – Whether to disable graph capture.

  • betas – Coefficients for the running averages of the gradient and its square.

  • eps – Term added to the denominator to improve numerical stability.

load_state_dict(state_dict: dict[str, Any]) None[source]#
state_dict() dict[str, Any][source]#
step(*, lr: float | None = None) None[source]#

Perform an optimization step to update parameters.

Parameters:

lr – Learning rate.