tgt_key_padding_mask cannt match with tgt_mask.
Epoch 1 Recall@5 0.40540524032301645 Recall@10 0.5550042311638902 Recall@20 0.691875286804221 Recall@50 0.7947694177616892 Recall@100 0.82678933109488...
Traceback (most recent call last):
File "D:/xuexi/graduate/project/SGG-Seq2Seq-main/trainer.py", line 406, in
main()
File "D:/xuexi/graduate/project/SGG-Seq2Seq-main/trainer.py", line 322, in main
predictions = model(
File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "D:\xuexi\graduate\project\SGG-Seq2Seq-main\transformer.py", line 92, in forward
decoder_states = self.decoder(
File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\transformer.py", line 291, in forward
output = mod(output, memory, tgt_mask=tgt_mask,
File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\transformer.py", line 576, in forward
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\transformer.py", line 585, in _sa_block
x = self.self_attn(x, x, x,
File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\activation.py", line 1153, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\functional.py", line 5155, in multi_head_attention_forward
assert key_padding_mask.shape == (bsz, src_len),
AssertionError: expecting key_padding_mask shape of (2, 100), but got torch.Size([2, 78])
Process finished with exit code 1