OpenPrompt icon indicating copy to clipboard operation
OpenPrompt copied to clipboard

[Tutorial 2.1 error] TypeError: where(): argument 'other' (position 3) must be Tensor, not int

Open canghongjian opened this issue 3 years ago • 0 comments

It happened in tutorial 2.1. Details are as follows: Traceback (most recent call last): File "condional_prompt.py", line 112, in loss = prompt_model(inputs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 449, in forward return self._forward(*args, **kwargs) File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 467, in _forward logits, labels = self.shift_logits_and_labels(logits, batch['loss_ids'], reference_ids) File "/opt/conda/lib/python3.7/site-packages/openprompt/pipeline_base.py", line 434, in shift_logits_and_labels shift_input_ids = torch.where(shift_loss_ids>0, shift_input_ids, -100) TypeError: where(): argument 'other' (position 3) must be Tensor, not int

canghongjian avatar Jul 07 '22 02:07 canghongjian