apex.fp16_utils

This submodule contains utilities designed to streamline the mixed precision training recipe presented by NVIDIA on Parallel Forall and in GTC 2018 Sessions Training Neural Networks with Mixed Precision: Theory and Practice and Training Neural Networks with Mixed Precision: Real Examples. For Pytorch users, Real Examples in particular is recommended.

Full runnable Python scripts demonstrating apex.fp16_utils can be found on the Github page:

Automatic management of master params + loss scaling

class apex.fp16_utils.FP16_Optimizer(init_optimizer, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, verbose=True)[source]

FP16_Optimizer is designed to wrap an existing PyTorch optimizer, and manage static or dynamic loss scaling and master weights in a manner transparent to the user. For standard use, only two lines must be changed: creating the FP16_Optimizer instance, and changing the call to backward.

Example:

model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# Name the FP16_Optimizer instance to replace the existing optimizer
# (recommended but not required):
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
# loss.backward() becomes:
optimizer.backward(loss)
...

Example with dynamic loss scaling:

...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
                           # optional arg to control dynamic loss scaling behavior
                           # dynamic_loss_args={'scale_window' : 500})
                           # Usually, dynamic_loss_args is not necessary.
Parameters
  • init_optimizer (torch.optim.optimizer) – Existing optimizer created with the parameters to optimize. Internally, FP16_Optimizer replaces the passed optimizer’s fp16 parameters, if any, with fp32 master parameters copied from the original ones. FP16_Optimizer also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each step.

  • static_loss_scale (float, optional, default=1.0) – Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so static_loss_scale should not affect learning rate.

  • dynamic_loss_scale (bool, optional, default=False) – Use dynamic loss scaling. If True, this will override any static_loss_scale option.

  • dynamic_loss_args (dict, optional, default=None) – Dict of kwargs that will be forwarded to the internal LossScaler instance’s constructor. Keys of this dict must match kwargs accepted by LossScaler’s constructor. If dynamic_loss_args is unspecified, LossScaler’s defaults will be used.

  • verbose (bool, optional, default=True) – By default, FP16_Optimizer’s constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing verbose=False. verbose=False will not disable printing when the loss scale is readjusted during dynamic loss scaling.

init_optimizer is expected to have been constructed in the ordinary way. It is recommended (although not required) that the newly constructed FP16_Optimizer instance be named to replace init_optimizer, for two reasons: First, it means that references to the same name later in the file will not have to change. Second, FP16_Optimizer reserves the right (as an implementation detail) to modify init_optimizer. If you do choose a unique name for the new FP16_Optimizer instance, you should only work with this new instance, because the preexisting optimizer might no longer behave as expected.

init_optimizer may be any Pytorch optimizer. It may contain a mixture of fp16 and fp32 parameters organized into any number of param_groups with different hyperparameters. The FP16_Optimizer constructor will ingest these param_groups and remember them.

Calls to

loss.backward()

must be replaced with

optimizer.backward(loss)

because FP16_Optimizer requires ownership of the backward pass to implement loss scaling and copies to master gradients.

Note

Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients are downscaled before being applied. This means that adjusting the loss scale, or using dynamic loss scaling, should not require retuning the learning rate or any other hyperparameters.

Advanced options

Closures: FP16_Optimizer can wrap a Pytorch optimizer that receives a closure. See docstring for step.

Gradient clipping: Use clip_master_grads.

Multiple losses: If your model accumulates gradients from multiple losses, this can be made more efficient by supplying update_master_grads=False to backward. See docstring for backward.

Manually adjusting loss scale: The current loss scale can be retrieved or set via

print(optimizer.loss_scale)
optimizer.loss_scale = new_loss_scale

For static loss scaling, manually adjusting the loss scale over time is a reasonable thing to do. During later epochs, gradients may become smaller, and a higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss scaling is more subtle (see DynamicLossScaler) and in this case, manually adjusting the loss scale is not recommended.

Multi_GPU training: If the wrapped init_optimizer was created from a model wrapped in Pytorch DistributedDataParallel or Apex DistributedDataParallel, FP16_Optimizer should still work as intended.

backward(loss, update_master_grads=True, retain_graph=False)[source]

backward performs the following conceptual steps:

  1. fp32_loss = loss.float() (see first Note below)

  2. scaled_loss = fp32_loss*loss_scale

  3. scaled_loss.backward(), which accumulates scaled gradients into the .grad attributes of the model’s leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).

  4. fp16 grads are then copied to the master params’ .grad attributes (see second Note), which are guaranteed to be fp32.

  5. Finally, master grads are divided by loss_scale.

In this way, after backward, the master params have fresh gradients, and step may be called.

Note

backward internally converts the loss to fp32 before applying the loss scale. This provides some additional safety against overflow if the user has supplied an fp16 loss value. However, for maximum overflow safety, the user should compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to backward.

Warning

The gradients found in a model’s leaves after the call to backward should not be regarded as valid in general, because it’s possible they have been scaled (and in the case of dynamic loss scaling, the scale factor may change over time). If the user wants to inspect gradients after a call to backward, only the master gradients should be regarded as valid. These can be retrieved via inspect_master_grad_data().

Parameters
  • loss – The loss output by the user’s model. loss may be either float or half (but see first Note above).

  • update_master_grads (bool, optional, default=True) – Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if backward is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling update_master_grads before calling step.

  • retain_graph (bool, optional, default=False) – Forwards the usual retain_graph=True option to the internal call to loss.backward. If retain_graph is being used to accumulate gradient values from multiple backward passes before calling optimizer.step, passing update_master_grads=False is also recommended (see Example below).

Example:

# Ordinary operation:
optimizer.backward(loss)

# Naive operation with multiple losses (technically valid, but less efficient):
# fp32 grads will be correct after the second call,  but
# the first call incurs an unnecessary fp16->fp32 grad copy.
optimizer.backward(loss1)
optimizer.backward(loss2)

# More efficient way to handle multiple losses:
# The fp16->fp32 grad copy is delayed until fp16 grads from all
# losses have been accumulated.
optimizer.backward(loss1, update_master_grads=False)
optimizer.backward(loss2, update_master_grads=False)
optimizer.update_master_grads()
clip_master_grads(max_norm, norm_type=2)[source]

Clips fp32 master gradients via torch.nn.utils.clip_grad_norm.

Parameters
  • max_norm (float or int) – max norm of the gradients

  • norm_type (float or int) – type of the used p-norm. Can be 'inf' for infinity norm.

Returns

Total norm of the current fp32 gradients (viewed as a single vector).

Warning

Returns -1 if the most recently computed fp16 gradients overflowed (that is, if self.overflow is True).

inspect_master_grad_data()[source]

When running with FP16_Optimizer, .grad attributes of a model’s fp16 leaves should not be regarded as truthful, because they might be scaled. After a call to fp16_optimizer_obj.backward(loss), if no overflow was encountered, the fp32 master params’ .grad attributes will contain valid gradients properly divided by the loss scale. However, because FP16_Optimizer flattens some parameters, accessing them may be nonintuitive. inspect_master_grad_data allows those gradients to be viewed with shapes corresponding to their associated model leaves.

Returns

List of lists (one list for each parameter group). The list for each parameter group is a list of the .grad.data attributes of the fp32 master params belonging to that group.

load_state_dict(state_dict)[source]

Loads a state_dict created by an earlier call to state_dict(). If fp16_optimizer_instance was constructed from some init_optimizer, whose parameters in turn came from model, it is expected that the user will call model.load_state_dict() before fp16_optimizer_instance.load_state_dict() is called.

Example:

model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
state_dict()[source]

Returns a dict containing the current state of this FP16_Optimizer instance. This dict contains attributes of FP16_Optimizer, as well as the state_dict of the contained Pytorch optimizer. Example:

checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
step(closure=None)[source]

If no closure is supplied, step should be called after fp16_optimizer_obj.backward(loss). step updates the fp32 master copy of parameters using the optimizer supplied to FP16_Optimizer’s constructor, then copies the updated fp32 params into the fp16 params originally referenced by FP16_Optimizer’s constructor, so the user may immediately run another forward pass using their model.

If a closure is supplied, step may be called without a prior call to backward(loss). This control flow is identical to ordinary Pytorch optimizer use with closures. However, the user should take care that any loss.backward() call within the closure has been replaced by fp16_optimizer_obj.backward(loss).

Parameters

closure (optional) – Closure that will be supplied to the underlying optimizer originally passed to FP16_Optimizer’s constructor. closure should call zero_grad() on the FP16_Optimizer object, compute the loss, call backward(loss), and return the loss.

Example with closure:

# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
# existing pytorch optimizer.
for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        # loss.backward() becomes:
        optimizer.backward(loss)
        return loss
    optimizer.step(closure)

Warning

Currently, calling step with a closure is not compatible with dynamic loss scaling.

update_master_grads()[source]

Copy the .grad attribute from stored references to fp16 parameters to the .grad attribute of the fp32 master parameters that are directly updated by the optimizer. update_master_grads only needs to be called if fp16_optimizer_obj.backward was called with update_master_grads=False.

zero_grad(set_grads_to_None=False)[source]

Zero fp32 and fp16 parameter grads.

class apex.fp16_utils.LossScaler(scale=1)[source]

Class that manages a static loss scale. This class is intended to interact with FP16_Optimizer, and should not be directly manipulated by the user.

Use of LossScaler is enabled via the static_loss_scale argument to FP16_Optimizer’s constructor.

Parameters

scale (float, optional, default=1.0) – The loss scale.

class apex.fp16_utils.DynamicLossScaler(init_scale=4294967296, scale_factor=2.0, scale_window=1000)[source]

Class that manages dynamic loss scaling. It is recommended to use DynamicLossScaler indirectly, by supplying dynamic_loss_scale=True to the constructor of FP16_Optimizer. However, it’s important to understand how DynamicLossScaler operates, because the default options can be changed using the the dynamic_loss_args argument to FP16_Optimizer’s constructor.

Loss scaling is designed to combat the problem of underflowing gradients encountered at long times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are encountered, DynamicLossScaler informs FP16_Optimizer that an overflow has occurred. FP16_Optimizer then skips the update step for this particular iteration/minibatch, and DynamicLossScaler adjusts the loss scale to a lower value. If a certain number of iterations occur without overflowing gradients detected, DynamicLossScaler increases the loss scale once more. In this way DynamicLossScaler attempts to “ride the edge” of always using the highest loss scale possible without incurring overflow.

Parameters
  • init_scale (float, optional, default=2**32) – Initial loss scale attempted by DynamicLossScaler.

  • scale_factor (float, optional, default=2.0) – Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/scale_factor. If scale_window consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.

  • scale_window (int, optional, default=1000) – Number of consecutive iterations without an overflow to wait before increasing the loss scale.

Manual master parameter management

apex.fp16_utils.prep_param_lists(model, flat_master=False)[source]

Creates a list of FP32 master parameters for a given model, as in Training Neural Networks with Mixed Precision: Real Examples.

Parameters
  • model (torch.nn.Module) – Existing Pytorch model

  • flat_master (bool, optional, default=False) – Flatten the master parameters into a single tensor, as a performance optimization.

Returns

A tuple (model_params, master_params). model_params is a list of the model’s parameters for later use with model_grads_to_master_grads() and master_params_to_model_params(). master_params is a list of FP32 master gradients. If flat_master=True, master_params will be a list with one element.

Example:

model_params, master_params = prep_param_lists(model)

Warning

Currently, if flat_master=True, all the model’s parameters must be the same type. If the model has parameters of different types, use flat_master=False, or use FP16_Optimizer.

apex.fp16_utils.master_params_to_model_params(model_params, master_params, flat_master=False)[source]

Copy master parameters to model parameters.

Parameters
apex.fp16_utils.model_grads_to_master_grads(model_params, master_params, flat_master=False)[source]

Copy model gradients to master gradients.

Parameters