EditNTS icon indicating copy to clipboard operation
EditNTS copied to clipboard

RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 1 for tensor number 3 in the list

Open jaugustin12 opened this issue 3 years ago • 0 comments

I am running EDITNTS: https://github.com/yuedongP/EditNTS without teacher forcing on some training data. When I run main.py I get the error:

  File "/home/jba5337/work/ds440w/EditNTS-Google/editnts.py", line 252, in forward
    output_t = torch.cat((output_edits, attn_applied_org_t, c, hidden_words[0]),
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 1 for tensor number 3 in the list.

Here is what happens when I print c:

c tensor([[[-0.0353, -0.0617, -0.1176,  ...,  0.0507, -0.0174,  0.1828]],

        [[-0.0769, -0.0166, -0.1737,  ..., -0.1302, -0.1488,  0.1480]],

        [[-0.0570, -0.0683, -0.2270,  ..., -0.0820, -0.2011,  0.1915]],

        ...,

        [[-0.1127,  0.0051, -0.2119,  ..., -0.0853, -0.1813,  0.2058]],

        [[-0.0570, -0.0683, -0.2270,  ..., -0.0412, -0.1851,  0.1975]],

        [[-0.1127,  0.0051, -0.2119,  ..., -0.0477, -0.1822,  0.2200]]],
       device='cuda:0', grad_fn=<GatherBackward0>)

size torch.Size([32, 1, 400])

It looks like this is greater than a size 1, so I am unsure where the issue is. Here is the function of where the error is coming from if you could please take a look:

        else: # no teacher forcing
            decoder_input_edit = input_edits[:, :1]
            decoder_input_word=simp_sent[:,:1]
            t, tt = 0, max(MAX_LEN,input_edits.size(1)-1)

            # initialize
            embedded_edits = self.embedding(decoder_input_edit)
            output_edits, hidden_edits = self.rnn_edits(embedded_edits, hidden_org)

            embedded_words = self.embedding(decoder_input_word)
            output_words, hidden_words = self.rnn_words(embedded_words, hidden_org)
            #
            # # give previous word from tgt simp_sent
            # inds = torch.LongTensor(counter_for_keep_ins)
            # dummy = inds.view(-1, 1, 1)
            # dummy = dummy.expand(dummy.size(0), dummy.size(1), output_words.size(2)).cuda()
            # c_word = output_words.gather(1, dummy)

            while t < tt:
                if t>0:
                    embedded_edits = self.embedding(decoder_input_edit)
                    output_edits, hidden_edits = self.rnn_edits(embedded_edits, hidden_edits)

                key_org = self.attn_Projection_org(output_edits)  # bsz x nsteps x nhid
                logits_org = torch.bmm(key_org, encoder_outputs_org.transpose(1, 2))  # bsz x nsteps x encsteps
                attn_weights_org_t = F.softmax(logits_org, dim=-1)  # bsz x nsteps x encsteps
                attn_applied_org_t = torch.bmm(attn_weights_org_t, encoder_outputs_org)  # bsz x nsteps x nhid

                ## find current word
                inds = torch.LongTensor(counter_for_keep_del)
                dummy = inds.view(-1, 1, 1)
                dummy = dummy.expand(dummy.size(0), dummy.size(1), encoder_outputs_org.size(2)).cuda()
                c = encoder_outputs_org.gather(1, dummy)
                print('c',c)
                output_t = torch.cat((output_edits, attn_applied_org_t, c, hidden_words[0]),
                                     2)  # bsz*nsteps x nhid*2
                output_t = self.attn_MLP(output_t)
                output_t = F.log_softmax(self.out(output_t), dim=-1)

                decoder_out.append(output_t)
                decoder_input_edit=torch.argmax(output_t,dim=2)



                # gold_action = input[:, t + 1].vocab_data.cpu().numpy()  # might need to realign here because start added
                pred_action= torch.argmax(output_t,dim=2)
                counter_for_keep_del = [i[0] + 1 if i[1] == 2 or i[1] == 3 or i[1] == 5 else i[0]
                                        for i in zip(counter_for_keep_del, pred_action)]

                # update rnn_words
                # find previous generated word
                # give previous word from tgt simp_sent
                dummy_2 = inds.view(-1, 1).cuda()
                org_t = org_ids.gather(1, dummy_2)
                hidden_words = self.execute_batch(pred_action, org_t, hidden_words)  # we give the editted subsequence
                # hidden_words = self.execute_batch(pred_action, org_t, hidden_org)  #here we only give the word

                t += 1
                check = sum([x >= org_ids.size(1) for x in counter_for_keep_del])
                if check:
                    break
        return torch.cat(decoder_out, dim=1), hidden_edits

jaugustin12 avatar Nov 20 '21 18:11 jaugustin12