fairseq icon indicating copy to clipboard operation
fairseq copied to clipboard

[Non-autoregressive Transformer] Add GLAT, CTC, DS

Open SirRob1997 opened this issue 2 years ago • 7 comments

This PR adds the code for the following methods to the Non-Autoregressive Transformer:

Important to note is that this code still has one more dependency on the C++ code from torch_imputer that is currently not integrated in this PR. We leave it up to the fairseq team to decide how they want to include the best_alignment method for fairseq/models/nat/nonautoregressive_transformer.py (l. 171). Both, building a pip package and importing it or directly copying over the code from the respective repository would work. It is used for getting Viterbi-aligned target tokens when using CTC + GLAT jointly.

Main flags for training using any of the above methods are:

  • GLAT: --use-glat
  • CTC: --use-ctc-decoder --ctc-src-upsample-scale 2
  • DS: --use-deep-supervision

These are also supported jointly. Once this PR has been integrated, we'll work on getting a follow-up PR up for the required inference speed improvements i.e. Shortlists and Average Attention (see below paper). As these are not specific to non-autoregressive models, we decided to keep them separate.

If anyone using this code finds it helpful, please consider citing our associated paper:

Abstract: Non-autoregressive approaches aim to improve the inference speed of translation models by only requiring a single forward pass to generate the output sequence instead of iteratively producing each predicted token. Consequently, their translation quality still tends to be inferior to their autoregressive counterparts due to several issues involving output token interdependence. In this work, we take a step back and revisit several techniques that have been proposed for improving non-autoregressive translation models and compare their combined translation quality and speed implications under third-party testing environments. We provide novel insights for establishing strong baselines using length prediction or CTC-based architecture variants and contribute standardized BLEU, chrF++, and TER scores using sacreBLEU on four translation tasks, which crucially have been missing as inconsistencies in the use of tokenized BLEU lead to deviations of up to 1.7 BLEU points. Our open-sourced code is integrated into fairseq for reproducibility.

@misc{schmidt2022nat,
  url = {https://arxiv.org/abs/2205.10577}, 
  author = {Schmidt, Robin M. and Pires, Telmo and Peitz, Stephan and Lööf, Jonas},
  title = {Non-Autoregressive Neural Machine Translation: A Call for Clarity},
  publisher = {arXiv},
  year = {2022}
}

SirRob1997 avatar May 21 '22 06:05 SirRob1997

Hi, thanks for this great integration of NAT codes. Could you please provide an example to show how to train a GLAT model?

xcfcode avatar Jun 04 '22 10:06 xcfcode

Sure, as written above, the main flag for that is --use-glat which will enable the glancing sampling. Given that you ran fairseq-preprocess and you have your data correctly set up in a folder data-bin you should be able to run a training run for GLAT with:

fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --activation-fn gelu --adam-betas '(0.9, 0.98)' --apply-bert-init --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --length-loss-factor 0.1 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --pred-length-offset --share-all-embeddings --task translation_lev --use-glat --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings

Similarly, for vanilla CTC:

fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --adam-betas '(0.9, 0.98)' --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --ctc-src-upsample-scale 2 --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --share-all-embeddings --task translation_lev --use-ctc-decoder --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings

As you can see, the main flags to enable the methods are passed and can also be combined for CTC + GLAT (given that the C++ code is added as stated above):

fairseq-train data-bin --log-format simple --log-interval 100 --max-tokens 8192 --adam-betas '(0.9, 0.98)' --arch nonautoregressive_transformer --clip-norm 5.0 --criterion nat_loss --ctc-src-upsample-scale 2 --decoder-learned-pos --dropout 0.1 --encoder-learned-pos --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_eos_penalty": 0, "iter_decode_with_beam": 1}' --eval-bleu-detok moses --eval-bleu-print-samples --eval-bleu-remove-bpe --fp16 --label-smoothing 0 --lr 0.001 --lr-scheduler inverse_sqrt --max-update 200000 --min-lr 1e-09 --noise full_mask --optimizer adam --share-all-embeddings --task translation_lev --use-ctc-decoder --use-glat --warmup-init-lr 1e-07 --warmup-updates 10000 --weight-decay 0.01 --share-all-embeddings

For some of the hyperparameter choices, please see the paper (above is for WMT'14 EN-DE) !

Of course, max-tokens and lr are a little specific to our setup (number of GPUs, batch size) and might need some tuning to most effectively utilise your available GPU resources. My guess would be that you need to reduce both of them since we train on multiple A100's and as a result our batch size is quite large.

Let me know in case you run into any issues, I needed to strip a few internal flags so hopefully I didn't miss anything!

SirRob1997 avatar Jun 04 '22 13:06 SirRob1997

Dear Robin, thanks for your help, I have successfully finished the training process, could you kindly provide the test script?

xcfcode avatar Jun 29 '22 12:06 xcfcode

Sure, given that you have averaged your checkpoints and saved it in a file e.g. ckpts_last_5.pt running inference on the test set works with the following command:

fairseq-generate data-bin --path ckpts_last_5.pt --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --iter-decode-with-beam 1 --batch-size 128 --beam 1 --remove-bpe  --task translation_lev --gen-subset test 

This will generate hypothesis for the test set that you will need to score with sacrebleu something like this should work to generate BLEU, chrF++, and case-sensitive TER metrics for sacrebleu==2.0.0:

sacrebleu -i  test.hyp -t wmt14/full -l en-de -m bleu chrf ter --chrf-word-order 2 --ter-case-sensitive

Note that EN-DE and DE-EN use wmt14/full while EN-RO and RO-EN use wmt16.

SirRob1997 avatar Jun 29 '22 17:06 SirRob1997

Sincerely thanks!

xcfcode avatar Jun 30 '22 01:06 xcfcode

No problem at all, please let me know in case you run into any issues!

SirRob1997 avatar Jun 30 '22 19:06 SirRob1997

How to implement --nbest?

PPPNut avatar Apr 05 '23 05:04 PPPNut