apex.fp16_utils¶
This submodule contains utilities designed to streamline the mixed precision training recipe presented by NVIDIA on Parallel Forall and in GTC 2018 Sessions Training Neural Networks with Mixed Precision: Theory and Practice and Training Neural Networks with Mixed Precision: Real Examples. For Pytorch users, Real Examples in particular is recommended.
Full runnable Python scripts demonstrating apex.fp16_utils
can be found on the Github page:
Automatic management of master params + loss scaling¶
-
class
apex.fp16_utils.
FP16_Optimizer
(init_optimizer, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, verbose=True)[source]¶ FP16_Optimizer
is designed to wrap an existing PyTorch optimizer, and manage static or dynamic loss scaling and master weights in a manner transparent to the user. For standard use, only two lines must be changed: creating theFP16_Optimizer
instance, and changing the call tobackward
.Example:
model = torch.nn.Linear(D_in, D_out).cuda().half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # Name the FP16_Optimizer instance to replace the existing optimizer # (recommended but not required): optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... # loss.backward() becomes: optimizer.backward(loss) ...
Example with dynamic loss scaling:
... optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) # optional arg to control dynamic loss scaling behavior # dynamic_loss_args={'scale_window' : 500}) # Usually, dynamic_loss_args is not necessary.
- Parameters
init_optimizer (torch.optim.optimizer) – Existing optimizer created with the parameters to optimize. Internally,
FP16_Optimizer
replaces the passed optimizer’s fp16 parameters, if any, with fp32 master parameters copied from the original ones.FP16_Optimizer
also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of eachstep
.static_loss_scale (float, optional, default=1.0) – Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so
static_loss_scale
should not affect learning rate.dynamic_loss_scale (bool, optional, default=False) – Use dynamic loss scaling. If True, this will override any
static_loss_scale
option.dynamic_loss_args (dict, optional, default=None) – Dict of kwargs that will be forwarded to the internal
LossScaler
instance’s constructor. Keys of this dict must match kwargs accepted byLossScaler
’s constructor. Ifdynamic_loss_args
is unspecified,LossScaler
’s defaults will be used.verbose (bool, optional, default=True) – By default, FP16_Optimizer’s constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing
verbose=False
.verbose=False
will not disable printing when the loss scale is readjusted during dynamic loss scaling.
init_optimizer
is expected to have been constructed in the ordinary way. It is recommended (although not required) that the newly constructedFP16_Optimizer
instance be named to replaceinit_optimizer
, for two reasons: First, it means that references to the same name later in the file will not have to change. Second,FP16_Optimizer
reserves the right (as an implementation detail) to modifyinit_optimizer
. If you do choose a unique name for the newFP16_Optimizer
instance, you should only work with this new instance, because the preexisting optimizer might no longer behave as expected.init_optimizer
may be any Pytorch optimizer. It may contain a mixture of fp16 and fp32 parameters organized into any number ofparam_groups
with different hyperparameters. TheFP16_Optimizer
constructor will ingest theseparam_groups
and remember them.Calls to
loss.backward()
must be replaced with
optimizer.backward(loss)
because
FP16_Optimizer
requires ownership of the backward pass to implement loss scaling and copies to master gradients.Note
Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients are downscaled before being applied. This means that adjusting the loss scale, or using dynamic loss scaling, should not require retuning the learning rate or any other hyperparameters.
Advanced options
Closures:
FP16_Optimizer
can wrap a Pytorch optimizer that receives a closure. See docstring forstep
.Gradient clipping: Use
clip_master_grads
.Multiple losses: If your model accumulates gradients from multiple losses, this can be made more efficient by supplying
update_master_grads=False
tobackward
. See docstring forbackward
.Manually adjusting loss scale: The current loss scale can be retrieved or set via
print(optimizer.loss_scale) optimizer.loss_scale = new_loss_scale
For static loss scaling, manually adjusting the loss scale over time is a reasonable thing to do. During later epochs, gradients may become smaller, and a higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss scaling is more subtle (see
DynamicLossScaler
) and in this case, manually adjusting the loss scale is not recommended.Multi_GPU training: If the wrapped
init_optimizer
was created from a model wrapped in Pytorch DistributedDataParallel or Apex DistributedDataParallel,FP16_Optimizer
should still work as intended.-
backward
(loss, update_master_grads=True, retain_graph=False)[source]¶ backward
performs the following conceptual steps:fp32_loss = loss.float() (see first Note below)
scaled_loss = fp32_loss*loss_scale
scaled_loss.backward(), which accumulates scaled gradients into the
.grad
attributes of the model’s leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).fp16 grads are then copied to the master params’
.grad
attributes (see second Note), which are guaranteed to be fp32.Finally, master grads are divided by loss_scale.
In this way, after
backward
, the master params have fresh gradients, andstep
may be called.Note
backward
internally converts the loss to fp32 before applying the loss scale. This provides some additional safety against overflow if the user has supplied an fp16 loss value. However, for maximum overflow safety, the user should compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it tobackward
.Warning
The gradients found in a model’s leaves after the call to
backward
should not be regarded as valid in general, because it’s possible they have been scaled (and in the case of dynamic loss scaling, the scale factor may change over time). If the user wants to inspect gradients after a call tobackward
, only the master gradients should be regarded as valid. These can be retrieved viainspect_master_grad_data()
.- Parameters
loss – The loss output by the user’s model. loss may be either float or half (but see first Note above).
update_master_grads (bool, optional, default=True) – Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if
backward
is being called on multiple losses in one iteration. If set to False, the user becomes responsible for callingupdate_master_grads
before callingstep
.retain_graph (bool, optional, default=False) – Forwards the usual
retain_graph=True
option to the internal call toloss.backward
. Ifretain_graph
is being used to accumulate gradient values from multiple backward passes before callingoptimizer.step
, passingupdate_master_grads=False
is also recommended (see Example below).
Example:
# Ordinary operation: optimizer.backward(loss) # Naive operation with multiple losses (technically valid, but less efficient): # fp32 grads will be correct after the second call, but # the first call incurs an unnecessary fp16->fp32 grad copy. optimizer.backward(loss1) optimizer.backward(loss2) # More efficient way to handle multiple losses: # The fp16->fp32 grad copy is delayed until fp16 grads from all # losses have been accumulated. optimizer.backward(loss1, update_master_grads=False) optimizer.backward(loss2, update_master_grads=False) optimizer.update_master_grads()
-
clip_master_grads
(max_norm, norm_type=2)[source]¶ Clips fp32 master gradients via
torch.nn.utils.clip_grad_norm
.- Parameters
- Returns
Total norm of the current fp32 gradients (viewed as a single vector).
Warning
Returns -1 if the most recently computed fp16 gradients overflowed (that is, if
self.overflow
isTrue
).
-
inspect_master_grad_data
()[source]¶ When running with
FP16_Optimizer
,.grad
attributes of a model’s fp16 leaves should not be regarded as truthful, because they might be scaled. After a call tofp16_optimizer_obj.backward(loss)
, if no overflow was encountered, the fp32 master params’.grad
attributes will contain valid gradients properly divided by the loss scale. However, becauseFP16_Optimizer
flattens some parameters, accessing them may be nonintuitive.inspect_master_grad_data
allows those gradients to be viewed with shapes corresponding to their associated model leaves.- Returns
List of lists (one list for each parameter group). The list for each parameter group is a list of the
.grad.data
attributes of the fp32 master params belonging to that group.
-
load_state_dict
(state_dict)[source]¶ Loads a state_dict created by an earlier call to state_dict(). If
fp16_optimizer_instance
was constructed from someinit_optimizer
, whose parameters in turn came frommodel
, it is expected that the user will callmodel.load_state_dict()
beforefp16_optimizer_instance.load_state_dict()
is called.Example:
model = torch.nn.Linear(D_in, D_out).cuda().half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... checkpoint = torch.load("saved.pth") model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer'])
-
state_dict
()[source]¶ Returns a dict containing the current state of this
FP16_Optimizer
instance. This dict contains attributes ofFP16_Optimizer
, as well as the state_dict of the contained Pytorch optimizer. Example:checkpoint = {} checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() torch.save(checkpoint, "saved.pth")
-
step
(closure=None)[source]¶ If no closure is supplied,
step
should be called afterfp16_optimizer_obj.backward(loss)
.step
updates the fp32 master copy of parameters using the optimizer supplied toFP16_Optimizer
’s constructor, then copies the updated fp32 params into the fp16 params originally referenced byFP16_Optimizer
’s constructor, so the user may immediately run another forward pass using their model.If a closure is supplied,
step
may be called without a prior call tobackward(loss)
. This control flow is identical to ordinary Pytorch optimizer use with closures. However, the user should take care that anyloss.backward()
call within the closure has been replaced byfp16_optimizer_obj.backward(loss)
.- Parameters
closure (optional) – Closure that will be supplied to the underlying optimizer originally passed to
FP16_Optimizer
’s constructor. closure should callzero_grad()
on theFP16_Optimizer
object, compute the loss, callbackward(loss)
, and return the loss.
Example with closure:
# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an # existing pytorch optimizer. for input, target in dataset: def closure(): optimizer.zero_grad() output = model(input) loss = loss_fn(output, target) # loss.backward() becomes: optimizer.backward(loss) return loss optimizer.step(closure)
Warning
Currently, calling
step
with a closure is not compatible with dynamic loss scaling.
-
update_master_grads
()[source]¶ Copy the
.grad
attribute from stored references to fp16 parameters to the.grad
attribute of the fp32 master parameters that are directly updated by the optimizer.update_master_grads
only needs to be called iffp16_optimizer_obj.backward
was called withupdate_master_grads=False
.
-
class
apex.fp16_utils.
LossScaler
(scale=1)[source]¶ Class that manages a static loss scale. This class is intended to interact with
FP16_Optimizer
, and should not be directly manipulated by the user.Use of
LossScaler
is enabled via thestatic_loss_scale
argument toFP16_Optimizer
’s constructor.- Parameters
scale (float, optional, default=1.0) – The loss scale.
-
class
apex.fp16_utils.
DynamicLossScaler
(init_scale=4294967296, scale_factor=2.0, scale_window=1000)[source]¶ Class that manages dynamic loss scaling. It is recommended to use
DynamicLossScaler
indirectly, by supplyingdynamic_loss_scale=True
to the constructor ofFP16_Optimizer
. However, it’s important to understand howDynamicLossScaler
operates, because the default options can be changed using the thedynamic_loss_args
argument toFP16_Optimizer
’s constructor.Loss scaling is designed to combat the problem of underflowing gradients encountered at long times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are encountered,
DynamicLossScaler
informsFP16_Optimizer
that an overflow has occurred.FP16_Optimizer
then skips the update step for this particular iteration/minibatch, andDynamicLossScaler
adjusts the loss scale to a lower value. If a certain number of iterations occur without overflowing gradients detected,DynamicLossScaler
increases the loss scale once more. In this wayDynamicLossScaler
attempts to “ride the edge” of always using the highest loss scale possible without incurring overflow.- Parameters
init_scale (float, optional, default=2**32) – Initial loss scale attempted by
DynamicLossScaler.
scale_factor (float, optional, default=2.0) – Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/
scale_factor
. Ifscale_window
consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.scale_window (int, optional, default=1000) – Number of consecutive iterations without an overflow to wait before increasing the loss scale.
Manual master parameter management¶
-
apex.fp16_utils.
prep_param_lists
(model, flat_master=False)[source]¶ Creates a list of FP32 master parameters for a given model, as in Training Neural Networks with Mixed Precision: Real Examples.
- Parameters
model (torch.nn.Module) – Existing Pytorch model
flat_master (bool, optional, default=False) – Flatten the master parameters into a single tensor, as a performance optimization.
- Returns
A tuple (
model_params
,master_params
).model_params
is a list of the model’s parameters for later use withmodel_grads_to_master_grads()
andmaster_params_to_model_params()
.master_params
is a list of FP32 master gradients. Ifflat_master=True
,master_params
will be a list with one element.
Example:
model_params, master_params = prep_param_lists(model)
Warning
Currently, if
flat_master=True
, all the model’s parameters must be the same type. If the model has parameters of different types, useflat_master=False
, or useFP16_Optimizer
.
-
apex.fp16_utils.
master_params_to_model_params
(model_params, master_params, flat_master=False)[source]¶ Copy master parameters to model parameters.
- Parameters
model_params – List of model parameters created by
prep_param_lists()
.master_params – List of FP32 master parameters created by
prep_param_lists()
. Ifmaster_params
was created withflat_master=True
,flat_master=True
should also be supplied tomaster_params_to_model_params()
.
-
apex.fp16_utils.
model_grads_to_master_grads
(model_params, master_params, flat_master=False)[source]¶ Copy model gradients to master gradients.
- Parameters
model_params – List of model parameters created by
prep_param_lists()
.master_params – List of FP32 master parameters created by
prep_param_lists()
. Ifmaster_params
was created withflat_master=True
,flat_master=True
should also be supplied tomodel_grads_to_master_grads()
.