Stochastic Gradient Descent (SGD)#
API#
- class warp_nn.optimizers.SGD(
- parameters: list[array],
- *,
- lr: float = 0.001,
- device: str | Device | None = None,
- max_norm: float | None = None,
- disable_graph: bool = False,
- momentum: float = 0,
- dampening: float = 0,
- weight_decay: float = 0,
Bases:
OptimizerSGD optimizer.
- Parameters:
parameters – Model parameters.
lr – Learning rate.
device – Device to use for the optimizer.
disable_graph – Whether to disable graph capture.
momentum – Momentum factor.
dampening – Dampening factor.
weight_decay – Weight decay factor.