NgramRes
NgramRes copied to clipboard
Residual Learning of Neural Text Generation with n-gram Language Model
Authors: Huayang Li, Deng Cai, Jin Xu, Taro Watanabe
NOTE: This is a draft version. I will clean this repository soon.
1. Introduction
In this README, we will introduce how to train & evaluate our proposed neuro-symbolic system on the standard language modeling task.
Dependencies:
- python==3.7
- pytorch==1.8.0
- fairseq==1.0.0
- KenLM
Dataset:
2. $n$-gram Language Model
Since we propose to combine the $n$-gram model and neural model at the logits layer (the unnormalized probability scores before softmax
), we need to generate the prediction distribution of $n$-gram model at each step and convert it back to logits using a reverse function of softmax
.
However, at least before I start the project, KenLM didn't support generating the prediction distribution over the vocabulary at each step. Therefore, I implemented a model that can efficiently generate the prediction distribution based on the $n$-gram LM trained by KenLM.
2.1 Train $n$-gram
We use KenLM to train the $n$-gram LM. Please follow the instructions of KenLM to compile the code.
DEST=path/to/the/project/
TEXT=$DEST/data/wikitext-103/wiki.train.tokens
./bin/lmplz -o 5 < $TEXT > $TEXT.5gram.arpa
2.2 Model Convert
In this part, we need to convert the ARPA
model generated by KenLM to our model.
NOTE: before running this part, please prepare the dictionary
dict.txt
using fairseq first (section 3.1), because the neural model and $n$-gram should use the same dictionary.
#########
# setup
#########
cd $DEST/KNQuery
bash setup.sh
########
# run
########
DATADIR=$DEST/data/wikitext-103/
ARPA_FNAME=wiki.train.tokens.5gram.arpa
ARPA=$DATADIR/$ARPA_FNAME
CACHEDIR=./cache_5gram_wt103/
# this is the dict file generated by fairseq
TKZ=$DEST/data-bin/wikitext-103/dict.txt
########
# arpa -> binary
########
mkdir -p $CACHEDIR
python arpa2binary.py --tokenizer_path $TKZ \
--arpa $ARPA \
--eos_token "</s>" \
--unk_token "<unk>" \
--binary ./$ARPA_FNAME.pt
########
# build our model
########
python query.py --tokenizer_path $TKZ \
--lm_path ./$ARPA_FNAME.pt \
--cache_path $CACHEDIR/$ARPA_FNAME.cache \
--mode fast
3. Neural Language Model
setup:
cd $DEST/fairseq
bash setup.sh
3.1 Prepare data-bin
DATADIR=$DEST/data/wikitext-103/
fairseq-preprocess \
--only-source \
--trainpref $DATADIR/wiki.train.tokens \
--validpref $DATADIR/wiki.valid.tokens \
--testpref $DATADIR/wiki.test.tokens \
--destdir $DEST/data-bin/wikitext-103 \
--workers 20
3.2 Train Model
NOTE: Directly training ADP on fp16 using the code provided by fairseq may cause
NAN
loss. A tricky solution is to train ADP on fp32 for 3 epochs and then load the pre-trained parameters for fp16 training.
Standard neural LM:
########
# ADP
########
SAVE_DIR=$DEST/results/adaptive.input.wt103.base.fp16
mkdir -p $SAVE_DIR
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 nohup python train.py \
--task language_modeling \
$DEST/data-bin/wikitext-103 \
--save-dir $SAVE_DIR \
--arch transformer_lm_wiki103 \
--max-update 286000 --lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --min-lr 1e-5 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 8192 --update-freq 1 --tokens-per-sample 2048 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=legacy_ddp \
--restore-file $DEST/results/adaptive.input.wt103/checkpoint3.pt \
--fp16 \
--reset-optimizer \
--reset-lr-scheduler > $SAVE_DIR/log.train &
Our neuro-symbolic model:
########
# ADP+5gram
########
mkdir -p $DEST/results/adaptive.input.wt103.5gram.fp16
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 nohup python train.py \
--task language_modeling_with_ngram \
$DEST/data-bin/wikitext-103 \
--save-dir $DEST/results/adaptive.input.wt103.5gram.fp16 \
--arch transformer_lm_wiki103 \
--max-update 286000 --lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --min-lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss_with_ngram --max-tokens 8192 --update-freq 1 --tokens-per-sample 2048 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=legacy_ddp \
--restore-file $DEST/results/adaptive.input.wt103/checkpoint3.pt \
--fp16 \
--ngram-generation-model-cache "${DEST}/KNQuery/cache_5gram_wt103/wiki.train.tokens.5gram.arpa.cache" \
--ngram-alpha 0.1 \
--ngram-module-path $DEST/KNQuery/ \
--ngram-warmup-updates 0 > $DEST/results/adaptive.input.wt103.5gram.fp16/log.train &
3.3 Valid Model
MODEL=$DEST/results/adaptive.input.wt103.5gram.fp16/checkpoint_best.pt
CUDA_VISIBLE_DEVICES=0 python eval_lm.py \
$DEST/data-bin/wikitext-103 \
--gen-subset "test" \
--task language_modeling_with_ngram \
--path $MODEL \
--distributed-world-size 1 \
--batch-size 2 \
--context-window 1536 \
--tokens-per-sample 2048 \
--ngram-generation-model-cache "${DEST}/KNQuery/cache_5gram_wt103/wiki.train.tokens.5gram.arpa.cache" \
--ngram-module-path $DEST/KNQuery/