apex.optimizers

class apex.optimizers.FusedAdam(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, adam_w_mode=True, weight_decay=0.0, amsgrad=False, set_grad_none=True)[source]

Implements Adam algorithm.

Currently GPU-only. Requires Apex to be installed via pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./.

This version of fused Adam implements 2 fusions.

  • Fusion of the Adam update’s elementwise operations

  • A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.

apex.optimizers.FusedAdam may be used as a drop-in replacement for torch.optim.Adam:

opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
...
opt.step()

apex.optimizers.FusedAdam may be used with or without Amp. If you wish to use FusedAdam with Amp, you may choose any opt_level:

opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()

In general, opt_level="O1" is recommended.

Warning

A previous version of FusedAdam allowed a number of additional arguments to step. These additional arguments are now deprecated and unnecessary.

Adam was been proposed in Adam: A Method for Stochastic Optimization.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) – learning rate. (default: 1e-3)

  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False) NOT SUPPORTED in FusedAdam!

  • adam_w_mode (boolean, optional) – Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True)

  • set_grad_none (bool, optional) – whether set grad to None when zero_grad() method is called. (default: True)

step(closure=None, grads=None, output_params=None, scale=None, grad_norms=None)[source]

Performs a single optimization step.

Parameters

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.

zero_grad()[source]

Clears the gradients of all optimized torch.Tensor s.

class apex.optimizers.FusedLAMB(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.01, amsgrad=False, adam_w_mode=True, grad_averaging=True, set_grad_none=True, max_grad_norm=1.0)[source]

Implements LAMB algorithm.

Currently GPU-only. Requires Apex to be installed via pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./.

This version of fused LAMB implements 2 fusions.

  • Fusion of the LAMB update’s elementwise operations

  • A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.

apex.optimizers.FusedLAMB’s usage is identical to any ordinary Pytorch optimizer:

opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()

apex.optimizers.FusedLAMB may be used with or without Amp. If you wish to use FusedLAMB with Amp, you may choose any opt_level:

opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()

In general, opt_level="O1" is recommended.

LAMB was proposed in Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) – learning rate. (default: 1e-3)

  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its norm. (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond NOT SUPPORTED now! (default: False)

  • adam_w_mode (boolean, optional) – Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True)

  • grad_averaging (bool, optional) – whether apply (1-beta2) to grad when calculating running averages of gradient. (default: True)

  • set_grad_none (bool, optional) – whether set grad to None when zero_grad() method is called. (default: True)

  • max_grad_norm (float, optional) – value used to clip global grad norm (default: 1.0)

step(closure=None)[source]

Performs a single optimization step.

Parameters

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

zero_grad()[source]

Clears the gradients of all optimized torch.Tensor s.

class apex.optimizers.FusedNovoGrad(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False, reg_inside_moment=False, grad_averaging=True, norm_type=2, init_zero=False, set_grad_none=True)[source]

Implements NovoGrad algorithm.

Currently GPU-only. Requires Apex to be installed via pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./.

This version of fused NovoGrad implements 2 fusions.

  • Fusion of the NovoGrad update’s elementwise operations

  • A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.

apex.optimizers.FusedNovoGrad’s usage is identical to any Pytorch optimizer:

opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....)
...
opt.step()

apex.optimizers.FusedNovoGrad may be used with or without Amp. If you wish to use FusedNovoGrad with Amp, you may choose any opt_level:

opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()

In general, opt_level="O1" is recommended.

It has been proposed in Jasper: An End-to-End Convolutional Neural Acoustic Model. More info: https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html#novograd

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) – learning rate. (default: 1e-3)

  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its norm. (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond NOT SUPPORTED now! (default: False)

  • reg_inside_moment (bool, optional) – whether do regularization (norm and L2) in momentum calculation. True for include, False for not include and only do it on update term. (default: False)

  • grad_averaging (bool, optional) – whether apply (1-beta2) to grad when calculating running averages of gradient. (default: True)

  • norm_type (int, optional) – which norm to calculate for each layer. 2 for L2 norm, and 0 for infinite norm. These 2 are only supported type now. (default: 2)

  • init_zero (bool, optional) – whether init norm with 0 (start averaging on 1st step) or first step norm (start averaging on 2nd step). True for init with 0. (default: False)

  • set_grad_none (bool, optional) – whether set grad to None when zero_grad() method is called. (default: True)

load_state_dict(state_dict)[source]

Loads the optimizer state.

Parameters

state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict().

step(closure=None)[source]

Performs a single optimization step.

Parameters

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

zero_grad()[source]

Clears the gradients of all optimized torch.Tensor s.

class apex.optimizers.FusedSGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, wd_after_momentum=False, materialize_master_grads=True)[source]

Implements stochastic gradient descent (optionally with momentum).

Currently GPU-only. Requires Apex to be installed via pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./.

This version of fused SGD implements 2 fusions.

  • Fusion of the SGD update’s elementwise operations

  • A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.

apex.optimizers.FusedSGD may be used as a drop-in replacement for torch.optim.SGD:

opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....)
...
opt.step()

apex.optimizers.FusedSGD may be used with or without Amp. If you wish to use FusedSGD with Amp, you may choose any opt_level:

opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()

In general, opt_level="O1" is recommended.

Nesterov momentum is based on the formula from On the importance of initialization and momentum in deep learning.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – learning rate

  • momentum (float, optional) – momentum factor (default: 0)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • dampening (float, optional) – dampening for momentum (default: 0)

  • nesterov (bool, optional) – enables Nesterov momentum (default: False)

Example

>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()

Note

The implementation of SGD with Momentum/Nesterov subtly differs from Sutskever et. al. and implementations in some other frameworks.

Considering the specific case of Momentum, the update can be written as

\[\begin{split}v = \rho * v + g \\ p = p - lr * v\end{split}\]

where p, g, v and \(\rho\) denote the parameters, gradient, velocity, and momentum respectively.

This is in contrast to Sutskever et. al. and other frameworks which employ an update of the form

\[\begin{split}v = \rho * v + lr * g \\ p = p - v\end{split}\]

The Nesterov version is analogously modified.

step(closure=None)[source]

Performs a single optimization step.

Parameters

closure (callable, optional) – A closure that reevaluates the model and returns the loss.