RecBole icon indicating copy to clipboard operation
RecBole copied to clipboard

GRU4Rec模型跑ml-100k数据集出现RuntimeError

Open mksdu opened this issue 2 years ago • 1 comments

开发者你好,今天当我使用GRU4Rec模型跑ml-100k数据集时出现如下问题

Traceback (most recent call last):
  File "run.py", line 6, in <module>
    run_recbole(model='GRU4Rec', dataset='ml-100k', config_dict=parameter_dict)
  File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\quick_start\quick_start.py", line 54, in run_recbole
    train_data, valid_data, saved=saved, show_progress=config['show_progress']
  File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\trainer\trainer.py", line 262, in fit
    train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
  File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\trainer\trainer.py", line 152, in _train_epoch
    losses = loss_func(interaction)
  File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\model\sequential_recommender\gru4rec.py", line 88, in calculate_loss
    seq_output = self.forward(item_seq, item_seq_len)
  File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\model\sequential_recommender\gru4rec.py", line 82, in forward
    seq_output = self.gather_indexes(gru_output, item_seq_len - 1)
  File "C:\environment\Anaconda\envs\recbole\lib\site-packages\recbole\model\abstract_recommender.py", line 119, in gather_indexes
    output_tensor = output.gather(dim=1, index=gather_index)
RuntimeError: gather(): Expected dtype int64 for index


run.py代码如下

from recbole.quick_start import run_recbole

parameter_dict = {
   'train_neg_sample_args': None,
}
run_recbole(model='GRU4Rec', dataset='ml-100k', config_dict=parameter_dict)

希望您能解答,非常感谢

mksdu avatar Aug 03 '22 09:08 mksdu

@mksdu 您好!请检查一下您的RecBole版本是否为最新版本,我们已经在之前的版本中修复了该bug,可能是您使用的RecBole版本没有及时更新。

Wicknight avatar Aug 04 '22 09:08 Wicknight

由于长时间无新回复,该 issue 已关闭。如果还有疑问,欢迎随时评论。

Wicknight avatar Sep 20 '22 13:09 Wicknight