RecBole
RecBole copied to clipboard
GRU4Rec模型跑ml-100k数据集出现RuntimeError
开发者你好,今天当我使用GRU4Rec模型跑ml-100k数据集时出现如下问题
Traceback (most recent call last):
File "run.py", line 6, in <module>
run_recbole(model='GRU4Rec', dataset='ml-100k', config_dict=parameter_dict)
File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\quick_start\quick_start.py", line 54, in run_recbole
train_data, valid_data, saved=saved, show_progress=config['show_progress']
File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\trainer\trainer.py", line 262, in fit
train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\trainer\trainer.py", line 152, in _train_epoch
losses = loss_func(interaction)
File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\model\sequential_recommender\gru4rec.py", line 88, in calculate_loss
seq_output = self.forward(item_seq, item_seq_len)
File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\model\sequential_recommender\gru4rec.py", line 82, in forward
seq_output = self.gather_indexes(gru_output, item_seq_len - 1)
File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\model\abstract_recommender.py", line 119, in gather_indexes
output_tensor = output.gather(dim=1, index=gather_index)
RuntimeError: gather(): Expected dtype int64 for index
run.py代码如下
from recbole.quick_start import run_recbole
parameter_dict = {
'train_neg_sample_args': None,
}
run_recbole(model='GRU4Rec', dataset='ml-100k', config_dict=parameter_dict)
希望您能解答,非常感谢
@mksdu 您好!请检查一下您的RecBole版本是否为最新版本,我们已经在之前的版本中修复了该bug,可能是您使用的RecBole版本没有及时更新。
由于长时间无新回复,该 issue 已关闭。如果还有疑问,欢迎随时评论。