MWPToolkit icon indicating copy to clipboard operation
MWPToolkit copied to clipboard

Graph2Tree and GTS can't handle test_batch_size > 1

Open liamjxu opened this issue 2 years ago • 2 comments

When test_batch_size is set through command line, e.g.,

python run_mwptoolkit.py --model=GTS --dataset=mawps --task_type=multi_equation --gpu_id=0 --equation_fix=prefix --test_batch_size=32

Both Graph2Tree and GTS fail to forward propagate,

Traceback (most recent call last):
  File "run_mwptoolkit.py", line 63, in <module>
    run_toolkit(config)
  File "/data/MWPToolkit/mwptoolkit/quick_start.py", line 220, in run_toolkit
    train_with_train_valid_test_split(config)
  File "/data/MWPToolkit/mwptoolkit/quick_start.py", line 109, in train_with_train_valid_test_split
    trainer.fit()
  File "/data/MWPToolkit/mwptoolkit/trainer/supervised_trainer.py", line 583, in fit
    valid_equ_ac, valid_val_ac, valid_total, valid_time_cost = self.evaluate(DatasetType.Valid)
  File "/data/MWPToolkit/mwptoolkit/trainer/supervised_trainer.py", line 645, in evaluate
    batch_val_ac, batch_equ_ac = self._eval_batch(batch)
  File "/data/MWPToolkit/mwptoolkit/trainer/supervised_trainer.py", line 506, in _eval_batch
    test_out, target = self.model.model_test(batch)
  File "/data/MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py", line 179, in model_test
    _, outputs, _ = self.forward(seq, seq_length, nums_stack, num_size, num_pos)
  File "/data/MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py", line 128, in forward
    output_all_layers)
  File "/data/MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py", line 350, in decoder_forward
    out_token = int(ti)
ValueError: only one element tensors can be converted to Python scalars

This issue makes the training + testing time exceptionally long because the model_test logic is utilized in both validating and testing. Would you consider adding support for test_batch_size > 1?

liamjxu avatar May 06 '22 17:05 liamjxu

Yes, i'm working on this, I will update if I test its correctness and it significantly improves speed.

LYH-YF avatar May 08 '22 09:05 LYH-YF

Thanks for the reply! Looking forward to it.

liamjxu avatar May 08 '22 10:05 liamjxu