fairseq
fairseq copied to clipboard
data2vec_text: EMA teacher processes only masked tokens, instead of full sequence.
🐛 Bug
The target_tokens variable in the forward call of the Data2VecTextEncoder model contains only the tokens at masked positions and padding tokens otherwise. In the method described in the data2vec paper, the target_tokens should be the full uncorrupted sequence of the model.
To Reproduce
- Add a
print("target_tokens", target_tokens)in before they are passed to the EMA model in data2vec_text.py. - Pre-train data2vec with the provided configuration:
python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/text/pretraining \
--config-name base task.data=/path/to/data common.user_dir=examples/data2vec
- Logging shows target_tokens with lots of padding indices (the 1 entries)
target_tokens tensor([[ 1, 1, 1, ..., 1, 1, 1],
[ 1, 1, 1, ..., 128, 128, 2],
[ 1, 13863, 17383, ..., 1228, 479, 2],
...,
[ 1, 1, 1, ..., 1, 1, 1],
[ 1, 1, 1, ..., 1, 1, 1],
[ 0, 133, 4573, ..., 1, 1, 1]])
I believe this is because MaskTokensDataset returns a target sequence where tokens are only added at masked positions and padding is used otherwise.
Expected behavior
target_tokens should be the full uncorrupted sequence, and only have padding at positions where src_tokens has them, i.e. something like
target_tokens = torch.where(target_tokens == padding_index, src_tokens, target_tokens)
Environment
fairseq Version: 0.12.2 OS: linux How you installed fairseq: pip Python version: 3.10.8