你好,我在复现代码时出现了维度不一致的问题:
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 181109/181109 [00:00<00:00, 389326.69it/s]
Epoch: 0 Loss: 0.9171
Epoch: 100 Loss: 0.3691
Epoch: 200 Loss: 0.1924
Epoch: 300 Loss: 0.1232
Epoch: 400 Loss: 0.0847
Epoch: 500 Loss: 0.0723
Epoch: 600 Loss: 0.0584
Epoch: 700 Loss: 0.0560
Epoch: 800 Loss: 0.0451
Epoch: 900 Loss: 0.0368
Epoch: 1000 Loss: 0.0358
Epoch 1000 has finished, saving...
Epoch 1000 has finished, validating...
Traceback (most recent call last):
File "/data/xuyiming/Few-shot/Edge-level/GANA-FewShotKGC/main_gana.py", line 63, in
trainer.train()
File "/data/xuyiming/Few-shot/Edge-level/GANA-FewShotKGC/trainer_gana.py", line 308, in train
valid_data = self.eval(istest=False, epoch=e)
File "/data/xuyiming/Few-shot/Edge-level/GANA-FewShotKGC/trainer_gana.py", line 363, in eval
_, p_score, n_score = self.do_one_step(eval_task, iseval=True, curr_rel=curr_rel, istest=istest)
File "/data/xuyiming/Few-shot/Edge-level/GANA-FewShotKGC/trainer_gana.py", line 281, in do_one_step
loss = self.metaR.loss_func(p_score, n_score, y)
File "/data/xuyiming/Conda/anaconda3/envs/enmark/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/xuyiming/Conda/anaconda3/envs/enmark/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/data/xuyiming/Conda/anaconda3/envs/enmark/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 1353, in forward
return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
File "/data/xuyiming/Conda/anaconda3/envs/enmark/lib/python3.9/site-packages/torch/nn/functional.py", line 3416, in margin_ranking_loss
raise RuntimeError(
RuntimeError: margin_ranking_loss : All input tensors should have same dimension but got sizes: input1: torch.Size([1, 1]), input2: torch.Size([1, 2366]), target: torch.Size([1])
应该是margin_ranking_loss的input1,input2和target维度应该是一致的,但是这里不一致,请问怎么解决呢