Advanced Amp Usage

GANs

GANs are an interesting synthesis of several topics below. A comprehensive example is under construction.

Gradient clipping

Amp calls the params owned directly by the optimizer’s param_groups the “master params.”

These master params may be fully or partially distinct from model.parameters(). For example, with opt_level=”O2”, amp.initialize casts most model params to FP16, creates an FP32 master param outside the model for each newly-FP16 model param, and updates the optimizer’s param_groups to point to these FP32 params.

The master params owned by the optimizer’s param_groups may also fully coincide with the model params, which is typically true for opt_levels O0, O1, and O3.

In all cases, correct practice is to clip the gradients of the params that are guaranteed to be owned by the optimizer’s param_groups, instead of those retrieved via model.parameters().

Also, if Amp uses loss scaling, gradients must be clipped after they have been unscaled (which occurs during exit from the amp.scale_loss context manager).

The following pattern should be correct for any opt_level:

with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
    # Gradients are unscaled during context manager exit.
# Now it's safe to clip.  Replace
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# with
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm)
# or
torch.nn.utils.clip_grad_value_(amp.master_params(optimizer), max_)

Note the use of the utility function amp.master_params(optimizer), which returns a generator-expression that iterates over the params in the optimizer’s param_groups.

Also note that clip_grad_norm_(amp.master_params(optimizer), max_norm) is invoked instead of, not in addition to, clip_grad_norm_(model.parameters(), max_norm).

Custom/user-defined autograd functions

The old Amp API for registering user functions is still considered correct. Functions must be registered before calling amp.initialize.

Forcing particular layers/functions to a desired type

I’m still working on a generalizable exposure for this that won’t require user-side code divergence across different opt-levels.

Multiple models/optimizers/losses

Initialization with multiple models/optimizers

amp.initialize’s optimizer argument may be a single optimizer or a list of optimizers, as long as the output you accept has the same type. Similarly, the model argument may be a single model or a list of models, as long as the accepted output matches. The following calls are all legal:

model, optim = amp.initialize(model, optim,...)
model, [optim0, optim1] = amp.initialize(model, [optim0, optim1],...)
[model0, model1], optim = amp.initialize([model0, model1], optim,...)
[model0, model1], [optim0, optim1] = amp.initialize([model0, model1], [optim0, optim1],...)

Backward passes with multiple optimizers

Whenever you invoke a backward pass, the amp.scale_loss context manager must receive all the optimizers that own any params for which the current backward pass is creating gradients. This is true even if each optimizer owns only some, but not all, of the params that are about to receive gradients.

If, for a given backward pass, there’s only one optimizer whose params are about to receive gradients, you may pass that optimizer directly to amp.scale_loss. Otherwise, you must pass the list of optimizers whose params are about to receive gradients. Example with 3 losses and 2 optimizers:

# loss0 accumulates gradients only into params owned by optim0:
with amp.scale_loss(loss0, optim0) as scaled_loss:
    scaled_loss.backward()

# loss1 accumulates gradients only into params owned by optim1:
with amp.scale_loss(loss1, optim1) as scaled_loss:
    scaled_loss.backward()

# loss2 accumulates gradients into some params owned by optim0
# and some params owned by optim1
with amp.scale_loss(loss2, [optim0, optim1]) as scaled_loss:
    scaled_loss.backward()

Optionally have Amp use a different loss scaler per-loss

By default, Amp maintains a single global loss scaler that will be used for all backward passes (all invocations of with amp.scale_loss(...)). No additional arguments to amp.initialize or amp.scale_loss are required to use the global loss scaler. The code snippets above with multiple optimizers/backward passes use the single global loss scaler under the hood, and they should “just work.”

However, you can optionally tell Amp to maintain a loss scaler per-loss, which gives Amp increased numerical flexibility. This is accomplished by supplying the num_losses argument to amp.initialize (which tells Amp how many backward passes you plan to invoke, and therefore how many loss scalers Amp should create), then supplying the loss_id argument to each of your backward passes (which tells Amp the loss scaler to use for this particular backward pass):

model, [optim0, optim1] = amp.initialize(model, [optim0, optim1], ..., num_losses=3)

with amp.scale_loss(loss0, optim0, loss_id=0) as scaled_loss:
    scaled_loss.backward()

with amp.scale_loss(loss1, optim1, loss_id=1) as scaled_loss:
    scaled_loss.backward()

with amp.scale_loss(loss2, [optim0, optim1], loss_id=2) as scaled_loss:
    scaled_loss.backward()

num_losses and loss_ids should be specified purely based on the set of losses/backward passes. The use of multiple optimizers, or association of single or multiple optimizers with each backward pass, is unrelated.

Gradient accumulation across iterations

The following should “just work,” and properly accommodate multiple models/optimizers/losses, as well as gradient clipping via the instructions above:

# If your intent is to simulate a larger batch size using gradient accumulation,
# you can divide the loss by the number of accumulation iterations (so that gradients
# will be averaged over that many iterations):
loss = loss/iters_to_accumulate

with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()

# Every iters_to_accumulate iterations, call step() and reset gradients:
if iter%iters_to_accumulate == 0:
    # Gradient clipping if desired:
    # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm)
    optimizer.step()
    optimizer.zero_grad()

As a minor performance optimization, you can pass delay_unscale=True to amp.scale_loss until you’re ready to step(). You should only attempt delay_unscale=True if you’re sure you know what you’re doing, because the interaction with gradient clipping and multiple models/optimizers/losses can become tricky.:

if iter%iters_to_accumulate == 0:
    # Every iters_to_accumulate iterations, unscale and step
    with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
else:
    # Otherwise, accumulate gradients, don't unscale or step.
    with amp.scale_loss(loss, optimizer, delay_unscale=True) as scaled_loss:
        scaled_loss.backward()

Custom data batch types

The intention of Amp is that you never need to cast your input data manually, regardless of opt_level. Amp accomplishes this by patching any models’ forward methods to cast incoming data appropriately for the opt_level. But to cast incoming data, Amp needs to know how. The patched forward will recognize and cast floating-point Tensors (non-floating-point Tensors like IntTensors are not touched) and Python containers of floating-point Tensors. However, if you wrap your Tensors in a custom class, the casting logic doesn’t know how to drill through the tough custom shell to access and cast the juicy Tensor meat within. You need to tell Amp how to cast your custom batch class, by assigning it a to method that accepts a torch.dtype (e.g., torch.float16 or torch.float32) and returns an instance of the custom batch cast to dtype. The patched forward checks for the presence of your to method, and will invoke it with the correct type for the opt_level.

Example:

class CustomData(object):
    def __init__(self):
        self.tensor = torch.cuda.FloatTensor([1,2,3])

    def to(self, dtype):
        self.tensor = self.tensor.to(dtype)
        return self

Warning

Amp also forwards numpy ndarrays without casting them. If you send input data as a raw, unwrapped ndarray, then later use it to create a Tensor within your model.forward, this Tensor’s type will not depend on the opt_level, and may or may not be correct. Users are encouraged to pass castable data inputs (Tensors, collections of Tensors, or custom classes with a to method) wherever possible.

Note

Amp does not call .cuda() on any Tensors for you. Amp assumes that your original script is already set up to move Tensors from the host to the device as needed.