Source code for data.utils

# Copyright (c) 2017 NVIDIA Corporation
from __future__ import absolute_import, division, print_function
from __future__ import unicode_literals

import io

from six.moves import range


[docs]def pad_vocab_to_eight(vocab): """Pads vocabulary so that it is divisible by 8. Args: vocab (dict): vocabulary in the form token->id Returns: dict: vocab with new tokens added if necessary, such that the total vocab size is divisible by 8. """ v_len = len(vocab) if v_len % 8 == 0: return vocab for id_add in range(0, 8 - v_len % 8): vocab['<$'+str(id_add)+'$>'] = v_len + id_add return vocab
[docs]def load_pre_existing_vocabulary(path, min_idx=0, read_chars=False): """Loads pre-existing vocabulary into memory. The vocabulary file should contain a token on each line with optional token count on the same line that will be ignored. Example:: a 1234 b 4321 c 32342 d e word 234 Args: path (str): path to vocabulary. min_idx (int, optional): minimum id to assign for a token. read_chars (bool, optional): whether to read only the first symbol of the line. Returns: dict: vocabulary dictionary mapping tokens (chars/words) to int ids. """ idx = min_idx vocab_dict = {} with io.open(path, newline='', encoding='utf-8') as f: for line in f: # ignoring empty lines if not line or line == '\n': continue if read_chars: token = line[0] else: token = line.rstrip().split('\t')[0] vocab_dict[token] = idx idx += 1 return vocab_dict