import torch
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
from collections import OrderedDict
from itertools import chain
import copy
import importlib
from ..multi_tensor_apply import multi_tensor_applier
imported_flatten_impl = False
def import_flatten_impl():
global flatten_impl, unflatten_impl, imported_flatten_impl
try:
import apex_C
flatten_impl = apex_C.flatten
unflatten_impl = apex_C.unflatten
except ImportError:
print("Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten.")
flatten_impl = torch._utils._flatten_dense_tensors
unflatten_impl = torch._utils._unflatten_dense_tensors
imported_flatten_impl = True
def flatten(bucket):
if not imported_flatten_impl:
import_flatten_impl()
return flatten_impl(bucket)
def unflatten(coalesced, bucket):
if not imported_flatten_impl:
import_flatten_impl()
return unflatten_impl(coalesced, bucket)
# apply_dist_call requires that tensors in 'bucket' are all the same type.
def apply_flat_dist_call(bucket, call, extra_args=None):
coalesced = flatten(bucket)
if extra_args is not None:
call(coalesced, *extra_args)
else:
call(coalesced)
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced)
def split_half_float_double(tensors):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor"]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
def split_by_type(tensors):
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
return buckets
# flat_dist_call organizes 'tensors' by type.
def flat_dist_call(tensors, call, extra_args=None):
buckets = split_by_type(tensors)
for tp in buckets:
bucket = buckets[tp]
apply_flat_dist_call(bucket, call, extra_args)
def extract_tensors(maybe_tensor, tensor_list):
if torch.is_tensor(maybe_tensor):
tensor_list.append(maybe_tensor)
else:
try:
for item in maybe_tensor:
extract_tensors(item, tensor_list)
except TypeError:
return
[docs]class Reducer(object):
"""
:class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters
across processes. :class:`Reducer` is intended to give the user additional control:
Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce
parameters during ``backward()``.
Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually.
This enables, for example, delaying the allreduce to be carried out every
several iterations instead of every single iteration.
Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces
over the number of participating processes.
:class:`Reducer` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
Args:
module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training.
"""
def __init__(self, module_or_grads_list):
if isinstance(module_or_grads_list, Module):
self.module = module_or_grads_list
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
else:
self.module = None
self.grads = []
extract_tensors(module_or_grads_list, self.grads)
def reduce(self):
if self.module:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
else:
flat_dist_call(self.grads, dist.all_reduce)
[docs]class DistributedDataParallel(Module):
"""
:class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables
easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. Parameters are broadcast across participating processes on initialization, and gradients are
allreduced and averaged over processes during ``backward()``.
:class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by
overlapping communication with computation during ``backward()`` and bucketing smaller gradient
transfers to reduce the total number of transfers required.
:class:`DistributedDataParallel` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed shows detailed usage.
https://github.com/NVIDIA/apex/tree/master/examples/imagenet shows another example
that combines :class:`DistributedDataParallel` with mixed precision training.
Args:
module: Network definition to be run in multi-gpu/distributed mode.
message_size (int, default=1e7): Minimum number of elements in a communication bucket.
delay_allreduce (bool, default=False): Delay all communication to the end of the backward pass. This disables overlapping communication with computation.
allreduce_trigger_params (list, optional, default=None): If supplied, should contain a list of parameters drawn from the model. Allreduces will be kicked off whenever one of these parameters receives its gradient (as opposed to when a bucket of size message_size is full). At the end of backward(), a cleanup allreduce to catch any remaining gradients will also be performed automatically. If allreduce_trigger_params is supplied, the message_size argument will be ignored.
allreduce_always_fp32 (bool, default=False): Convert any FP16 gradients to FP32 before allreducing. This can improve stability for widely scaled-out runs.
gradient_average (bool, default=True): Option to toggle whether or not DDP averages the allreduced gradients over processes. For proper scaling, the default value of True is recommended.
gradient_predivide_factor (float, default=1.0): Allows perfoming the average of gradients over processes partially before and partially after the allreduce. Before allreduce: ``grads.mul_(1.0/gradient_predivide_factor)``. After allreduce: ``grads.mul_(gradient_predivide_factor/world size)``. This can reduce the stress on the dynamic range of FP16 allreduces for widely scaled-out runs.
.. warning::
If ``gradient_average=False``, the pre-allreduce division (``grads.mul_(1.0/gradient_predivide_factor)``) will still be applied, but the post-allreduce gradient averaging (``grads.mul_(gradient_predivide_factor/world size)``) will be omitted.
"""
def __init__(self,
module,
message_size=10000000,
delay_allreduce=False,
shared_param=None,
allreduce_trigger_params=None,
retain_allreduce_buffers=False,
allreduce_always_fp32=False,
num_allreduce_streams=1,
allreduce_communicators=None,
gradient_average=True,
gradient_predivide_factor=1.0,
gradient_average_split_factor=None,
prof=False):
super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and
# https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86
if hasattr(dist, "get_backend"):
self._backend = dist.get_backend()
if hasattr(dist, "DistBackend"):
self.backend_enum_holder = dist.DistBackend
else:
self.backend_enum_holder = dist.Backend
else:
self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.prof = prof
self.allreduce_different_streams = (num_allreduce_streams > 1)
self.num_allreduce_streams = num_allreduce_streams
self.allreduce_communicators = allreduce_communicators
if self.allreduce_communicators:
assert len(allreduce_communicators[0]) == num_allreduce_streams
assert len(allreduce_communicators[0]) == len(allreduce_communicators[1])
assert self.allreduce_different_streams
if self.allreduce_different_streams and delay_allreduce:
raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.")
if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
self.world_size = float(dist.get_world_size())
self.retain_allreduce_buffers = retain_allreduce_buffers
self.allreduce_always_fp32 = allreduce_always_fp32
self.gradient_average = gradient_average
self.gradient_predivide_factor = gradient_predivide_factor
self.custom_allreduce_triggers = False
if allreduce_trigger_params is not None:
if delay_allreduce:
raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.")
self.custom_allreduce_triggers = True
self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params])
self.delay_allreduce = delay_allreduce
self.message_size = message_size
self.main_stream = torch.cuda.current_stream()
self.bucket_streams = []
self.bucket_events = []
self.module = module
self._disable_allreduce = False
if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.active_params = []
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2}
if multi_tensor_applier.available:
# TODO: I really need to centralize the C++ backed imports
import amp_C
self.multi_tensor_scale = amp_C.multi_tensor_scale
self._overflow_buf = torch.cuda.IntTensor([0])
self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
if self.allreduce_different_streams and delay_allreduce:
raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.")
if self.delay_allreduce:
self.needs_refresh = True
self.bucket_streams = []
self.bucket_events = []
def __getstate__(self):
attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.bucket_streams']
del attrs['self.bucket_events']
return attrs
def enable_allreduce(self):
self._disable_allreduce = False
def disable_allreduce(self):
self._disable_allreduce = True
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
def sync_bucket_structure(self):
# Append leftover buckets
for tmp_bucket in self.tmp_buckets:
if len(tmp_bucket) > 0:
self.active_i_buckets.append(tmp_bucket)
self.num_buckets = len(self.active_i_buckets)
self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets]
info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes +
list(chain(*self.active_i_buckets)))
dist.broadcast(info_tensor, 0)
info = [int(entry) for entry in info_tensor]
self.num_buckets = info[0]
self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
# Technically, active_i_buckets' work is done. But the information is still useful to
# keep around. Therefore, refresh active_i_buckets based on rank 0 as well.
self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
flattened_buckets = info[self.num_buckets + 1:]
flat_i = 0
for bucket_idx in range(self.num_buckets):
for bucket_loc in range(self.bucket_sizes[bucket_idx]):
param_i = flattened_buckets[flat_i]
self.active_i_buckets[bucket_idx][bucket_loc] = param_i
self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)
flat_i += 1
def create_hooks(self):
# Fallback hook that's only called at the end of backward.
# Used if you deliberately want to delay allreduces to the end, or to refresh the
# bucket structure that will be used to overlap communication with computation in later
# iterations.
def allreduce_params():
# Bucket record refresh
if not self.delay_allreduce:
if self.needs_refresh:
self.sync_bucket_structure()
self.needs_refresh = False
self.allreduce_fallback()
def overlapping_backward_epilogue():
for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
# Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets:
raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format(
self.next_bucket, self.num_buckets),
"This probably indicates some buckets were not allreduced.")
for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
if actual != expected:
raise RuntimeError("Some param buckets were not allreduced.")
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook")
if not self._disable_allreduce:
if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True?
# needs_refresh and callback_queued are both vulnerable states.
if not self.delay_allreduce and self.needs_refresh:
# Use the backward pass to build the bucket structure on the fly.
active_i = self.param_id_to_active_i[id(param)]
# Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False
if self.custom_allreduce_triggers:
if id(param) in self.allreduce_trigger_params:
ship_tmp_bucket = True
else:
self.tmp_numels[current_type] += param.numel()
if self.tmp_numels[current_type] >= self.message_size:
ship_tmp_bucket = True
# To consider: If custom_allreduce_triggers are in use, ship all
# tmp_buckets, not just tmp_buckets[current_type].
if ship_tmp_bucket:
self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0
if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True
else:
if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True
self.comm_ready_buckets(param)
if self.prof:
torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc)
wrapper(param)
def _stream_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_streams[bucket_idx%self.num_allreduce_streams]
else:
return self.bucket_streams[0]
def _event_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_events[bucket_idx%self.num_allreduce_streams]
else:
return self.bucket_events[0]
def allreduce_bucket(self, bucket, bucket_idx, force_default_stream):
tensor = flatten(bucket)
if force_default_stream:
bucket_stream = self.main_stream
else:
bucket_stream = self._stream_this_bucket(bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
# self.main_stream.wait_stream(torch.cuda.current_stream())
# torch.cuda.synchronize()
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
if self.allreduce_different_streams and not force_default_stream:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx%self.num_allreduce_streams])
else:
dist.all_reduce(tensor_to_allreduce)
if self.gradient_average:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
if not self.retain_allreduce_buffers:
if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(tensor, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(tensor, bucket)):
buf.copy_(synced)
# I think we actually do need this here. After allreduce_bucket returns, tensor will
# eventually go out of scope and die, at which point it could otherwise be freed for
# further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream.
tensor.record_stream(bucket_stream)
return tensor
def allreduce_maybe_retain(self, bucket, bucket_idx, force_default_stream=False):
allreduced = self.allreduce_bucket(bucket, bucket_idx, force_default_stream)
if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled "
"allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced
for view, grad in zip(unflatten(allreduced, bucket), bucket):
grad.data = view
# for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
# buf.copy_(synced)
def allreduce_fallback(self):
for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
if self.retain_allreduce_buffers:
grads = [param.grad for param in self.module.parameters() if param.grad is not None]
else:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
split_buckets = split_half_float_double(grads)
# If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the
# training script, and overwritten in the next forward pass. So it's harmless.
if self.retain_allreduce_buffers:
self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i, force_default_stream=True)
def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
# self.reduction_stream.wait_stream(torch.cuda.current_stream())
if self.prof:
torch.cuda.nvtx.range_push("comm_ready_buckets")
bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
if self.buckets[bucket_idx][bucket_loc] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled "
"bucket slot. This is almost certainly an error.")
if self.retain_allreduce_buffers:
self.buckets[bucket_idx][bucket_loc] = param.grad
else:
self.buckets[bucket_idx][bucket_loc] = param.grad.data
self.buckets_ready_size[bucket_idx] += 1
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket:
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.next_bucket += 1
# Reversing upstream's logic here, because we constructed our buckets based on
# the order things were received during backward.
if len(self.ready_buckets_not_reduced) > 0:
sorted_todo = sorted(self.ready_buckets_not_reduced)
for i in sorted_todo:
# Nothing can be reduced now
if i > self.next_bucket:
break
elif i == self.next_bucket:
self.allreduce_maybe_retain(self.buckets[i], i)
self.ready_buckets_not_reduced.remove(i)
self.next_bucket += 1
else:
raise ValueError("i should always be >= next_bucket")
else:
self.ready_buckets_not_reduced.add(bucket_idx)
if self.prof:
torch.cuda.nvtx.range_pop()
[docs] def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs)
if self.prof:
torch.cuda.nvtx.range_push("forward pass DDP logic")
if not self._disable_allreduce:
if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad]
# Conditions under which to refresh self.record
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ((not self.active_params) or
(len(param_list) != len(self.active_params)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True
if self.needs_refresh:
self.active_i_buckets = []
self.buckets = []
self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0]
self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {}
self.bucket_pgs = []
self.bucket_streams = []
self.bucket_events = []
else:
# self.buckets = [[None for _ in range(self.bucket_sizes[i])]
# for i in range(self.num_buckets)]
if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
else:
assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
len(self.buckets), self.num_buckets)
for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_communicators:
self.bucket_pgs = self.allreduce_communicators[0]
self.bucket_streams = self.allreduce_communicators[1]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0
self.ready_buckets_not_reduced = set()
self.active_params = param_list
self.callback_queued = False
if self.prof:
torch.cuda.nvtx.range_pop()
return result