pytorch-transformer
pytorch-transformer copied to clipboard
Should encoder mask be a (1, seq_len, seq_len) matrix?
If input sentence is (A, B, C, D, PAD).
In this implementation, encoder mask is [[[FALSE, FALSE, FALSE, TRUE]]]
But the encoder attention is
[
[AA, AB, AC, AD],
[BA, BB, BC, BD],
[CA, CB, CC, CD],
[DA, DB, DC, DD]
]
This encoder mask will only mask the 4th colomn. 4th row should also be mask?
So encoder mask should be
[[ [FALSE, FALSE, FALSE, TRUE] [FALSE, FALSE, FALSE, TRUE] [FALSE, FALSE, FALSE, TRUE] [TRUE, TRUE, TRUE, TRUE] ] ]