# Copyright (c) 2018 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import tensorflow as tf
from .automatic_loss_scaler import AutomaticLossScaler
# pylint: disable=abstract-method
[docs]class MixedPrecisionOptimizerWrapper(tf.train.Optimizer):
def __init__(self, optimizer, loss_scale=None):
super(MixedPrecisionOptimizerWrapper, self).__init__(
optimizer._use_locking,
optimizer._name + '-MP',
)
self._optimizer = optimizer
self._fp32_to_fp16 = {}
self._loss_scaler = None
if loss_scale is None:
self._loss_scale = 1.0
elif isinstance(loss_scale, float):
self._loss_scale = loss_scale
elif isinstance(loss_scale, AutomaticLossScaler):
self._loss_scaler = loss_scale
self._loss_scale = self._loss_scaler.loss_scale
[docs] def compute_gradients(self, loss, var_list=None,
gate_gradients=tf.train.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None):
loss *= self._loss_scale
grads_and_vars_fp16 = self._optimizer.compute_gradients(
loss, var_list=var_list,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
grad_loss=grad_loss,
)
# collecting regularization functions
reg_var_funcs = tf.get_collection('REGULARIZATION_FUNCTIONS')
reg_funcs = dict(map(lambda x: (x[0].name, x[1]), reg_var_funcs))
# creating FP-32 variables and filling the fp32 dict
grads_and_vars_fp32 = []
with tf.variable_scope('FP32-master-copy'):
for grad, var in grads_and_vars_fp16:
if var.dtype.base_dtype == tf.float16:
fp32_var = tf.Variable(
initial_value=tf.cast(var.initialized_value(), tf.float32),
name=var.name.split(':')[0],
expected_shape=var.shape,
dtype=tf.float32,
trainable=False,
# necessary for cudnn_rnn layers which have unknown shape
validate_shape=bool(var.get_shape()),
collections=[tf.GraphKeys.GLOBAL_VARIABLES,
"FP32_MASTER_COPIES"],
)
self._fp32_to_fp16[fp32_var.name] = var
fp32_grad = tf.cast(grad, tf.float32)
# adding regularization part with respect to fp32 copy
if var.name in reg_funcs:
fp32_grad += self._loss_scale * tf.gradients(
# pylint: disable=no-member
tf.contrib.layers.apply_regularization(
reg_funcs[var.name],
[fp32_var],
),
fp32_var,
)[0]
grads_and_vars_fp32.append((fp32_grad, fp32_var))
else:
grads_and_vars_fp32.append((grad, var))
grads_and_vars_fp32 = _scale_grads(grads_and_vars_fp32,
1.0 / self._loss_scale)
return grads_and_vars_fp32
[docs] def apply_gradients(self, grads_and_vars, global_step=None, name=None):
def apply_ops_wrapper():
update_op = self._optimizer.apply_gradients(grads_and_vars,
global_step, name)
apply_ops = []
with tf.control_dependencies([update_op]):
for grad, var in grads_and_vars:
if var.name in self._fp32_to_fp16:
dst_var = self._fp32_to_fp16[var.name]
apply_ops.append(
tf.assign(dst_var, tf.saturate_cast(var, tf.float16))
)
if apply_ops:
return tf.group(apply_ops)
return update_op
if self._loss_scaler:
grad_has_nans, grad_amax = AutomaticLossScaler.check_grads(grads_and_vars)
should_skip_update = tf.logical_or(tf.is_inf(grad_amax), grad_has_nans)
loss_scale_update_op = self._loss_scaler.update_op(grad_has_nans,
grad_amax)
with tf.control_dependencies([loss_scale_update_op]):
return tf.cond(should_skip_update, tf.no_op, apply_ops_wrapper)
else:
return apply_ops_wrapper()
[docs]def mp_regularizer_wrapper(regularizer):
def func_wrapper(weights):
if weights.dtype.base_dtype == tf.float16:
tf.add_to_collection('REGULARIZATION_FUNCTIONS', (weights, regularizer))
# disabling the inner regularizer
return None
return regularizer(weights)
return func_wrapper
def _scale_grads(grads_and_vars, scale):
scaled_grads_and_vars = []
for grad, var in grads_and_vars:
if grad is not None:
if isinstance(grad, tf.IndexedSlices):
grad_values = grad.values * scale
grad = tf.IndexedSlices(grad_values, grad.indices, grad.dense_shape)
else:
grad *= scale
scaled_grads_and_vars.append((grad, var))
return scaled_grads_and_vars