fairseq
fairseq copied to clipboard
[Non-autoregressive Transformer] Add GLAT, CTC, DS
This PR adds the code for the following methods to the Non-Autoregressive Transformer:
- Glancing Transformer (GLAT) from "Glancing Transformer for Non-Autoregressive Neural Machine Translation" (Qian et al., 2021)
- Connectionist Temporal Classification (CTC) from "End-to-End Non-Autoregressive Neural Machine Translation with Connectionist Temporal Classification" (Libovický & Helcl, 2018)
- Deep Supervision (DS) from "Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision" (Huang et al., 2021)
Important to note is that this code still has one more dependency on the C++ code from torch_imputer that is currently not integrated in this PR. We leave it up to the fairseq team to decide how they want to include the best_alignment
method for fairseq/models/nat/nonautoregressive_transformer.py
(l. 171). Both, building a pip package and importing it or directly copying over the code from the respective repository would work. It is used for getting Viterbi-aligned target tokens when using CTC + GLAT jointly.
Main flags for training using any of the above methods are:
- GLAT:
--use-glat
- CTC:
--use-ctc-decoder --ctc-src-upsample-scale 2
- DS:
--use-deep-supervision
These are also supported jointly. Once this PR has been integrated, we'll work on getting a follow-up PR up for the required inference speed improvements i.e. Shortlists and Average Attention (see below paper). As these are not specific to non-autoregressive models, we decided to keep them separate.
If anyone using this code finds it helpful, please consider citing our associated paper:
Abstract: Non-autoregressive approaches aim to improve the inference speed of translation models by only requiring a single forward pass to generate the output sequence instead of iteratively producing each predicted token. Consequently, their translation quality still tends to be inferior to their autoregressive counterparts due to several issues involving output token interdependence. In this work, we take a step back and revisit several techniques that have been proposed for improving non-autoregressive translation models and compare their combined translation quality and speed implications under third-party testing environments. We provide novel insights for establishing strong baselines using length prediction or CTC-based architecture variants and contribute standardized BLEU, chrF++, and TER scores using sacreBLEU on four translation tasks, which crucially have been missing as inconsistencies in the use of tokenized BLEU lead to deviations of up to 1.7 BLEU points. Our open-sourced code is integrated into fairseq for reproducibility.
@misc{schmidt2022nat,
url = {https://arxiv.org/abs/2205.10577},
author = {Schmidt, Robin M. and Pires, Telmo and Peitz, Stephan and Lööf, Jonas},
title = {Non-Autoregressive Neural Machine Translation: A Call for Clarity},
publisher = {arXiv},
year = {2022}
}
Hi, thanks for this great integration of NAT codes. Could you please provide an example to show how to train a GLAT
model?
Sure, as written above, the main flag for that is --use-glat
which will enable the glancing sampling. Given that you ran fairseq-preprocess
and you have your data correctly set up in a folder data-bin
you should be able to run a training run for GLAT with:
fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --activation-fn gelu --adam-betas '(0.9, 0.98)' --apply-bert-init --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --length-loss-factor 0.1 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --pred-length-offset --share-all-embeddings --task translation_lev --use-glat --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings
Similarly, for vanilla CTC:
fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --adam-betas '(0.9, 0.98)' --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --ctc-src-upsample-scale 2 --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --share-all-embeddings --task translation_lev --use-ctc-decoder --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings
As you can see, the main flags to enable the methods are passed and can also be combined for CTC + GLAT (given that the C++ code is added as stated above):
fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --adam-betas '(0.9, 0.98)' --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --ctc-src-upsample-scale 2 --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --share-all-embeddings --task translation_lev --use-ctc-decoder --use-glat --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings
For some of the hyperparameter choices, please see the paper (above is for WMT'14 EN-DE) !
Of course, max-tokens
and lr
are a little specific to our setup (number of GPUs, batch size) and might need some tuning to most effectively utilise your available GPU resources. My guess would be that you need to reduce both of them since we train on multiple A100's and as a result our batch size is quite large.
Let me know in case you run into any issues, I needed to strip a few internal flags so hopefully I didn't miss anything!
Dear Robin, thanks for your help, I have successfully finished the training process, could you kindly provide the test script?
Sure, given that you have averaged your checkpoints and saved it in a file e.g. ckpts_last_5.pt
running inference on the test set works with the following command:
fairseq-generate data-bin --path ckpts_last_5.pt --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --iter-decode-with-beam 1 --batch-size 128 --beam 1 --remove-bpe --task translation_lev --gen-subset test
This will generate hypothesis for the test set that you will need to score with sacrebleu
something like this should work to generate BLEU
, chrF++
, and case-sensitive TER
metrics for sacrebleu==2.0.0
:
sacrebleu -i test.hyp -t wmt14/full -l en-de -m bleu chrf ter --chrf-word-order 2 --ter-case-sensitive
Note that EN-DE
and DE-EN
use wmt14/full
while EN-RO
and RO-EN
use wmt16
.
Sincerely thanks!
No problem at all, please let me know in case you run into any issues!
How to implement --nbest?