Source code for data.text2text.text2text

# Copyright (c) 2017 NVIDIA Corporation
from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals

import numpy as np
import tensorflow as tf
import os
from enum import Enum
from open_seq2seq.data.data_layer import DataLayer
from open_seq2seq.data.utils import load_pre_existing_vocabulary, pad_vocab_to_eight
from open_seq2seq.data.text2text.t2t import _read_and_batch_from_files
from open_seq2seq.data.text2text.tokenizer import PAD_ID

[docs]class SpecialTextTokens(Enum): PAD_ID = 0 # special padding token EOS_ID = 1 # special end of sentence token S_ID = 2 # special start of sentence token UNK_ID = 3 # out-of-vocabulary tokens will map there OUT_OF_BUCKET = 1234567890 END_OF_CHOICE = -100 @staticmethod def to_string(s_token): if s_token == SpecialTextTokens.UNK_ID.value: return '<UNK>' elif s_token == SpecialTextTokens.S_ID.value: return '<S>' elif s_token == SpecialTextTokens.EOS_ID.value: return '</S>' elif s_token == SpecialTextTokens.PAD_ID.value: return '<PAD>' else: raise ValueError("Unknown Value in SpecialTokens")
[docs]class ParallelTextDataLayer(DataLayer):
[docs] @staticmethod def get_required_params(): return dict(DataLayer.get_required_params(), **{ 'source_file': str, 'src_vocab_file': str, 'tgt_vocab_file': str, 'max_length': int, 'shuffle': bool, 'repeat': bool, })
[docs] @staticmethod def get_optional_params(): return dict(DataLayer.get_optional_params(), **{ 'use_targets': bool, 'delimiter': str, 'target_file': str, 'map_parallel_calls': int, 'prefetch_buffer_size': int, 'pad_lengths_to_eight': bool, 'pad_vocab_to_eight': bool, 'shuffle_buffer_size': int, 'special_tokens_already_in_vocab': bool, 'use_start_token': bool, })
def __init__(self, params, model, num_workers=1, worker_id=0): super(ParallelTextDataLayer, self).__init__(params, model, num_workers, worker_id) self._batch_size = self.params['batch_size'] self.source_file = self.params['source_file'] self._use_targets = self.params.get('use_targets', True) if not self._use_targets: self.target_file = self.source_file if 'target_file' in self.params: print("WARNING: target file was specified but was " "ignored by data layer because 'use_targets'=False") else: self.target_file = self.params['target_file'] self.src_vocab_file = self.params['src_vocab_file'] self.tgt_vocab_file = self.params['tgt_vocab_file'] self.max_len = self.params['max_length'] self._delimiter = self.params.get('delimiter', ' ') self._map_parallel_calls = self.params.get('map_parallel_calls', 8) self._pad_lengths_to_eight = self.params.get('pad_lengths_to_eight', False) self._prefetch_buffer_size = self.params.get('prefetch_buffer_size', tf.contrib.data.AUTOTUNE) self._shuffle_buffer_size = self.params.get('shuffle_buffer_size', -1) self._num_workers = num_workers self._worker_id = worker_id self._use_start_token = self.params.get('use_start_token', True) if self._pad_lengths_to_eight and not (self.params['max_length'] % 8 == 0): raise ValueError("If padding to 8 in data layer, then " "max_length should be multiple of 8") def file_len(fname): with open(fname) as f: for i, l in enumerate(f): pass return i + 1 self.dataset_size = file_len(self.source_file) special_tokens_already_in_vocab = self.params.get('special_tokens_already_in_vocab', True) # load source and target vocabularies to RAM self.src_seq2idx = load_pre_existing_vocabulary( self.src_vocab_file, min_idx=0 if special_tokens_already_in_vocab else SpecialTextTokens.UNK_ID.value + 1) self.tgt_seq2idx = load_pre_existing_vocabulary( self.tgt_vocab_file, min_idx=0 if special_tokens_already_in_vocab else SpecialTextTokens.UNK_ID.value + 1) if not special_tokens_already_in_vocab: # manually add special tokens # unknown symbol self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.UNK_ID.value)] = \ SpecialTextTokens.UNK_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.UNK_ID.value)] = \ SpecialTextTokens.UNK_ID.value # sentence start self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.S_ID.value)] = \ SpecialTextTokens.S_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.S_ID.value)] = \ SpecialTextTokens.S_ID.value # sentence end self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.EOS_ID.value)] = \ SpecialTextTokens.EOS_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.EOS_ID.value)] = \ SpecialTextTokens.EOS_ID.value # padding self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.PAD_ID.value)] = \ SpecialTextTokens.PAD_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.PAD_ID.value)] = \ SpecialTextTokens.PAD_ID.value if self.params.get('pad_vocab_to_eight', False): self.src_seq2idx = pad_vocab_to_eight(self.src_seq2idx) self.tgt_seq2idx = pad_vocab_to_eight(self.tgt_seq2idx) self.src_idx2seq = {idx: w for w, idx in self.src_seq2idx.items()} self.tgt_idx2seq = {idx: w for w, idx in self.tgt_seq2idx.items()} self.params['src_vocab_size'] = len(self.src_seq2idx) self.params['tgt_vocab_size'] = len(self.tgt_seq2idx) self.params['target_seq2idx'] = self.tgt_seq2idx self.params['source_seq2idx'] = self.src_seq2idx self.params['target_idx2seq'] = self.tgt_idx2seq self.params['source_idx2seq'] = self.src_idx2seq self._input_tensors = {} def _pad2eight(self, lst, do_pad_eight): if len(lst) % 8 == 0 or not do_pad_eight: return lst else: return lst + [SpecialTextTokens.PAD_ID.value] * (8 - len(lst) % 8) def _src_token_to_id(self, line): tokens = line.decode("utf-8").split(self._delimiter) #line.numpy().decode if self._use_start_token: return np.array(self._pad2eight([SpecialTextTokens.S_ID.value] + \ [self.src_seq2idx.get(token, SpecialTextTokens.UNK_ID.value) for token in tokens[:self.max_len-2]] + \ [SpecialTextTokens.EOS_ID.value], self._pad_lengths_to_eight), dtype="int32") else: return np.array(self._pad2eight([self.src_seq2idx.get(token, SpecialTextTokens.UNK_ID.value) for token in tokens[:self.max_len - 2]] + \ [SpecialTextTokens.EOS_ID.value], self._pad_lengths_to_eight), dtype="int32") def _tgt_token_to_id(self, line): tokens = line.decode("utf-8").split(self._delimiter) #line.numpy().decode if self._use_start_token: return np.array(self._pad2eight([SpecialTextTokens.S_ID.value] + \ [self.tgt_seq2idx.get(token, SpecialTextTokens.UNK_ID.value) for token in tokens[:self.max_len-2]] + \ [SpecialTextTokens.EOS_ID.value], self._pad_lengths_to_eight), dtype="int32") else: return np.array(self._pad2eight([self.tgt_seq2idx.get(token, SpecialTextTokens.UNK_ID.value) for token in tokens[:self.max_len - 2]] + \ [SpecialTextTokens.EOS_ID.value], self._pad_lengths_to_eight), dtype="int32")
[docs] def build_graph(self): with tf.device('/cpu:0'): _sources = tf.data.TextLineDataset(self.source_file) _targets = tf.data.TextLineDataset(self.target_file) if self._num_workers > 1: #_src_tgt_dataset = _src_tgt_dataset\ # .shard(num_shards=self._num_workers, index=self._worker_id) _sources = _sources.shard(num_shards=self._num_workers, index=self._worker_id) _targets = _targets.shard(num_shards=self._num_workers, index=self._worker_id) _sources = _sources.map(lambda line: tf.py_func(func=self._src_token_to_id, inp=[line], Tout=[tf.int32], stateful=False), num_parallel_calls=self._map_parallel_calls) \ .map(lambda tokens: (tokens, tf.size(tokens)), num_parallel_calls=self._map_parallel_calls) _targets = _targets.map(lambda line: tf.py_func(func=self._tgt_token_to_id, inp=[line], Tout=[tf.int32], stateful=False), num_parallel_calls=self._map_parallel_calls) \ .map(lambda tokens: (tokens, tf.size(tokens)), num_parallel_calls=self._map_parallel_calls) _src_tgt_dataset = tf.data.Dataset.zip((_sources, _targets)).filter( lambda t1, t2: tf.logical_and(tf.less_equal(t1[1], self.max_len), tf.less_equal(t2[1], self.max_len)) ).cache() if self.params['shuffle']: bf_size = self.get_size_in_samples() if self._shuffle_buffer_size == -1 \ else self._shuffle_buffer_size _src_tgt_dataset = _src_tgt_dataset.shuffle(buffer_size=bf_size) else: _src_tgt_dataset = _src_tgt_dataset if self.params['repeat']: _src_tgt_dataset = _src_tgt_dataset.repeat() self.batched_dataset = _src_tgt_dataset.padded_batch( self._batch_size, padded_shapes=((tf.TensorShape([None]), tf.TensorShape([])), (tf.TensorShape([None]), tf.TensorShape([]))), padding_values=( (SpecialTextTokens.PAD_ID.value, 0), (SpecialTextTokens.PAD_ID.value, 0))).prefetch(buffer_size=self._prefetch_buffer_size) self._iterator = self.batched_dataset.make_initializable_iterator() if self.params['mode'] == 'train' or self.params['mode'] == 'eval': t1, t2 = self.iterator.get_next() x, x_length = t1[0], t1[1] y, y_length = t2[0], t2[1] self._input_tensors['source_tensors'] = [x, x_length] self._input_tensors['target_tensors'] = [y, y_length] else: t1, _ = self.iterator.get_next() self._input_tensors['source_tensors'] = [t1[0], t1[1]]
[docs] def create_interactive_placeholders(self): self._text = tf.placeholder(dtype=tf.int32, shape=[self._batch_size, None]) self._text_length = tf.placeholder(dtype=tf.int32, shape=[self._batch_size]) self._input_tensors = {} self._input_tensors['source_tensors'] = [self._text, self._text_length]
[docs] def create_feed_dict(self, model_in): """ Creates the feed dict for interactive infer Args: model_in (str): the string to be translated. Should be in bpe format. Returns: feed_dict (dict): Dictionary with values for the placeholders. """ text = [] text_length = [] for line in model_in: line = self._src_token_to_id(line) text.append(line) text_length.append(line.shape[0]) max_len = np.max(text_length) for i,line in enumerate(text): line = np.pad( line, ((0, max_len-len(line))), "constant", constant_values=SpecialTextTokens.PAD_ID.value ) text[i] = line text = np.reshape(text, [self._batch_size, -1]) text_length = np.reshape(text_length, [self._batch_size]) feed_dict = { self._text: text, self._text_length: text_length } return feed_dict
[docs] def get_size_in_samples(self): return self.dataset_size
@property def iterator(self): return self._iterator @property def input_tensors(self): return self._input_tensors
[docs]class TransformerDataLayer(DataLayer): """Wraps Transformers data pipeline into the form for OpenSeq2Seq"""
[docs] @staticmethod def get_required_params(): return dict(DataLayer.get_required_params(), **{ 'data_dir': str, 'file_pattern': str, 'src_vocab_file': str, 'batch_size': int, 'max_length': int, 'shuffle': bool, "delimiter": str, })
[docs] @staticmethod def get_optional_params(): return dict(DataLayer.get_optional_params(), **{ 'repeat': int, 'num_cpu_cores': int, 'tgt_vocab_file': str, 'pad_data_to_eight': bool, 'batch_in_tokens': bool, })
def __init__(self, params, model, num_workers=1, worker_id=0): super(TransformerDataLayer, self).__init__(params, model, num_workers, worker_id) self.src_vocab_file = self.params['src_vocab_file'] # if tgt vocab isn't specified - assume common vocab file self.tgt_vocab_file = self.params.get('tgt_vocab_file', self.src_vocab_file) # load source and target vocabularies to RAM # pre-processed vocab starts from PAD, EOS self.src_seq2idx = load_pre_existing_vocabulary( self.src_vocab_file, min_idx=PAD_ID) self.tgt_seq2idx = load_pre_existing_vocabulary( self.tgt_vocab_file, min_idx=PAD_ID) self.src_idx2seq = {idx: w for w, idx in self.src_seq2idx.items()} self.tgt_idx2seq = {idx: w for w, idx in self.tgt_seq2idx.items()} self.params['src_vocab_size'] = len(self.src_seq2idx) self.params['tgt_vocab_size'] = len(self.tgt_seq2idx) self.params['target_seq2idx'] = self.tgt_seq2idx self.params['source_seq2idx'] = self.src_seq2idx self.params['target_idx2seq'] = self.tgt_idx2seq self.params['source_idx2seq'] = self.src_idx2seq self._num_workers = num_workers self._worker_id = worker_id self._input_tensors = {} self._iterator = None self.batched_dataset = None
[docs] def build_graph(self): file_pattern = os.path.join(self.params['data_dir'], self.params['file_pattern']) self.batched_dataset = _read_and_batch_from_files( file_pattern=file_pattern, batch_size=self.params['batch_size'], max_length=self.params['max_length'], num_cpu_cores=self.params.get('num_cpu_cores', 2), shuffle=self.params['shuffle'], repeat=self.params['repeat'], num_workers=self._num_workers, worker_id=self._worker_id, batch_in_tokens=self.params.get('batch_in_tokens', True), pad2eight=self.params.get('pad_data_to_eight', False)) self._iterator = self.batched_dataset.make_initializable_iterator() x, y = self.iterator.get_next() len_x = tf.count_nonzero(x, axis=1, dtype=tf.int32) len_y = tf.count_nonzero(y, axis=1, dtype=tf.int32) if self.params['mode'] == 'train' or self.params['mode'] == 'eval': self._input_tensors['source_tensors'] = [x, len_x] self._input_tensors['target_tensors'] = [y, len_y] else: self._input_tensors['source_tensors'] = [x, len_x]
@property def iterator(self): return self._iterator @property def input_tensors(self): return self._input_tensors