SGG-Seq2Seq icon indicating copy to clipboard operation
SGG-Seq2Seq copied to clipboard

something wrong in the code

Open xywhat opened this issue 2 years ago • 0 comments

tgt_key_padding_mask cannt match with tgt_mask.

Epoch 1 Recall@5 0.40540524032301645 Recall@10 0.5550042311638902 Recall@20 0.691875286804221 Recall@50 0.7947694177616892 Recall@100 0.82678933109488...

Traceback (most recent call last): File "D:/xuexi/graduate/project/SGG-Seq2Seq-main/trainer.py", line 406, in main() File "D:/xuexi/graduate/project/SGG-Seq2Seq-main/trainer.py", line 322, in main predictions = model( File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "D:\xuexi\graduate\project\SGG-Seq2Seq-main\transformer.py", line 92, in forward decoder_states = self.decoder( File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\transformer.py", line 291, in forward output = mod(output, memory, tgt_mask=tgt_mask, File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\transformer.py", line 576, in forward x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask)) File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\transformer.py", line 585, in _sa_block x = self.self_attn(x, x, x, File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\modules\activation.py", line 1153, in forward attn_output, attn_output_weights = F.multi_head_attention_forward( File "D:\software\anaconda\envs\pytorch1.8\lib\site-packages\torch\nn\functional.py", line 5155, in multi_head_attention_forward assert key_padding_mask.shape == (bsz, src_len),
AssertionError: expecting key_padding_mask shape of (2, 100), but got torch.Size([2, 78])

Process finished with exit code 1

xywhat avatar Nov 11 '22 09:11 xywhat