在完成RNN/LSTM模型的训练之后,使用指令textclf --config-file test_rnn.json test
进行模型测试的时候,报错如下:
Load checkpoint from ckpts/best.pt..
Traceback (most recent call last):
File "d:\anaconda3\lib\runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "d:\anaconda3\lib\runpy.py", line 85, in run_code
exec(code, run_globals)
File "D:\Anaconda3\Scripts\textclf.exe_main.py", line 9, in
File "d:\anaconda3\lib\site-packages\click\core.py", line 764, in call
return self.main(*args, **kwargs)
File "d:\anaconda3\lib\site-packages\click\core.py", line 717, in main
rv = self.invoke(ctx)
File "d:\anaconda3\lib\site-packages\click\core.py", line 1137, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "d:\anaconda3\lib\site-packages\click\core.py", line 956, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "d:\anaconda3\lib\site-packages\click\core.py", line 555, in invoke
return callback(*args, **kwargs)
File "d:\anaconda3\lib\site-packages\click\decorators.py", line 17, in new_func
return f(get_current_context(), *args, **kwargs)
File "d:\anaconda3\lib\site-packages\textclf\main.py", line 72, in test
tester.test()
File "d:\anaconda3\lib\site-packages\textclf\tester\base_tester.py", line 16, in test
self.test_file()
File "d:\anaconda3\lib\site-packages\textclf\tester\base_tester.py", line 78, in test_file
predict = self.predict_label(text)
File "d:\anaconda3\lib\site-packages\textclf\tester\dl_tester.py", line 58, in predict_label
logits = self.model(text_processed, text_len)
File "d:\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "d:\anaconda3\lib\site-packages\textclf\models\dl_model.py", line 24, in forward
logits = self.classifier(embedding, lens)
File "d:\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "d:\anaconda3\lib\site-packages\textclf\models\classifier\rnn.py", line 30, in forward
context = self.attention_layer(outputs, seq_lens)
File "d:\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "d:\anaconda3\lib\site-packages\textclf\models\classifier\components.py", line 81, in forward
attn_logits[i][seq_len.item():] = -1e9
IndexError: dimension specified as 0 but tensor has no dimensions
我咨询了原作者,该问题的反馈如下:①检查pytorch版本,确保torch==1.4.0;②若版本无误,可以将训练RNN时自定义生成的json文件(train_rnn.json)中的use_attention设置为false。
经过上述两个步骤,问题应该可以得到解决,作者反馈use_attention部分的代码实现有一些问题,后续会进行调整。