Source code for decoders.decoder

# Copyright (c) 2018 NVIDIA Corporation
from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals

import abc
import copy

import six
import tensorflow as tf

from open_seq2seq.optimizers.mp_wrapper import mp_regularizer_wrapper
from open_seq2seq.utils.utils import check_params, cast_types


[docs]@six.add_metaclass(abc.ABCMeta) class Decoder: """Abstract class from which all decoders must inherit. """
[docs] @staticmethod def get_required_params(): """Static method with description of required parameters. Returns: dict: Dictionary containing all the parameters that **have to** be included into the ``params`` parameter of the class :meth:`__init__` method. """ return {}
[docs] @staticmethod def get_optional_params(): """Static method with description of optional parameters. Returns: dict: Dictionary containing all the parameters that **can** be included into the ``params`` parameter of the class :meth:`__init__` method. """ return { 'regularizer': None, # any valid TensorFlow regularizer 'regularizer_params': dict, 'initializer': None, # any valid TensorFlow initializer 'initializer_params': dict, 'dtype': [tf.float32, tf.float16, 'mixed'], }
[docs] def __init__(self, params, model, name="decoder", mode='train'): """Decoder constructor. Note that decoder constructors should not modify TensorFlow graph, all graph construction should happen in the :meth:`self._decode() <_decode>` method. Args: params (dict): parameters describing the decoder. All supported parameters are listed in :meth:`get_required_params`, :meth:`get_optional_params` functions. model (instance of a class derived from :class:`Model<models.model.Model>`): parent model that created this decoder. Could be None if no model access is required for the use case. name (str): name for decoder variable scope. mode (str): mode decoder is going to be run in. Could be "train", "eval" or "infer". Config parameters: * **initializer** --- any valid TensorFlow initializer. If no initializer is provided, model initializer will be used. * **initializer_params** (dict) --- dictionary that will be passed to initializer ``__init__`` method. * **regularizer** --- and valid TensorFlow regularizer. If no regularizer is provided, model regularizer will be used. * **regularizer_params** (dict) --- dictionary that will be passed to regularizer ``__init__`` method. * **dtype** --- model dtype. Could be either ``tf.float16``, ``tf.float32`` or "mixed". For details see :ref:`mixed precision training <mixed_precision>` section in docs. If no dtype is provided, model dtype will be used. """ check_params(params, self.get_required_params(), self.get_optional_params()) self._params = copy.deepcopy(params) self._model = model if 'dtype' not in self._params: if self._model: self._params['dtype'] = self._model.params['dtype'] else: self._params['dtype'] = tf.float32 self._name = name self._mode = mode self._compiled = False
[docs] def decode(self, input_dict): """Wrapper around :meth:`self._decode() <_decode>` method. Here name, initializer and dtype are set in the variable scope and then :meth:`self._decode() <_decode>` method is called. Args: input_dict (dict): see :meth:`self._decode() <_decode>` docs. Returns: see :meth:`self._decode() <_decode>` docs. """ if not self._compiled: if 'regularizer' not in self._params: if self._model and 'regularizer' in self._model.params: self._params['regularizer'] = copy.deepcopy( self._model.params['regularizer'] ) self._params['regularizer_params'] = copy.deepcopy( self._model.params['regularizer_params'] ) if 'regularizer' in self._params: init_dict = self._params.get('regularizer_params', {}) if self._params['regularizer'] is not None: self._params['regularizer'] = self._params['regularizer'](**init_dict) if self._params['dtype'] == 'mixed': self._params['regularizer'] = mp_regularizer_wrapper( self._params['regularizer'], ) if self._params['dtype'] == 'mixed': self._params['dtype'] = tf.float16 if 'initializer' in self.params: init_dict = self.params.get('initializer_params', {}) initializer = self.params['initializer'](**init_dict) else: initializer = None self._compiled = True with tf.variable_scope(self._name, initializer=initializer, dtype=self.params['dtype']): return self._decode(self._cast_types(input_dict))
[docs] def _cast_types(self, input_dict): """This function performs automatic cast of all inputs to decoder dtype. Args: input_dict (dict): dictionary passed to :meth:`self._decode() <_decode>` method. Returns: dict: same as input_dict, but with all Tensors cast to decoder dtype. """ return cast_types(input_dict, self.params['dtype'])
[docs] @abc.abstractmethod def _decode(self, input_dict): """This is the main function which should construct decoder graph. Typically, decoder will take hidden representation from encoder as an input and produce some output sequence as an output. Args: input_dict (dict): dictionary containing decoder inputs. If the decoder is used with :class:`models.encoder_decoder` class, ``input_dict`` will have the following content:: { "encoder_output": dictionary returned from encoder.encode() method "target_tensors": data_layer.input_tensors['target_tensors'] } Returns: dict: dictionary of decoder outputs. Typically this will be just:: { "logits": logits that will be passed to Loss "outputs": list with actual decoded outputs, e.g. characters instead of logits } """ pass
@property def params(self): """Parameters used to construct the decoder (dictionary)""" return self._params @property def mode(self): """Mode decoder is run in.""" return self._mode @property def name(self): """Decoder name.""" return self._name