# This code is heavily based on the code from MLPerf
# https://github.com/mlperf/reference/tree/master/translation/tensorflow/transformer
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
[docs]class LayerNormalization(tf.layers.Layer):
"""Layer normalization for BTC format: supports L2(default) and L1 modes"""
def __init__(self, hidden_size, params={}):
super(LayerNormalization, self).__init__()
self.hidden_size = hidden_size
self.norm_type = params.get("type", "layernorm_L2")
self.epsilon = params.get("epsilon", 1e-6)
[docs] def build(self, _):
self.scale = tf.get_variable("layer_norm_scale", [self.hidden_size],
initializer= tf.keras.initializers.Ones(),
dtype=tf.float32)
self.bias = tf.get_variable("layer_norm_bias", [self.hidden_size],
initializer=tf.keras.initializers.Zeros(),
dtype=tf.float32)
self.built = True
[docs] def call(self, x):
if self.norm_type=="layernorm_L2":
epsilon = self.epsilon
dtype = x.dtype
x = tf.cast(x=x, dtype=tf.float32)
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
result = norm_x * self.scale + self.bias
return tf.cast(x=result, dtype=dtype)
else:
dtype = x.dtype
if dtype==tf.float16:
x = tf.cast(x, dtype=tf.float32)
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
x = x - mean
variance = tf.reduce_mean(tf.abs(x), axis=[-1], keepdims=True)
norm_x = tf.div(x , variance + self.epsilon)
y = norm_x * self.scale + self.bias
if dtype == tf.float16:
y = tf.saturate_cast(y, dtype)
return y
[docs]class PrePostProcessingWrapper(object):
"""Wrapper around layer, that applies pre-processing and post-processing."""
def __init__(self, layer, params, training):
self.layer = layer
self.postprocess_dropout = params["layer_postprocess_dropout"]
self.training = training
self.norm_params = params.get("norm_params", {"type": "layernorm_L2"})
# Create normalization layer
if self.norm_params["type"]=="batch_norm":
self.norm = Transformer_BatchNorm(training=training,
params=self.norm_params)
else:
self.norm = LayerNormalization(hidden_size=params["hidden_size"],
params=self.norm_params)
def __call__(self, x, *args, **kwargs):
# Preprocessing: normalization
y = self.norm(x)
y = self.layer(y, *args, **kwargs)
# Postprocessing: dropout and residual connection
if self.training:
y = tf.nn.dropout(y, keep_prob=1 - self.postprocess_dropout)
return x + y