Seq-NAT
Seq-NAT copied to clipboard
Source code for <Sequence-Level Training for Non-Autoregressive Neural Machine Translation>.
Source code for <Sequence-Level Training for Non-Autoregressive Neural Machine Translation>.
PyTorch implementation of the methods described in the Computational Linguistics 2021 paper Sequence-Level Training for Non-Autoregressive Neural Machine Translation. The code is based on fairseq v0.9.0. We only modified nat_loss.py and utils.py.
Dependencies
- Python 3.8
- PyTorch 1.7
Dataset
First, follow the instructions to download and preprocess the WMT'14 En-De dataset.
Make sure to learn a joint vocabulary by passing the --joined-dictionary option to fairseq-preprocess.
Knowledge Distillation
Following Gu et al. 2019, knowledge distillation from an autoregressive model can effectively simplify the training data distribution, which is sometimes essential for NAT-based models to learn good translations. The easiest way of performing distillation is to follow the instructions of training a standard transformer model on the same data, and then decode the training set to produce a distillation dataset for NAT.
Training
The training scripts are provided in the folder training_scripts. Firstly, run the pretraining script to pretrain the baseline NAT model:
$ sh training_scripts/pretrain.sh
Then, run other scripts for the finetuning. For example, to finetune the NAT model with the BoN-L1 objective, run:
$ sh training_scripts/bag2grams.sh
Decoding
To decode the test set, run:
$ sh decode.sh model_path
Citation
If you find the resources in this repository useful, please consider citing:
@article{10.1162/coli_a_00421,
author = {Shao, Chenze and Feng, Yang and Zhang, Jinchao and Meng, Fandong and Zhou, Jie},
title = "{Sequence-Level Training for Non-Autoregressive Neural Machine Translation}",
journal = {Computational Linguistics},
volume = {47},
number = {4},
pages = {891-925},
year = {2021},
month = {12},
issn = {0891-2017},
doi = {10.1162/coli_a_00421},
url = {https://doi.org/10.1162/coli\_a\_00421},
eprint = {https://direct.mit.edu/coli/article-pdf/47/4/891/1979393/coli\_a\_00421.pdf},
}