RSTNet
RSTNet copied to clipboard
RuntimeError: gather(): Expected dtype int64 for index
When I run train_transformer.py, I'm having difficulty, showing RuntimeError: gather(): Expected dtype int64 for index, can anyone tell me how to solve it?
thank you very much
When I run train_transformer.py, I'm having difficulty, showing RuntimeError: gather(): Expected dtype int64 for index, can anyone tell me how to solve it? thank you very much
我没有遇到过这个问题哈,可能是由于某一步的浮点运算出错导致的bug,建议您check一下您的pytorch版本,我使用的是torch==1.1.0
For latest version of pytorch, for most (if not all) torch.gather()
in models/beam_search/beam_search.py
, you need to add .long()