Stochastic Gradient Descent (SGD)#

API#

class warp_nn.optimizers.SGD(
parameters: list[array],
*,
lr: float = 0.001,
device: str | Device | None = None,
max_norm: float | None = None,
disable_graph: bool = False,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
)[source]#

Bases: Optimizer

SGD optimizer.

Parameters:
  • parameters – Model parameters.

  • lr – Learning rate.

  • device – Device to use for the optimizer.

  • disable_graph – Whether to disable graph capture.

  • momentum – Momentum factor.

  • dampening – Dampening factor.

  • weight_decay – Weight decay factor.

load_state_dict(state_dict: dict[str, Any]) None[source]#
state_dict() dict[str, Any][source]#
step(*, lr: float | None = None) None[source]#

Perform an optimization step to update parameters.

Parameters:

lr – Learning rate.