fix overwrite bug when adding symbol to dictionary
Before submitting
- [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the contributor guideline?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?
What does this PR do?
Fixes #3064. Fixes #3705. Fixes #1309.
TLDR; This PR fixes the bug that duplicates the symbols that were meant to be overwritten in the vocabulary file. See detailed explanation in this blog post.
Expected behavior:
A Dictionary object has an indices dict and two lists (symbols and counts). By default, when loading a vocabulary from a file, a Dictionary instance is first created by adding 4 special tokens (<s>, <pad>, </s> and <unk> in that order). Then, all the entries from the file are appended to the Dictionary. If the vocabulary file already has some of the special tokens, their file entry should contain #fairseq:overwrite, otherwise a "duplicate" error will be raised at runtime. Furthermore, during preprocessing, the saved dictionary should not contain any of the special symbols.
Current behavior:
The add_symbol function is responsible for adding the symbols to the Dictionary. It has an overwrite argument that is set to True when the corresponding line in the file has #fairseq:overwrite. Rather than testing if word in self.indices and overwrite, it is currently testing if word in self.indices and not overwrite, which makes it ignore the case where the symbol should actually be overwritten. Hence, the symbol is appended to the symbols list, and its index is changed in the indices dict. This results in duplicate symbols and incorrect indices. Generally, only the special symbols will be affected. However, because the number of special tokens is set during initialization, it remains correct.
For example, a dictionary with 50K tokens that already has <s>, <pad> , </s> and <unk> with the #fairseq:overwrite tag will end up having 50004 tokens when loaded. This will also propagate to the subsequent model which will have an embedding dimension of 50004 instead of 50K. Also, with fairseq-preprocess, the resulting dictionary will skip the first 4 special symbols but will still contain the duplicate ones.
Domino effects and backward compatibility:
By fixing this bug, dictionary files will be loaded properly. However, this fix might cause problems in pipelines that use existing architectures and pretrained models because of the mismatch in sentencepiece encoding and/or embedding dimension.
For the sake of backward compatibility, a #fairseq:duplicate flag is introduced to ensure that duplicates are kept in the dictionary just like the bug. When used with fairseq-preprocess, the produced dict.txt file will also write #fairseq:duplicate next to the same symbols.
PR review
Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Yes, I did 🙃