RSTNet icon indicating copy to clipboard operation
RSTNet copied to clipboard

RuntimeError: gather(): Expected dtype int64 for index

Open zhuhaifengaaa opened this issue 2 years ago • 2 comments

When I run train_transformer.py, I'm having difficulty, showing RuntimeError: gather(): Expected dtype int64 for index, can anyone tell me how to solve it? thank you very much Quicker_20220407_164612

zhuhaifengaaa avatar Apr 07 '22 08:04 zhuhaifengaaa

When I run train_transformer.py, I'm having difficulty, showing RuntimeError: gather(): Expected dtype int64 for index, can anyone tell me how to solve it? thank you very much Quicker_20220407_164612

我没有遇到过这个问题哈,可能是由于某一步的浮点运算出错导致的bug,建议您check一下您的pytorch版本,我使用的是torch==1.1.0

zhangxuying1004 avatar May 12 '22 09:05 zhangxuying1004

For latest version of pytorch, for most (if not all) torch.gather() in models/beam_search/beam_search.py, you need to add .long()

yiren-jian avatar Aug 19 '22 19:08 yiren-jian