apex.parallel

class apex.parallel.DistributedDataParallel(module, message_size=10000000, delay_allreduce=False, shared_param=None, allreduce_trigger_params=None, retain_allreduce_buffers=False, allreduce_always_fp32=False, num_allreduce_streams=1, allreduce_communicators=None, gradient_average=True, gradient_predivide_factor=1.0, gradient_average_split_factor=None, prof=False)[source]

apex.parallel.DistributedDataParallel is a module wrapper that enables easy multiprocess distributed data parallel training, similar to torch.nn.parallel.DistributedDataParallel. Parameters are broadcast across participating processes on initialization, and gradients are allreduced and averaged over processes during backward().

DistributedDataParallel is optimized for use with NCCL. It achieves high performance by overlapping communication with computation during backward() and bucketing smaller gradient transfers to reduce the total number of transfers required.

DistributedDataParallel is designed to work with the upstream launch utility script torch.distributed.launch with --nproc_per_node <= number of gpus per node. When used with this launcher, DistributedDataParallel assumes 1:1 mapping of processes to GPUs. It also assumes that your script calls torch.cuda.set_device(args.rank) before creating the model.

https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed shows detailed usage. https://github.com/NVIDIA/apex/tree/master/examples/imagenet shows another example that combines DistributedDataParallel with mixed precision training.

Parameters
  • module – Network definition to be run in multi-gpu/distributed mode.

  • message_size (int, default=1e7) – Minimum number of elements in a communication bucket.

  • delay_allreduce (bool, default=False) – Delay all communication to the end of the backward pass. This disables overlapping communication with computation.

  • allreduce_trigger_params (list, optional, default=None) – If supplied, should contain a list of parameters drawn from the model. Allreduces will be kicked off whenever one of these parameters receives its gradient (as opposed to when a bucket of size message_size is full). At the end of backward(), a cleanup allreduce to catch any remaining gradients will also be performed automatically. If allreduce_trigger_params is supplied, the message_size argument will be ignored.

  • allreduce_always_fp32 (bool, default=False) – Convert any FP16 gradients to FP32 before allreducing. This can improve stability for widely scaled-out runs.

  • gradient_average (bool, default=True) – Option to toggle whether or not DDP averages the allreduced gradients over processes. For proper scaling, the default value of True is recommended.

  • gradient_predivide_factor (float, default=1.0) – Allows perfoming the average of gradients over processes partially before and partially after the allreduce. Before allreduce: grads.mul_(1.0/gradient_predivide_factor). After allreduce: grads.mul_(gradient_predivide_factor/world size). This can reduce the stress on the dynamic range of FP16 allreduces for widely scaled-out runs.

Warning

If gradient_average=False, the pre-allreduce division (grads.mul_(1.0/gradient_predivide_factor)) will still be applied, but the post-allreduce gradient averaging (grads.mul_(gradient_predivide_factor/world size)) will be omitted.

forward(*inputs, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class apex.parallel.Reducer(module_or_grads_list)[source]

apex.parallel.Reducer is a simple class that helps allreduce a module’s parameters across processes. Reducer is intended to give the user additional control: Unlike DistributedDataParallel, Reducer will not automatically allreduce parameters during backward(). Instead, Reducer waits for the user to call <reducer_instance>.reduce() manually. This enables, for example, delaying the allreduce to be carried out every several iterations instead of every single iteration.

Like DistributedDataParallel, Reducer averages any tensors it allreduces over the number of participating processes.

Reducer is designed to work with the upstream launch utility script torch.distributed.launch with --nproc_per_node <= number of gpus per node. When used with this launcher, Reducer assumes 1:1 mapping of processes to GPUs. It also assumes that your script calls torch.cuda.set_device(args.rank) before creating the model.

Parameters

module_or_grads_list – Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they’re all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module’s parameters at the beginning of training.

class apex.parallel.SyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False)[source]

synchronized batch normalization module extented from torch.nn.BatchNormNd with the added stats reduction across multiple processes. apex.parallel.SyncBatchNorm is designed to work with DistributedDataParallel.

When running in training mode, the layer reduces stats across all processes to increase the effective batchsize for normalization layer. This is useful in applications where batch size is small on a given process that would diminish converged accuracy of the model. The model uses collective communication package from torch.distributed.

When running in evaluation mode, the layer falls back to torch.nn.functional.batch_norm

Parameters
  • num_features\(C\) from an expected input of size \((N, C, L)\) or \(L\) from input of size \((N, L)\)

  • eps – a value added to the denominator for numerical stability. Default: 1e-5

  • momentum – the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1

  • affine – a boolean value that when set to True, this module has learnable affine parameters. Default: True

  • track_running_stats – a boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: True

  • process_group – pass in a process group within which the stats of the mini-batch is being synchronized. None for using default process group

  • channel_last – a boolean value that when set to True, this module take the last dimension of the input tensor to be the channel dimension. Default: False

Examples::
>>> # channel first tensor
>>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
>>> inp = torch.randn(10, 100, 14, 14).cuda()
>>> out = sbn(inp)
>>> inp = torch.randn(3, 100, 20).cuda()
>>> out = sbn(inp)
>>> # channel last tensor
>>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda()
>>> inp = torch.randn(10, 14, 14, 100).cuda()
forward(input, z=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Utility functions

apex.parallel.convert_syncbn_model(module, process_group=None, channel_last=False)[source]

Recursively traverse module and its children to replace all instances of torch.nn.modules.batchnorm._BatchNorm with apex.parallel.SyncBatchNorm.

All torch.nn.BatchNorm*N*d wrap around torch.nn.modules.batchnorm._BatchNorm, so this function lets you easily switch to use sync BN.

Parameters

module (torch.nn.Module) – input module

Example:

>>> # model is an instance of torch.nn.Module
>>> import apex
>>> sync_bn_model = apex.parallel.convert_syncbn_model(model)