MWPToolkit icon indicating copy to clipboard operation
MWPToolkit copied to clipboard

unable to run GTS on a custom dataset

Open indranilByjus opened this issue 2 years ago • 4 comments

The module seems to run fine with the provided datasets. But it throws error on when I've included a custom dataset.

Errortrace:

MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py line 220, in train_tree
current_num = current_nums_embeddings[idx, i - num_start].unsqueeze(0)
IndexError: index 124 is out of bounds for dimension 1 with size 118

indranilByjus avatar Sep 27 '21 07:09 indranilByjus

I'm not sure the specific reason. But you could check the value of these variables below.

dataset.copy_nums
dataset.num_start
dataset.generate_list
dataset.out_idx2symbol

they are all about the decoder's vocabulary of GTS, or other models. Please check if they are currect.

Second, you may pay attention to the inputs(batch_data) of the model.
At the line where throws the error , current_nums_embeddings means all number embedding at current decoding step (generate number + copy number). The size of it should be [batch_size, 118, hidden_size] , 118 is the sum of generate size (static in different batches) and copy size (dynamic in different batches, it's up to max(batch_data["num size"])). You could check if batch_data["num size"] is currect. Another point is batch_data[num stack], in GTS, if a number appears twice or more in question sentence (one number has two position). So it has two optional symbols to generate. Which symbol to choose is decided while decoding. So target token is replaced by UNK_token, when decoding, choose the symbol which has maximal score as target symbol. batch_data["num stack"] means candidate symbols for UNK_token.If UNK_token is not replaced by candidate symbols currectly, it may cause the index out of bounds. So please check if batch_data["num stack"] is currect.

LYH-YF avatar Sep 28 '21 06:09 LYH-YF

Code for building number stack MWPToolkit/mwptoolkit/data/dataset/abstactdataset.py line 192

    def _build_num_stack(self, equation, num_list):
        num_stack = []
        for word in equation:
            temp_num = []
            flag_not = True
            if word not in self.dataset.out_idx2symbol:
                flag_not = False
                if "NUM" in word:
                    temp_num.append(int(word[4:]))
                for i, j in enumerate(num_list):
                    if j == word:
                        temp_num.append(i)

            if not flag_not and len(temp_num) != 0:
                num_stack.append(temp_num)
            if not flag_not and len(temp_num) == 0:
                num_stack.append([_ for _ in range(len(num_list))])
        num_stack.reverse()
        return num_stack

Code for choosing the target symbol according to maximal score MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py line 357

    def generate_tree_input(self, target, decoder_output, nums_stack_batch, num_start, unk):
        # when the decoder input is copied num but the num has two pos, chose the max
        target_input = copy.deepcopy(target)
        for i in range(len(target)):
            if target[i] == unk:
                num_stack = nums_stack_batch[i].pop()
                max_score = -float("1e12")
                for num in num_stack:
                    if decoder_output[i, num_start + num] > max_score:
                        target[i] = num + num_start
                        max_score = decoder_output[i, num_start + num]
            if target_input[i] >= num_start:
                target_input[i] = 0
        return torch.LongTensor(target), torch.LongTensor(target_input)

LYH-YF avatar Sep 28 '21 06:09 LYH-YF

So I looked into the variables, apparently the values are: image

there are a few target tokens, who exceed beyond the specified num I'd previously patched them: target_t = torch.LongTensor([t if t<118 else 117 for t in target_t])

But its definitely screwing something during training.

indranilByjus avatar Sep 29 '21 08:09 indranilByjus

I encountered a similar problem. I think one possible reason is that: the text of equation/question has a different format. Eg. you are using "x=1+2" as the equation but the data loader is expecting "1+2".

lijierui avatar Nov 04 '21 14:11 lijierui