Source code for parts.transformer.attention_layer

# Copyright 2018 MLBenchmark Group. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of multiheaded attention and self-attention layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

[docs]class Attention(tf.layers.Layer): """Multi-headed attention layer.""" def __init__( self, hidden_size, num_heads, attention_dropout, train, mode="loung", regularizer=None, window_size=None, back_step_size=None ): if hidden_size % num_heads != 0: raise ValueError("Hidden size must be evenly divisible by the number of " "heads.") super(Attention, self).__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.attention_dropout = attention_dropout self.train = train self.mode = mode # Parameters for monotonic attention forcing during inference self.window_size = window_size self.back_step_size = back_step_size # Layers for linearly projecting the queries, keys, and values. self.q_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="q", kernel_regularizer=regularizer) self.k_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="k", kernel_regularizer=regularizer) self.v_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="v", kernel_regularizer=regularizer) self.output_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="output_transform", kernel_regularizer=regularizer)
[docs] def split_heads(self, x): """Split x into different heads, and transpose the resulting value. The tensor is transposed to insure the inner dimensions hold the correct values during the matrix multiplication. Args: x: A tensor with shape [batch_size, length, hidden_size] Returns: A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads] """ with tf.name_scope("split_heads"): batch_size = tf.shape(x)[0] length = tf.shape(x)[1] # Calculate depth of last dimension after it has been split. depth = (self.hidden_size // self.num_heads) # Split the last dimension x = tf.reshape(x, [batch_size, length, self.num_heads, depth]) # Transpose the result return tf.transpose(x, [0, 2, 1, 3])
[docs] def combine_heads(self, x): """Combine tensor that has been split. Args: x: A tensor [batch_size, num_heads, length, hidden_size/num_heads] Returns: A tensor with shape [batch_size, length, hidden_size] """ with tf.name_scope("combine_heads"): batch_size = tf.shape(x)[0] length = tf.shape(x)[2] x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth] return tf.reshape(x, [batch_size, length, self.hidden_size])
[docs] def call(self, x, y, bias, cache=None, positions=None): """Apply attention mechanism to x and y. Args: x: a tensor with shape [batch_size, length_x, hidden_size] y: a tensor with shape [batch_size, length_y, hidden_size] bias: attention bias that will be added to the result of the dot product. cache: (Used during prediction) dictionary with tensors containing results of previous attentions. The dictionary must have the items: {"k": tensor with shape [batch_size, i, key_channels], "v": tensor with shape [batch_size, i, value_channels]} where i is the current decoded length. positions: decoder-encoder alignment for previous steps [batch_size, n_heads, length_x] Returns: Attention layer output with shape [batch_size, length_x, hidden_size] """ # Linearly project the query (q), key (k) and value (v) using different # learned projections. This is in preparation of splitting them into # multiple heads. Multi-head attention uses multiple queries, keys, and # values rather than regular attention (which uses a single q, k, v). q = self.q_dense_layer(x) k = self.k_dense_layer(y) v = self.v_dense_layer(y) if cache is not None: # Combine cached keys and values with new keys and values. k = tf.concat([cache["k"], k], axis=1) v = tf.concat([cache["v"], v], axis=1) # Update cache cache["k"] = k cache["v"] = v # Split q, k, v into heads. q = self.split_heads(q) k = self.split_heads(k) v = self.split_heads(v) if self.mode == "loung": # Scale q to prevent the dot product between q and k from growing too large. depth = (self.hidden_size // self.num_heads) q *= depth ** -0.5 # Calculate dot product attention # logits = tf.matmul(q, k, transpose_b=True) # logits += bias # weights = tf.nn.softmax(logits, name="attention_weights") logits = tf.matmul(q, k, transpose_b=True) dtype = logits.dtype if dtype != tf.float32: # upcast softmax inputs logits = tf.cast(x=logits, dtype=tf.float32) logits += bias weights = tf.nn.softmax(logits, name="attention_weights") # downcast softmax output weights = tf.cast(weights, dtype=dtype) else: # Logits shape: [batch, head, decoder, encoder] # Bias shape: [batch, 1, 1, encoder] # Force monotonic attention during inference if positions is not None and self.window_size is not None: assert self.back_step_size is not None max_length = tf.shape(logits)[-1] # Allow to make back_step_size steps back window_pos = tf.maximum(positions - self.back_step_size, tf.zeros_like(positions)) # Create attention mask mask_large = tf.sequence_mask(window_pos + self.window_size, maxlen=max_length) mask_large = tf.cast(mask_large, tf.float32) mask_small = tf.sequence_mask(window_pos, maxlen=max_length) mask_small = tf.cast(mask_small, tf.float32) mask = mask_large - mask_small mask = -1e9 * (1 - mask) bias = mask + bias # Clipping bias = tf.maximum(bias, -1e9) logits += bias weights = tf.nn.softmax(logits, name="attention_weights") elif self.mode == "bahdanau": att_v = tf.get_variable( "attention_v", [self.hidden_size // self.num_heads], dtype=q.dtype ) # Compute the attention score if bias is not None: weights = tf.reduce_sum( tf.nn.tanh(att_v * tf.nn.tanh(k + q + bias)), 3 ) else: weights = tf.reduce_sum( tf.nn.tanh(att_v * tf.nn.tanh(k + q)), 3 ) weights = tf.nn.softmax(weights) weights = tf.expand_dims(weights, 2) else: raise ValueError( "Mode for multi-head attention must be either loung for dot-product", "attention, or bahdanau for content-based/additive/mlp-base attention" ) if self.train: weights = tf.nn.dropout(weights, keep_prob=1 - self.attention_dropout) attention_output = tf.matmul(weights, v) # Recombine heads --> [batch_size, length, hidden_size] attention_output = self.combine_heads(attention_output) # Run the combined outputs through another linear projection layer. attention_output = self.output_dense_layer(attention_output) return attention_output
[docs]class SelfAttention(Attention): """Multiheaded self-attention layer."""
[docs] def call(self, x, bias, cache=None): return super(SelfAttention, self).call(x, x, bias, cache)