Source code for apex.parallel

import torch

if hasattr(torch.distributed, 'ReduceOp'):
    ReduceOp = torch.distributed.ReduceOp
elif hasattr(torch.distributed, 'reduce_op'):
    ReduceOp = torch.distributed.reduce_op
else:
    ReduceOp = torch.distributed.deprecated.reduce_op

from .distributed import DistributedDataParallel, Reducer
# This is tricky because I'd like SyncBatchNorm to be exposed the same way
# for both the cuda-enabled and python-fallback versions, and I don't want
# to suppress the error information.
try:
    import syncbn
    from .optimized_sync_batchnorm import SyncBatchNorm
except ImportError as err:
    from .sync_batchnorm import SyncBatchNorm
    SyncBatchNorm.syncbn_import_error = err

[docs]def convert_syncbn_model(module, process_group=None, channel_last=False): ''' Recursively traverse module and its children to replace all instances of ``torch.nn.modules.batchnorm._BatchNorm`` with :class:`apex.parallel.SyncBatchNorm`. All ``torch.nn.BatchNorm*N*d`` wrap around ``torch.nn.modules.batchnorm._BatchNorm``, so this function lets you easily switch to use sync BN. Args: module (torch.nn.Module): input module Example:: >>> # model is an instance of torch.nn.Module >>> import apex >>> sync_bn_model = apex.parallel.convert_syncbn_model(model) ''' mod = module if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): return module if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last) mod.running_mean = module.running_mean mod.running_var = module.running_var if module.affine: mod.weight.data = module.weight.data.clone().detach() mod.bias.data = module.bias.data.clone().detach() for name, child in module.named_children(): mod.add_module(name, convert_syncbn_model(child, process_group=process_group, channel_last=channel_last)) # TODO(jie) should I delete model explicitly? del module return mod
def create_syncbn_process_group(group_size): ''' Creates process groups to be used for syncbn of a give ``group_size`` and returns process group that current GPU participates in. ``group_size`` must divide the total number of GPUs (world_size). ``group_size`` of 0 would be considered as =world_size. In this case ``None`` will be returned. ``group_size`` of 1 would be equivalent to using non-sync bn, but will still carry the overhead. Args: group_size (int): number of GPU's to collaborate for sync bn Example:: >>> # model is an instance of torch.nn.Module >>> import apex >>> group = apex.parallel.create_syncbn_process_group(group_size) ''' if group_size==0: return None world_size = torch.distributed.get_world_size() assert(world_size >= group_size) assert(world_size % group_size == 0) group=None for group_num in (range(world_size//group_size)): group_ids = range(group_num*group_size, (group_num+1)*group_size) cur_group = torch.distributed.new_group(ranks=group_ids) if (torch.distributed.get_rank()//group_size == group_num): group = cur_group #can not drop out and return here, every process must go through creation of all subgroups assert(group is not None) return group