Currently we support the following models:
|Model description||Config file||Checkpoint|
|LSTM with WikiText-2||lstm-wkt2-fp32.py||Perplexity=89.9|
|LSTM with WikiText-103||lstm-wkt103-mixed.py||Perplexity=48.6|
The model specification and training parameters can be found in the corresponding config file.
The WkiText-103 dataset, developed by Salesforce, contains over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. It has 267,340 unique tokens that appear at least 3 times in the dataset. Since it has full-length Wikipedia articles, the dataset is well-suited for tasks that can benefit of long term dependencies, such as language modeling.
The WikiText-2 dataset is a small version of the WikiText-103 dataset as it contains only 2 million tokens. This small dataset is suitable for testing your language model.
The WikiText datasets are available in both word-level (with minor preprocessing and rare tokens being replaced with <UNK>) and the raw character level. OpenSeq2Seq’s WKTDataLayer is equipped to deal with both versions, but we recommend that you use the raw dataset.
You can download the datasets here <https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/>, extract them to the location of your choice. The dataset should contain of 3 files for train, validation, and test. Don’t forget to update the
data_root parameter in your config file to point to the location of your dataset.
WKTDataLayer does the necessary pre-processing to make the WikiText datasets ready to be fed into the model. We use the
word_token method available in the
You can pre-process the data for your language model in any way you deem fit. However, if you want to use your trained language model for other tasks such as sentiment analysis, make sure that the dataset used for your language model and the dataset used for the sentiment analysis model have similar pre-processing and share the same vocabulary.
Next let’s create a simple LSTM language model by defining a config file for it or using one of the config files defined in
data_rootto point to the directory containing the raw dataset used to train your language model, for example, your WikiText dataset downloaded above.
processed_data_folderto point to the location where you want to store the processed dataset. If the dataset has been pre-procesed before, the data layer can just load the data from this location.
- update other hyper parameters such as number of layers, number of hidden units, cell type, loss function, learning rate, optimizer, etc. to meet your needs.
"mixed"if you want to use mixed-precision training, or
tf.float32to train only in FP32.
For example, your config file is
lstm-wkt103-mixed.py. To train without Horovod, update
use_horovod to False in the config file and run:
python run.py --config_file=example_configs/lstmlm/lstm-wkt103-mixed.py --mode=train_eval --enable_logs
When training with Horovod, use the following command:
mpiexec --allow-run-as-root -np <num_gpus> python run.py --config_file=example_configs/lstmlm/lstm-wkt103-mixed.py --mode=train_eval --use_horovod=True --enable_logs
Some things to keep in mind:
- Don’t forget to update
num_gpusto the number of GPUs you want to use.
- If the vocabulary is large (the word-level vocabulary for WikiText-103 is 267,000+), you might want to use
BasicSampledSequenceLoss, which uses sampled softnax, instead of
BasicSequenceLoss, which uses full softmax.
- If your GPUs still run out of memory, reduce the
Even if your training is done using sampled softmax, evaluation and text generation will always be done using full softmax. Running in the mode
eval will evaluate your model on the evaluation set:
python run.py --config_file=example_configs/lstmlm/lstm-wkt103-mixed.py --mode=eval --enable_logs
Running in the mode
infer will generate text from the seed tokens, defined in the config file under the parameter name
seed_tokens, each seed token should be separated by space. [TODO: make
seed_tokens take a list of strings instead]:
python run.py --config_file=example_configs/lstmlm/lstm-wkt103-mixed.py --mode=infer --enable_logs